Merge branch 'master' into crl-support

# Conflicts:
#	api/api.go
#	authority/config/config.go
#	cas/softcas/softcas.go
#	db/db.go
pull/731/head
Raal Goff 2 years ago
commit 60671b07d7

@ -0,0 +1,56 @@
name: Bug Report
description: File a bug report
title: "[Bug]: "
labels: ["bug", "needs triage"]
body:
- type: markdown
attributes:
value: |
Thanks for taking the time to fill out this bug report!
- type: textarea
id: steps
attributes:
label: Steps to Reproduce
description: Tell us how to reproduce this issue.
placeholder: These are the steps!
validations:
required: true
- type: textarea
id: your-env
attributes:
label: Your Environment
value: |-
* OS -
* `step-ca` Version -
validations:
required: true
- type: textarea
id: expected-behavior
attributes:
label: Expected Behavior
description: What did you expect to happen?
validations:
required: true
- type: textarea
id: actual-behavior
attributes:
label: Actual Behavior
description: What happens instead?
validations:
required: true
- type: textarea
id: context
attributes:
label: Additional Context
description: Add any other context about the problem here.
validations:
required: false
- type: textarea
id: contributing
attributes:
label: Contributing
value: |
Vote on this issue by adding a 👍 reaction.
To contribute a fix for this issue, leave a comment (and link to your pull request, if you've opened one already).
validations:
required: false

@ -1,27 +0,0 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: bug, needs triage
assignees: ''
---
### Subject of the issue
Describe your issue here.
### Your environment
* OS -
* Version -
### Steps to reproduce
Tell us how to reproduce this issue. Please provide a working demo, you can use [this template](https://plnkr.co/edit/XorWgI?p=preview) as a base.
### Expected behaviour
Tell us what should happen
### Actual behaviour
Tell us what happens instead
### Additional context
Add any other context about the problem here.

@ -1,12 +1,20 @@
--- ---
name: Documentation Request name: Documentation Request
about: Request documentation for a feature about: Request documentation for a feature
title: '' title: '[Docs]:'
labels: documentation, needs triage labels: docs, needs triage
assignees: '' assignees: ''
--- ---
## Hello!
<!-- Please leave this section as-is, it's designed to help others in the community know how to interact with our GitHub issues. -->
- Vote on this issue by adding a 👍 reaction
- If you want to document this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
## Affected area/feature
<!--- <!---
Tell us which feature you'd like to see documented. Tell us which feature you'd like to see documented.
- Where would you like that documentation to live (command line usage output, website, github markdown on the repo)? - Where would you like that documentation to live (command line usage output, website, github markdown on the repo)?

@ -1,13 +1,24 @@
--- ---
name: Enhancement name: Enhancement
about: Suggest an enhancement to step certificates about: Suggest an enhancement to step-ca
title: '' title: ''
labels: enhancement, needs triage labels: enhancement, needs triage
assignees: '' assignees: ''
--- ---
### What would you like to be added ## Hello!
<!-- Please leave this section as-is,
it's designed to help others in the community know how to interact with our GitHub issues. -->
- Vote on this issue by adding a 👍 reaction
- If you want to implement this feature, comment to let us know (we'll work with you on design, scheduling, etc.)
### Why this is needed ## Issue details
<!-- Enhancement requests are most helpful when they describe the problem you're having
as well as articulating the potential solution you'd like to see built. -->
## Why is this needed?
<!-- Let us know why you think this enhancement would be good for the project or community. -->

@ -1,4 +1,20 @@
### Description <!---
Please describe your pull request. Please provide answers in the spaces below each prompt, where applicable.
Not every PR requires responses for each prompt.
Use your discretion.
-->
#### Name of feature:
#### Pain or issue this feature alleviates:
#### Why is this important to the project (if not answered above):
#### Is there documentation on how to use this feature? If so, where?
#### In what environments or workflows is this feature supported?
#### In what environments or workflows is this feature explicitly NOT supported (if any)?
#### Supporting links/other PRs/issues:
💔Thank you! 💔Thank you!

@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.45.0' version: 'v1.45.2'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -139,7 +139,7 @@ jobs:
name: Run GoReleaser name: Run GoReleaser
uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0 uses: goreleaser/goreleaser-action@5a54d7e660bda43b405e8463261b3d25631ffe86 # v2.7.0
with: with:
version: latest version: 'v1.7.0'
args: release --rm-dist args: release --rm-dist
env: env:
GITHUB_TOKEN: ${{ secrets.PAT }} GITHUB_TOKEN: ${{ secrets.PAT }}

@ -33,7 +33,7 @@ jobs:
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
version: 'v1.45.0' version: 'v1.45.2'
# Optional: working directory, useful for monorepos # Optional: working directory, useful for monorepos
# working-directory: somedir # working-directory: somedir
@ -59,8 +59,9 @@ jobs:
- -
name: Codecov name: Codecov
if: matrix.go == '1.18' if: matrix.go == '1.18'
uses: codecov/codecov-action@v1.2.1 uses: codecov/codecov-action@v2
with: with:
file: ./coverage.out # optional token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.out # optional
name: codecov-umbrella # optional name: codecov-umbrella # optional
fail_ci_if_error: true # optional (default = false) fail_ci_if_error: true # optional (default = false)

@ -230,42 +230,3 @@ scoop:
# Your app's license # Your app's license
# Default is empty. # Default is empty.
license: "Apache-2.0" license: "Apache-2.0"
#dockers:
# - dockerfile: docker/Dockerfile
# goos: linux
# goarch: amd64
# use_buildx: true
# image_templates:
# - "smallstep/step-cli:latest"
# - "smallstep/step-cli:{{ .Tag }}"
# build_flag_templates:
# - "--platform=linux/amd64"
# - dockerfile: docker/Dockerfile
# goos: linux
# goarch: 386
# use_buildx: true
# image_templates:
# - "smallstep/step-cli:latest"
# - "smallstep/step-cli:{{ .Tag }}"
# build_flag_templates:
# - "--platform=linux/386"
# - dockerfile: docker/Dockerfile
# goos: linux
# goarch: arm
# goarm: 7
# use_buildx: true
# image_templates:
# - "smallstep/step-cli:latest"
# - "smallstep/step-cli:{{ .Tag }}"
# build_flag_templates:
# - "--platform=linux/arm/v7"
# - dockerfile: docker/Dockerfile
# goos: linux
# goarch: arm64
# use_buildx: true
# image_templates:
# - "smallstep/step-cli:latest"
# - "smallstep/step-cli:{{ .Tag }}"
# build_flag_templates:
# - "--platform=linux/arm64/v8"

@ -4,16 +4,70 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased - 0.18.3] - DATE ### TEMPLATE -- do not alter or remove
---
## [x.y.z] - aaaa-bb-cc
### Added ### Added
- Added support for renew after expiry using the claim `allowRenewAfterExpiry`. ### Changed
### Deprecated
### Removed
### Fixed
### Security
---
## [Unreleased]
### Changed
- Certificates signed by an issuer using an RSA key will be signed using the same algorithm as the issuer certificate was signed with. The signature will no longer default to PKCS #1. For example, if the issuer certificate was signed using RSA-PSS with SHA-256, a new certificate will also be signed using RSA-PSS with SHA-256.
## [0.20.0] - 2022-05-26
### Added
- Added Kubernetes auth method for Vault RAs.
- Added support for reporting provisioners to linkedca.
- Added support for certificate policies on authority level.
- Added a Dockerfile with a step-ca build with HSM support.
- A few new WithXX methods for instantiating authorities
### Changed
- Context usage in HTTP APIs.
- Changed authentication for Vault RAs.
- Error message returned to client when authenticating with expired certificate.
- Strip padding from ACME CSRs.
### Deprecated
- HTTP API handler types.
### Fixed
- Fixed SSH revocation.
- CA client dial context for js/wasm target.
- Incomplete `extraNames` support in templates.
- SCEP GET request support.
- Large SCEP request handling.
## [0.19.0] - 2022-04-19
### Added
- Added support for certificate renewals after expiry using the claim `allowRenewalAfterExpiry`.
- Added support for `extraNames` in X.509 templates. - Added support for `extraNames` in X.509 templates.
- Added `armv5` builds.
- Added RA support using a Vault instance as the CA.
- Added `WithX509SignerFunc` authority option.
- Added a new `/roots.pem` endpoint to download the CA roots in PEM format.
- Added support for Azure `Managed Identity` tokens.
- Added support for automatic configuration of linked RAs.
- Added support for the `--context` flag. It's now possible to start the
CA with `step-ca --context=abc` to use the configuration from context `abc`.
When a context has been configured and no configuration file is provided
on startup, the configuration for the current context is used.
- Added startup info logging and option to skip it (`--quiet`).
- Added support for renaming the CA (Common Name).
### Changed ### Changed
- Made SCEP CA URL paths dynamic - Made SCEP CA URL paths dynamic.
- Support two latest versions of Go (1.17, 1.18) - Support two latest versions of Go (1.17, 1.18).
- Upgrade go.step.sm/crypto to v0.16.1.
- Upgrade go.step.sm/linkedca to v0.15.0.
### Deprecated ### Deprecated
- Go 1.16 support.
### Removed ### Removed
### Fixed ### Fixed
- Fixed admin credentials on RAs.
- Fixed ACME HTTP-01 challenges for IPv6 identifiers.
- Various improvements under the hood.
### Security ### Security
## [0.18.2] - 2022-03-01 ## [0.18.2] - 2022-03-01
@ -49,7 +103,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Support for multiple certificate authority contexts. - Support for multiple certificate authority contexts.
- Support for generating extractable keys and certificates on a pkcs#11 module. - Support for generating extractable keys and certificates on a pkcs#11 module.
### Changed ### Changed
- Support two latest versions of golang (1.16, 1.17) - Support two latest versions of Go (1.16, 1.17)
### Deprecated ### Deprecated
- go 1.15 support - go 1.15 support

@ -151,7 +151,7 @@ integration: bin/$(BINNAME)
######################################### #########################################
fmt: fmt:
$Q gofmt -l -w $(SRC) $Q gofmt -l -s -w $(SRC)
lint: lint:
$Q golangci-lint run --timeout=30m $Q golangci-lint run --timeout=30m

@ -54,7 +54,7 @@ Setting up a *public key infrastructure* (PKI) is out of reach for many small te
- [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation - [Short-lived certificates](https://smallstep.com/blog/passive-revocation.html) with automated enrollment, renewal, and passive revocation
- Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries - Capable of high availability (HA) deployment using [root federation](https://smallstep.com/blog/step-v0.8.3-federation-root-rotation.html) and/or multiple intermediaries
- Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca) - Can operate as [an online intermediate CA for an existing root CA](https://smallstep.com/docs/tutorials/intermediate-ca-new-ca)
- [Badger, BoltDB, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases) - [Badger, BoltDB, Postgres, and MySQL database backends](https://smallstep.com/docs/step-ca/configuration#databases)
### ⚙️ Many ways to automate ### ⚙️ Many ways to automate
@ -68,6 +68,7 @@ You can issue certificates in exchange for:
- [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure - [Cloud instance identity documents](https://smallstep.com/blog/embarrassingly-easy-certificates-on-aws-azure-gcp/), for VMs on AWS, GCP, and Azure
- [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc. - [Single-use, short-lived JWK tokens](https://smallstep.com/docs/step-ca/provisioners#jwk) issued by your CD tool — Puppet, Chef, Ansible, Terraform, etc.
- A trusted X.509 certificate (X5C provisioner) - A trusted X.509 certificate (X5C provisioner)
- A host certificate from your Nebula network
- A SCEP challenge (SCEP provisioner) - A SCEP challenge (SCEP provisioner)
- An SSH host certificates needing renewal (the SSHPOP provisioner) - An SSH host certificates needing renewal (the SSHPOP provisioner)
- Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners) - Learn more in our [provisioner documentation](https://smallstep.com/docs/step-ca/provisioners)

@ -7,6 +7,8 @@ import (
"time" "time"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/certificates/authority/policy"
) )
// Account is a subset of the internal account type containing only those // Account is a subset of the internal account type containing only those
@ -43,15 +45,63 @@ func KeyToID(jwk *jose.JSONWebKey) (string, error) {
return base64.RawURLEncoding.EncodeToString(kid), nil return base64.RawURLEncoding.EncodeToString(kid), nil
} }
// PolicyNames contains ACME account level policy names
type PolicyNames struct {
DNSNames []string `json:"dns"`
IPRanges []string `json:"ips"`
}
// X509Policy contains ACME account level X.509 policy
type X509Policy struct {
Allowed PolicyNames `json:"allow"`
Denied PolicyNames `json:"deny"`
AllowWildcardNames bool `json:"allowWildcardNames"`
}
// Policy is an ACME Account level policy
type Policy struct {
X509 X509Policy `json:"x509"`
}
func (p *Policy) GetAllowedNameOptions() *policy.X509NameOptions {
if p == nil {
return nil
}
return &policy.X509NameOptions{
DNSDomains: p.X509.Allowed.DNSNames,
IPRanges: p.X509.Allowed.IPRanges,
}
}
func (p *Policy) GetDeniedNameOptions() *policy.X509NameOptions {
if p == nil {
return nil
}
return &policy.X509NameOptions{
DNSDomains: p.X509.Denied.DNSNames,
IPRanges: p.X509.Denied.IPRanges,
}
}
// AreWildcardNamesAllowed returns if wildcard names
// like *.example.com are allowed to be signed.
// Defaults to false.
func (p *Policy) AreWildcardNamesAllowed() bool {
if p == nil {
return false
}
return p.X509.AllowWildcardNames
}
// ExternalAccountKey is an ACME External Account Binding key. // ExternalAccountKey is an ACME External Account Binding key.
type ExternalAccountKey struct { type ExternalAccountKey struct {
ID string `json:"id"` ID string `json:"id"`
ProvisionerID string `json:"provisionerID"` ProvisionerID string `json:"provisionerID"`
Reference string `json:"reference"` Reference string `json:"reference"`
AccountID string `json:"-"` AccountID string `json:"-"`
KeyBytes []byte `json:"-"` HmacKey []byte `json:"-"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
BoundAt time.Time `json:"boundAt,omitempty"` BoundAt time.Time `json:"boundAt,omitempty"`
Policy *Policy `json:"policy,omitempty"`
} }
// AlreadyBound returns whether this EAK is already bound to // AlreadyBound returns whether this EAK is already bound to
@ -68,6 +118,6 @@ func (eak *ExternalAccountKey) BindTo(account *Account) error {
} }
eak.AccountID = account.ID eak.AccountID = account.ID
eak.BoundAt = time.Now() eak.BoundAt = time.Now()
eak.KeyBytes = []byte{} // clearing the key bytes; can only be used once eak.HmacKey = []byte{} // clearing the key bytes; can only be used once
return nil return nil
} }

@ -7,8 +7,9 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
) )
func TestKeyToID(t *testing.T) { func TestKeyToID(t *testing.T) {
@ -95,7 +96,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
ID: "eakID", ID: "eakID",
ProvisionerID: "provID", ProvisionerID: "provID",
Reference: "ref", Reference: "ref",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
}, },
acct: &Account{ acct: &Account{
ID: "accountID", ID: "accountID",
@ -108,7 +109,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
ID: "eakID", ID: "eakID",
ProvisionerID: "provID", ProvisionerID: "provID",
Reference: "ref", Reference: "ref",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
AccountID: "someAccountID", AccountID: "someAccountID",
BoundAt: boundAt, BoundAt: boundAt,
}, },
@ -138,7 +139,7 @@ func TestExternalAccountKey_BindTo(t *testing.T) {
assert.Equals(t, ae.Subproblems, tt.err.Subproblems) assert.Equals(t, ae.Subproblems, tt.err.Subproblems)
} else { } else {
assert.Equals(t, eak.AccountID, acct.ID) assert.Equals(t, eak.AccountID, acct.ID)
assert.Equals(t, eak.KeyBytes, []byte{}) assert.Equals(t, eak.HmacKey, []byte{})
assert.NotNil(t, eak.BoundAt) assert.NotNil(t, eak.BoundAt)
} }
}) })

@ -67,8 +67,11 @@ func (u *UpdateAccountRequest) Validate() error {
} }
// NewAccount is the handler resource for creating new ACME accounts. // NewAccount is the handler resource for creating new ACME accounts.
func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) { func NewAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
payload, err := payloadFromContext(ctx) payload, err := payloadFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -114,7 +117,7 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
return return
} }
eak, err := h.validateExternalAccountBinding(ctx, &nar) eak, err := validateExternalAccountBinding(ctx, &nar)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -125,18 +128,17 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
Contact: nar.Contact, Contact: nar.Contact,
Status: acme.StatusValid, Status: acme.StatusValid,
} }
if err := h.db.CreateAccount(ctx, acc); err != nil { if err := db.CreateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating account")) render.Error(w, acme.WrapErrorISE(err, "error creating account"))
return return
} }
if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response if eak != nil { // means that we have a (valid) External Account Binding key that should be bound, updated and sent in the response
err := eak.BindTo(acc) if err := eak.BindTo(acc); err != nil {
if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
if err := h.db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil { if err := db.UpdateExternalAccountKey(ctx, prov.ID, eak); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key")) render.Error(w, acme.WrapErrorISE(err, "error updating external account binding key"))
return return
} }
@ -147,15 +149,18 @@ func (h *Handler) NewAccount(w http.ResponseWriter, r *http.Request) {
httpStatus = http.StatusOK httpStatus = http.StatusOK
} }
h.linker.LinkAccount(ctx, acc) linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(r.Context(), AccountLinkType, acc.ID)) w.Header().Set("Location", linker.GetLink(r.Context(), acme.AccountLinkType, acc.ID))
render.JSONStatus(w, acc, httpStatus) render.JSONStatus(w, acc, httpStatus)
} }
// GetOrUpdateAccount is the api for updating an ACME account. // GetOrUpdateAccount is the api for updating an ACME account.
func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) { func GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -187,16 +192,16 @@ func (h *Handler) GetOrUpdateAccount(w http.ResponseWriter, r *http.Request) {
acc.Contact = uar.Contact acc.Contact = uar.Contact
} }
if err := h.db.UpdateAccount(ctx, acc); err != nil { if err := db.UpdateAccount(ctx, acc); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating account")) render.Error(w, acme.WrapErrorISE(err, "error updating account"))
return return
} }
} }
} }
h.linker.LinkAccount(ctx, acc) linker.LinkAccount(ctx, acc)
w.Header().Set("Location", h.linker.GetLink(ctx, AccountLinkType, acc.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.AccountLinkType, acc.ID))
render.JSON(w, acc) render.JSON(w, acc)
} }
@ -210,8 +215,11 @@ func logOrdersByAccount(w http.ResponseWriter, oids []string) {
} }
// GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account. // GetOrdersByAccountID ACME api for retrieving the list of order urls belonging to an account.
func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) { func GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -222,13 +230,14 @@ func (h *Handler) GetOrdersByAccountID(w http.ResponseWriter, r *http.Request) {
render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID)) render.Error(w, acme.NewError(acme.ErrorUnauthorizedType, "account ID '%s' does not match url param '%s'", acc.ID, accID))
return return
} }
orders, err := h.db.GetOrdersByAccountID(ctx, acc.ID)
orders, err := db.GetOrdersByAccountID(ctx, acc.ID)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
h.linker.LinkOrdersByAccountID(ctx, orders) linker.LinkOrdersByAccountID(ctx, orders)
render.JSON(w, orders) render.JSON(w, orders)
logOrdersByAccount(w, orders) logOrdersByAccount(w, orders)

@ -13,10 +13,12 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
) )
var ( var (
@ -29,6 +31,22 @@ var (
} }
) )
type fakeProvisioner struct{}
func (*fakeProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
return nil
}
func (*fakeProvisioner) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) {
return nil, nil
}
func (*fakeProvisioner) AuthorizeRevoke(ctx context.Context, token string) error { return nil }
func (*fakeProvisioner) GetID() string { return "" }
func (*fakeProvisioner) GetName() string { return "" }
func (*fakeProvisioner) DefaultTLSCertDuration() time.Duration { return 0 }
func (*fakeProvisioner) GetOptions() *provisioner.Options { return nil }
func newProv() acme.Provisioner { func newProv() acme.Provisioner {
// Initialize provisioners // Initialize provisioners
p := &provisioner.ACME{ p := &provisioner.ACME{
@ -41,6 +59,19 @@ func newProv() acme.Provisioner {
return p return p
} }
func newProvWithOptions(options *provisioner.Options) acme.Provisioner {
// Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-<test>provisioner.com",
Options: options,
}
if err := p.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
fmt.Printf("%v", err)
}
return p
}
func newACMEProv(t *testing.T) *provisioner.ACME { func newACMEProv(t *testing.T) *provisioner.ACME {
p := newProv() p := newProv()
a, ok := p.(*provisioner.ACME) a, ok := p.(*provisioner.ACME)
@ -50,6 +81,15 @@ func newACMEProv(t *testing.T) *provisioner.ACME {
return a return a
} }
func newACMEProvWithOptions(t *testing.T, options *provisioner.Options) *provisioner.ACME {
p := newProvWithOptions(options)
a, ok := p.(*provisioner.ACME)
if !ok {
t.Fatal("not a valid ACME provisioner")
}
return a
}
func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) { func createEABJWS(jwk *jose.JSONWebKey, hmacKey []byte, keyID, u string) (*jose.JSONWebSignature, error) {
signer, err := jose.NewSigner( signer, err := jose.NewSigner(
jose.SigningKey{ jose.SigningKey{
@ -296,10 +336,9 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: accID} acc := &acme.Account{ID: accID}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) { MockGetOrdersByAccountID: func(ctx context.Context, id string) ([]string, error) {
@ -315,11 +354,11 @@ func TestHandler_GetOrdersByAccountID(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrdersByAccountID(w, req) GetOrdersByAccountID(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -363,6 +402,7 @@ func TestHandler_NewAccount(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -371,6 +411,7 @@ func TestHandler_NewAccount(t *testing.T) {
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), payloadContextKey, nil) ctx := context.WithValue(context.Background(), payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -379,6 +420,7 @@ func TestHandler_NewAccount(t *testing.T) {
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to "+ err: acme.NewError(acme.ErrorMalformedType, "failed to "+
@ -393,6 +435,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
@ -405,8 +448,9 @@ func TestHandler_NewAccount(t *testing.T) {
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -418,9 +462,10 @@ func TestHandler_NewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jwk expected in request context"), err: acme.NewErrorISE("jwk expected in request context"),
@ -432,10 +477,11 @@ func TestHandler_NewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, nil) ctx = context.WithValue(ctx, jwkContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jwk expected in request context"), err: acme.NewErrorISE("jwk expected in request context"),
@ -454,9 +500,9 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"), err: acme.NewError(acme.ErrorExternalAccountRequiredType, "no external account binding provided"),
@ -471,7 +517,7 @@ func TestHandler_NewAccount(t *testing.T) {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -501,18 +547,11 @@ func TestHandler_NewAccount(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
scepProvisioner := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
assert.FatalError(t, err)
}
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"), err: acme.NewError(acme.ErrorServerInternalType, "provisioner in context is not an ACME provisioner"),
@ -551,14 +590,13 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
eak := &acme.ExternalAccountKey{ eak := &acme.ExternalAccountKey{
ID: "eakID", ID: "eakID",
ProvisionerID: provID, ProvisionerID: provID,
Reference: "testeak", Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
return test{ return test{
@ -599,8 +637,7 @@ func TestHandler_NewAccount(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
@ -635,11 +672,11 @@ func TestHandler_NewAccount(t *testing.T) {
Status: acme.StatusValid, Status: acme.StatusValid,
Contact: []string{"foo", "bar"}, Contact: []string{"foo", "bar"},
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
acc: acc, acc: acc,
statusCode: 200, statusCode: 200,
@ -664,8 +701,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = false prov.RequireEAB = false
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateAccount: func(ctx context.Context, acc *acme.Account) error { MockCreateAccount: func(ctx context.Context, acc *acme.Account) error {
@ -719,8 +755,7 @@ func TestHandler_NewAccount(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -735,7 +770,7 @@ func TestHandler_NewAccount(t *testing.T) {
ID: "eakID", ID: "eakID",
ProvisionerID: provID, ProvisionerID: provID,
Reference: "testeak", Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(), CreatedAt: time.Now(),
}, nil }, nil
}, },
@ -759,11 +794,11 @@ func TestHandler_NewAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.NewAccount(w, req) NewAccount(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -814,6 +849,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -822,6 +858,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, nil) ctx := context.WithValue(context.Background(), accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -830,6 +867,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -839,6 +877,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -848,6 +887,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"), err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-account request payload: unexpected end of JSON input"),
@ -862,6 +902,7 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
ctx := context.WithValue(context.Background(), accContextKey, &acc) ctx := context.WithValue(context.Background(), accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"), err: acme.NewError(acme.ErrorMalformedType, "contact cannot be empty string"),
@ -894,10 +935,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
@ -914,11 +954,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
uar := &UpdateAccountRequest{} uar := &UpdateAccountRequest{}
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
} }
@ -929,10 +969,9 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
} }
b, err := json.Marshal(uar) b, err := json.Marshal(uar)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error { MockUpdateAccount: func(ctx context.Context, upd *acme.Account) error {
@ -946,11 +985,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
} }
}, },
"ok/post-as-get": func(t *testing.T) test { "ok/post-as-get": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, &acc) ctx = context.WithValue(ctx, accContextKey, &acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isPostAsGet: true})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 200, statusCode: 200,
} }
@ -959,11 +998,11 @@ func TestHandler_GetOrUpdateAccount(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrUpdateAccount(w, req) GetOrUpdateAccount(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)

@ -4,8 +4,9 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/smallstep/certificates/acme"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/certificates/acme"
) )
// ExternalAccountBinding represents the ACME externalAccountBinding JWS // ExternalAccountBinding represents the ACME externalAccountBinding JWS
@ -16,7 +17,7 @@ type ExternalAccountBinding struct {
} }
// validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account. // validateExternalAccountBinding validates the externalAccountBinding property in a call to new-account.
func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) { func validateExternalAccountBinding(ctx context.Context, nar *NewAccountRequest) (*acme.ExternalAccountKey, error) {
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context") return nil, acme.WrapErrorISE(err, "could not load ACME provisioner from context")
@ -47,7 +48,8 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc
return nil, acmeErr return nil, acmeErr
} }
externalAccountKey, err := h.db.GetExternalAccountKey(ctx, acmeProv.ID, keyID) db := acme.MustDatabaseFromContext(ctx)
externalAccountKey, err := db.GetExternalAccountKey(ctx, acmeProv.ID, keyID)
if err != nil { if err != nil {
if _, ok := err.(*acme.Error); ok { if _, ok := err.(*acme.Error); ok {
return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key") return nil, acme.WrapError(acme.ErrorUnauthorizedType, err, "the field 'kid' references an unknown key")
@ -55,11 +57,19 @@ func (h *Handler) validateExternalAccountBinding(ctx context.Context, nar *NewAc
return nil, acme.WrapErrorISE(err, "error retrieving external account key") return nil, acme.WrapErrorISE(err, "error retrieving external account key")
} }
if externalAccountKey == nil {
return nil, acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key")
}
if len(externalAccountKey.HmacKey) == 0 {
return nil, acme.NewError(acme.ErrorServerInternalType, "external account binding key with id '%s' does not have secret bytes", keyID)
}
if externalAccountKey.AlreadyBound() { if externalAccountKey.AlreadyBound() {
return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt) return nil, acme.NewError(acme.ErrorUnauthorizedType, "external account binding key with id '%s' was already bound to account '%s' on %s", keyID, externalAccountKey.AccountID, externalAccountKey.BoundAt)
} }
payload, err := eabJWS.Verify(externalAccountKey.KeyBytes) payload, err := eabJWS.Verify(externalAccountKey.HmacKey)
if err != nil { if err != nil {
return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature") return nil, acme.WrapErrorISE(err, "error verifying externalAccountBinding signature")
} }
@ -97,12 +107,12 @@ func keysAreEqual(x, y *jose.JSONWebKey) bool {
// validateEABJWS verifies the contents of the External Account Binding JWS. // validateEABJWS verifies the contents of the External Account Binding JWS.
// The protected header of the JWS MUST meet the following criteria: // The protected header of the JWS MUST meet the following criteria:
// o The "alg" field MUST indicate a MAC-based algorithm //
// o The "kid" field MUST contain the key identifier provided by the CA // - The "alg" field MUST indicate a MAC-based algorithm
// o The "nonce" field MUST NOT be present // - The "kid" field MUST contain the key identifier provided by the CA
// o The "url" field MUST be set to the same value as the outer JWS // - The "nonce" field MUST NOT be present
// - The "url" field MUST be set to the same value as the outer JWS
func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) { func validateEABJWS(ctx context.Context, jws *jose.JSONWebSignature) (string, *acme.Error) {
if jws == nil { if jws == nil {
return "", acme.NewErrorISE("no JWS provided") return "", acme.NewErrorISE("no JWS provided")
} }

@ -9,10 +9,11 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/jose"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
) )
func Test_keysAreEqual(t *testing.T) { func Test_keysAreEqual(t *testing.T) {
@ -98,8 +99,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newACMEProv(t) prov := newACMEProv(t)
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
@ -143,8 +143,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now() createdAt := time.Now()
return test{ return test{
@ -154,7 +153,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID", ID: "eakID",
ProvisionerID: provID, ProvisionerID: provID,
Reference: "testeak", Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: createdAt, CreatedAt: createdAt,
}, nil }, nil
}, },
@ -168,7 +167,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID", ID: "eakID",
ProvisionerID: provID, ProvisionerID: provID,
Reference: "testeak", Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: createdAt, CreatedAt: createdAt,
}, },
err: nil, err: nil,
@ -189,17 +188,10 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
} }
b, err := json.Marshal(nar) b, err := json.Marshal(nar)
assert.FatalError(t, err) assert.FatalError(t, err)
scepProvisioner := &provisioner.SCEP{
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
if err := scepProvisioner.Init(provisioner.Config{Claims: globalProvisionerClaims}); err != nil {
assert.FatalError(t, err)
}
ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b}) ctx := context.WithValue(context.Background(), payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, &fakeProvisioner{})
ctx = context.WithValue(ctx, provisionerContextKey, scepProvisioner)
return test{ return test{
ctx: ctx, ctx: ctx,
err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"), err: acme.NewError(acme.ErrorServerInternalType, "could not load ACME provisioner from context: provisioner in context is not an ACME provisioner"),
@ -218,8 +210,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
@ -264,8 +255,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
@ -310,8 +300,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -358,8 +347,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -408,8 +396,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -426,6 +413,112 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
err: acme.NewErrorISE("error retrieving external account key"), err: acme.NewErrorISE("error retrieving external account key"),
} }
}, },
"fail/db.GetExternalAccountKey-nil": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
assert.FatalError(t, err)
eab := &ExternalAccountBinding{}
err = json.Unmarshal(rawEABJWS, &eab)
assert.FatalError(t, err)
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
}
payloadBytes, err := json.Marshal(nar)
assert.FatalError(t, err)
so := new(jose.SignerOptions)
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
so.WithHeader("url", url)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
jws, err := signer.Sign(payloadBytes)
assert.FatalError(t, err)
raw, err := jws.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{
db: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
return nil, nil
},
},
ctx: ctx,
nar: &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
},
eak: nil,
err: acme.NewError(acme.ErrorUnauthorizedType, "the field 'kid' references an unknown key"),
}
},
"fail/db.GetExternalAccountKey-no-keybytes": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err)
url := fmt.Sprintf("%s/acme/%s/account/new-account", baseURL.String(), escProvName)
rawEABJWS, err := createRawEABJWS(jwk, []byte{1, 3, 3, 7}, "eakID", url)
assert.FatalError(t, err)
eab := &ExternalAccountBinding{}
err = json.Unmarshal(rawEABJWS, &eab)
assert.FatalError(t, err)
nar := &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
}
payloadBytes, err := json.Marshal(nar)
assert.FatalError(t, err)
so := new(jose.SignerOptions)
so.WithHeader("alg", jose.SignatureAlgorithm(jwk.Algorithm))
so.WithHeader("url", url)
signer, err := jose.NewSigner(jose.SigningKey{
Algorithm: jose.SignatureAlgorithm(jwk.Algorithm),
Key: jwk.Key,
}, so)
assert.FatalError(t, err)
jws, err := signer.Sign(payloadBytes)
assert.FatalError(t, err)
raw, err := jws.CompactSerialize()
assert.FatalError(t, err)
parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err)
prov := newACMEProv(t)
prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now()
return test{
db: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerName, keyID string) (*acme.ExternalAccountKey, error) {
return &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: provID,
Reference: "testeak",
CreatedAt: createdAt,
AccountID: "some-account-id",
HmacKey: []byte{},
}, nil
},
},
ctx: ctx,
nar: &NewAccountRequest{
Contact: []string{"foo", "bar"},
ExternalAccountBinding: eab,
},
eak: nil,
err: acme.NewError(acme.ErrorServerInternalType, "external account binding key with id 'eakID' does not have secret bytes"),
}
},
"fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test { "fail/db.GetExternalAccountKey-wrong-provisioner": func(t *testing.T) test {
jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
@ -458,8 +551,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -506,8 +598,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
createdAt := time.Now() createdAt := time.Now()
boundAt := time.Now().Add(1 * time.Second) boundAt := time.Now().Add(1 * time.Second)
@ -520,6 +611,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
Reference: "testeak", Reference: "testeak",
CreatedAt: createdAt, CreatedAt: createdAt,
AccountID: "some-account-id", AccountID: "some-account-id",
HmacKey: []byte{1, 3, 3, 7},
BoundAt: boundAt, BoundAt: boundAt,
}, nil }, nil
}, },
@ -565,8 +657,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -575,7 +666,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID", ID: "eakID",
ProvisionerID: provID, ProvisionerID: provID,
Reference: "testeak", Reference: "testeak",
KeyBytes: []byte{1, 2, 3, 4}, HmacKey: []byte{1, 2, 3, 4},
CreatedAt: time.Now(), CreatedAt: time.Now(),
}, nil }, nil
}, },
@ -623,8 +714,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, jwk) ctx := context.WithValue(context.Background(), jwkContextKey, jwk)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -633,7 +723,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID", ID: "eakID",
ProvisionerID: provID, ProvisionerID: provID,
Reference: "testeak", Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(), CreatedAt: time.Now(),
}, nil }, nil
}, },
@ -678,8 +768,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -688,7 +777,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID", ID: "eakID",
ProvisionerID: provID, ProvisionerID: provID,
Reference: "testeak", Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(), CreatedAt: time.Now(),
}, nil }, nil
}, },
@ -734,8 +823,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
prov := newACMEProv(t) prov := newACMEProv(t)
prov.RequireEAB = true prov.RequireEAB = true
ctx := context.WithValue(context.Background(), jwkContextKey, nil) ctx := context.WithValue(context.Background(), jwkContextKey, nil)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -744,7 +832,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
ID: "eakID", ID: "eakID",
ProvisionerID: provID, ProvisionerID: provID,
Reference: "testeak", Reference: "testeak",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Now(), CreatedAt: time.Now(),
}, nil }, nil
}, },
@ -762,10 +850,8 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
db: tc.db, got, err := validateExternalAccountBinding(ctx, tc.nar)
}
got, err := h.validateExternalAccountBinding(tc.ctx, tc.nar)
wantErr := tc.err != nil wantErr := tc.err != nil
gotErr := err != nil gotErr := err != nil
if wantErr != gotErr { if wantErr != gotErr {
@ -787,7 +873,7 @@ func TestHandler_validateExternalAccountBinding(t *testing.T) {
} else { } else {
assert.NotNil(t, tc.eak) assert.NotNil(t, tc.eak)
assert.Equals(t, got.ID, tc.eak.ID) assert.Equals(t, got.ID, tc.eak.ID)
assert.Equals(t, got.KeyBytes, tc.eak.KeyBytes) assert.Equals(t, got.HmacKey, tc.eak.HmacKey)
assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, got.ProvisionerID, tc.eak.ProvisionerID)
assert.Equals(t, got.Reference, tc.eak.Reference) assert.Equals(t, got.Reference, tc.eak.Reference)
assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, got.CreatedAt, tc.eak.CreatedAt)

@ -2,12 +2,10 @@ package api
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"net"
"net/http" "net/http"
"time" "time"
@ -16,6 +14,7 @@ import (
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
@ -39,111 +38,152 @@ type payloadInfo struct {
isEmptyJSON bool isEmptyJSON bool
} }
// Handler is the ACME API request handler.
type Handler struct {
db acme.DB
backdate provisioner.Duration
ca acme.CertificateAuthority
linker Linker
validateChallengeOptions *acme.ValidateChallengeOptions
prerequisitesChecker func(ctx context.Context) (bool, error)
}
// HandlerOptions required to create a new ACME API request handler. // HandlerOptions required to create a new ACME API request handler.
type HandlerOptions struct { type HandlerOptions struct {
Backdate provisioner.Duration // DB storage backend that implements the acme.DB interface.
// DB storage backend that impements the acme.DB interface. //
// Deprecated: use acme.NewContex(context.Context, acme.DB)
DB acme.DB DB acme.DB
// CA is the certificate authority interface.
//
// Deprecated: use authority.NewContext(context.Context, *authority.Authority)
CA acme.CertificateAuthority
// Backdate is the duration that the CA will subtract from the current time
// to set the NotBefore in the certificate.
Backdate provisioner.Duration
// DNS the host used to generate accurate ACME links. By default the authority // DNS the host used to generate accurate ACME links. By default the authority
// will use the Host from the request, so this value will only be used if // will use the Host from the request, so this value will only be used if
// request.Host is empty. // request.Host is empty.
DNS string DNS string
// Prefix is a URL path prefix under which the ACME api is served. This // Prefix is a URL path prefix under which the ACME api is served. This
// prefix is required to generate accurate ACME links. // prefix is required to generate accurate ACME links.
// E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account -- // E.g. https://ca.smallstep.com/acme/my-acme-provisioner/new-account --
// "acme" is the prefix from which the ACME api is accessed. // "acme" is the prefix from which the ACME api is accessed.
Prefix string Prefix string
CA acme.CertificateAuthority
// PrerequisitesChecker checks if all prerequisites for serving ACME are // PrerequisitesChecker checks if all prerequisites for serving ACME are
// met by the CA configuration. // met by the CA configuration.
PrerequisitesChecker func(ctx context.Context) (bool, error) PrerequisitesChecker func(ctx context.Context) (bool, error)
} }
var mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
return authority.MustFromContext(ctx)
}
// handler is the ACME API request handler.
type handler struct {
opts *HandlerOptions
}
// Route traffic and implement the Router interface. For backward compatibility
// this route adds will add a new middleware that will set the ACME components
// on the context.
//
// Note: this method is deprecated in step-ca, other applications can still use
// this to support ACME, but the recommendation is to use use
// api.Route(api.Router) and acme.NewContext() instead.
func (h *handler) Route(r api.Router) {
client := acme.NewClient()
linker := acme.NewLinker(h.opts.DNS, h.opts.Prefix)
route(r, func(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if ca, ok := h.opts.CA.(*authority.Authority); ok && ca != nil {
ctx = authority.NewContext(ctx, ca)
}
ctx = acme.NewContext(ctx, h.opts.DB, client, linker, h.opts.PrerequisitesChecker)
next(w, r.WithContext(ctx))
}
})
}
// NewHandler returns a new ACME API handler. // NewHandler returns a new ACME API handler.
func NewHandler(ops HandlerOptions) api.RouterHandler { //
transport := &http.Transport{ // Note: this method is deprecated in step-ca, other applications can still use
TLSClientConfig: &tls.Config{ // this to support ACME, but the recommendation is to use use
InsecureSkipVerify: true, // api.Route(api.Router) and acme.NewContext() instead.
}, func NewHandler(opts HandlerOptions) api.RouterHandler {
} return &handler{
client := http.Client{ opts: &opts,
Timeout: 30 * time.Second,
Transport: transport,
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
}
prerequisitesChecker := func(ctx context.Context) (bool, error) {
// by default all prerequisites are met
return true, nil
}
if ops.PrerequisitesChecker != nil {
prerequisitesChecker = ops.PrerequisitesChecker
}
return &Handler{
ca: ops.CA,
db: ops.DB,
backdate: ops.Backdate,
linker: NewLinker(ops.DNS, ops.Prefix),
validateChallengeOptions: &acme.ValidateChallengeOptions{
HTTPGet: client.Get,
LookupTxt: net.LookupTXT,
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(dialer, network, addr, config)
},
},
prerequisitesChecker: prerequisitesChecker,
} }
} }
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface. This method requires that
func (h *Handler) Route(r api.Router) { // all the acme components, authority, db, client, linker, and prerequisite
getPath := h.linker.GetUnescapedPathSuffix // checker to be present in the context.
// Standard ACME API func Route(r api.Router) {
r.MethodFunc("GET", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) route(r, nil)
r.MethodFunc("HEAD", getPath(NewNonceLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.GetNonce)))))) }
r.MethodFunc("GET", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
r.MethodFunc("HEAD", getPath(DirectoryLinkType, "{provisionerID}"), h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.GetDirectory))))
func route(r api.Router, middleware func(next nextHTTP) nextHTTP) {
commonMiddleware := func(next nextHTTP) nextHTTP {
handler := func(w http.ResponseWriter, r *http.Request) {
// Linker middleware gets the provisioner and current url from the
// request and sets them in the context.
linker := acme.MustLinkerFromContext(r.Context())
linker.Middleware(http.HandlerFunc(checkPrerequisites(next))).ServeHTTP(w, r)
}
if middleware != nil {
handler = middleware(handler)
}
return handler
}
validatingMiddleware := func(next nextHTTP) nextHTTP { validatingMiddleware := func(next nextHTTP) nextHTTP {
return h.baseURLFromRequest(h.lookupProvisioner(h.checkPrerequisites(h.addNonce(h.addDirLink(h.verifyContentType(h.parseJWS(h.validateJWS(next)))))))) return commonMiddleware(addNonce(addDirLink(verifyContentType(parseJWS(validateJWS(next))))))
} }
extractPayloadByJWK := func(next nextHTTP) nextHTTP { extractPayloadByJWK := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.extractJWK(h.verifyAndExtractJWSPayload(next))) return validatingMiddleware(extractJWK(verifyAndExtractJWSPayload(next)))
} }
extractPayloadByKid := func(next nextHTTP) nextHTTP { extractPayloadByKid := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.lookupJWK(h.verifyAndExtractJWSPayload(next))) return validatingMiddleware(lookupJWK(verifyAndExtractJWSPayload(next)))
} }
extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP { extractPayloadByKidOrJWK := func(next nextHTTP) nextHTTP {
return validatingMiddleware(h.extractOrLookupJWK(h.verifyAndExtractJWSPayload(next))) return validatingMiddleware(extractOrLookupJWK(verifyAndExtractJWSPayload(next)))
} }
r.MethodFunc("POST", getPath(NewAccountLinkType, "{provisionerID}"), extractPayloadByJWK(h.NewAccount)) getPath := acme.GetUnescapedPathSuffix
r.MethodFunc("POST", getPath(AccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.GetOrUpdateAccount))
r.MethodFunc("POST", getPath(KeyChangeLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.NotImplemented)) // Standard ACME API
r.MethodFunc("POST", getPath(NewOrderLinkType, "{provisionerID}"), extractPayloadByKid(h.NewOrder)) r.MethodFunc("GET", getPath(acme.NewNonceLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(OrderLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrder))) commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("POST", getPath(OrdersByAccountLinkType, "{provisionerID}", "{accID}"), extractPayloadByKid(h.isPostAsGet(h.GetOrdersByAccountID))) r.MethodFunc("HEAD", getPath(acme.NewNonceLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(FinalizeLinkType, "{provisionerID}", "{ordID}"), extractPayloadByKid(h.FinalizeOrder)) commonMiddleware(addNonce(addDirLink(GetNonce))))
r.MethodFunc("POST", getPath(AuthzLinkType, "{provisionerID}", "{authzID}"), extractPayloadByKid(h.isPostAsGet(h.GetAuthorization))) r.MethodFunc("GET", getPath(acme.DirectoryLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"), extractPayloadByKid(h.GetChallenge)) commonMiddleware(GetDirectory))
r.MethodFunc("POST", getPath(CertificateLinkType, "{provisionerID}", "{certID}"), extractPayloadByKid(h.isPostAsGet(h.GetCertificate))) r.MethodFunc("HEAD", getPath(acme.DirectoryLinkType, "{provisionerID}"),
r.MethodFunc("POST", getPath(RevokeCertLinkType, "{provisionerID}"), extractPayloadByKidOrJWK(h.RevokeCert)) commonMiddleware(GetDirectory))
r.MethodFunc("POST", getPath(acme.NewAccountLinkType, "{provisionerID}"),
extractPayloadByJWK(NewAccount))
r.MethodFunc("POST", getPath(acme.AccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(GetOrUpdateAccount))
r.MethodFunc("POST", getPath(acme.KeyChangeLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(NotImplemented))
r.MethodFunc("POST", getPath(acme.NewOrderLinkType, "{provisionerID}"),
extractPayloadByKid(NewOrder))
r.MethodFunc("POST", getPath(acme.OrderLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(isPostAsGet(GetOrder)))
r.MethodFunc("POST", getPath(acme.OrdersByAccountLinkType, "{provisionerID}", "{accID}"),
extractPayloadByKid(isPostAsGet(GetOrdersByAccountID)))
r.MethodFunc("POST", getPath(acme.FinalizeLinkType, "{provisionerID}", "{ordID}"),
extractPayloadByKid(FinalizeOrder))
r.MethodFunc("POST", getPath(acme.AuthzLinkType, "{provisionerID}", "{authzID}"),
extractPayloadByKid(isPostAsGet(GetAuthorization)))
r.MethodFunc("POST", getPath(acme.ChallengeLinkType, "{provisionerID}", "{authzID}", "{chID}"),
extractPayloadByKid(GetChallenge))
r.MethodFunc("POST", getPath(acme.CertificateLinkType, "{provisionerID}", "{certID}"),
extractPayloadByKid(isPostAsGet(GetCertificate)))
r.MethodFunc("POST", getPath(acme.RevokeCertLinkType, "{provisionerID}"),
extractPayloadByKidOrJWK(RevokeCert))
} }
// GetNonce just sets the right header since a Nonce is added to each response // GetNonce just sets the right header since a Nonce is added to each response
// by middleware by default. // by middleware by default.
func (h *Handler) GetNonce(w http.ResponseWriter, r *http.Request) { func GetNonce(w http.ResponseWriter, r *http.Request) {
if r.Method == "HEAD" { if r.Method == "HEAD" {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} else { } else {
@ -179,7 +219,7 @@ func (d *Directory) ToLog() (interface{}, error) {
// GetDirectory is the ACME resource for returning a directory configuration // GetDirectory is the ACME resource for returning a directory configuration
// for client configuration. // for client configuration.
func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) { func GetDirectory(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
acmeProv, err := acmeProvisionerFromContext(ctx) acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil { if err != nil {
@ -187,12 +227,13 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
return return
} }
linker := acme.MustLinkerFromContext(ctx)
render.JSON(w, &Directory{ render.JSON(w, &Directory{
NewNonce: h.linker.GetLink(ctx, NewNonceLinkType), NewNonce: linker.GetLink(ctx, acme.NewNonceLinkType),
NewAccount: h.linker.GetLink(ctx, NewAccountLinkType), NewAccount: linker.GetLink(ctx, acme.NewAccountLinkType),
NewOrder: h.linker.GetLink(ctx, NewOrderLinkType), NewOrder: linker.GetLink(ctx, acme.NewOrderLinkType),
RevokeCert: h.linker.GetLink(ctx, RevokeCertLinkType), RevokeCert: linker.GetLink(ctx, acme.RevokeCertLinkType),
KeyChange: h.linker.GetLink(ctx, KeyChangeLinkType), KeyChange: linker.GetLink(ctx, acme.KeyChangeLinkType),
Meta: Meta{ Meta: Meta{
ExternalAccountRequired: acmeProv.RequireEAB, ExternalAccountRequired: acmeProv.RequireEAB,
}, },
@ -201,19 +242,22 @@ func (h *Handler) GetDirectory(w http.ResponseWriter, r *http.Request) {
// NotImplemented returns a 501 and is generally a placeholder for functionality which // NotImplemented returns a 501 and is generally a placeholder for functionality which
// MAY be added at some point in the future but is not in any way a guarantee of such. // MAY be added at some point in the future but is not in any way a guarantee of such.
func (h *Handler) NotImplemented(w http.ResponseWriter, r *http.Request) { func NotImplemented(w http.ResponseWriter, r *http.Request) {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented")) render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "this API is not implemented"))
} }
// GetAuthorization ACME api for retrieving an Authz. // GetAuthorization ACME api for retrieving an Authz.
func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) { func GetAuthorization(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
az, err := h.db.GetAuthorization(ctx, chi.URLParam(r, "authzID")) az, err := db.GetAuthorization(ctx, chi.URLParam(r, "authzID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization")) render.Error(w, acme.WrapErrorISE(err, "error retrieving authorization"))
return return
@ -223,20 +267,23 @@ func (h *Handler) GetAuthorization(w http.ResponseWriter, r *http.Request) {
"account '%s' does not own authorization '%s'", acc.ID, az.ID)) "account '%s' does not own authorization '%s'", acc.ID, az.ID))
return return
} }
if err = az.UpdateStatus(ctx, h.db); err != nil { if err = az.UpdateStatus(ctx, db); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating authorization status")) render.Error(w, acme.WrapErrorISE(err, "error updating authorization status"))
return return
} }
h.linker.LinkAuthorization(ctx, az) linker.LinkAuthorization(ctx, az)
w.Header().Set("Location", h.linker.GetLink(ctx, AuthzLinkType, az.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.AuthzLinkType, az.ID))
render.JSON(w, az) render.JSON(w, az)
} }
// GetChallenge ACME api for retrieving a Challenge. // GetChallenge ACME api for retrieving a Challenge.
func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) { func GetChallenge(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -257,7 +304,7 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
// we'll just ignore the body. // we'll just ignore the body.
azID := chi.URLParam(r, "authzID") azID := chi.URLParam(r, "authzID")
ch, err := h.db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID) ch, err := db.GetChallenge(ctx, chi.URLParam(r, "chID"), azID)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge")) render.Error(w, acme.WrapErrorISE(err, "error retrieving challenge"))
return return
@ -273,29 +320,31 @@ func (h *Handler) GetChallenge(w http.ResponseWriter, r *http.Request) {
render.Error(w, err) render.Error(w, err)
return return
} }
if err = ch.Validate(ctx, h.db, jwk, h.validateChallengeOptions); err != nil { if err = ch.Validate(ctx, db, jwk); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error validating challenge")) render.Error(w, acme.WrapErrorISE(err, "error validating challenge"))
return return
} }
h.linker.LinkChallenge(ctx, ch, azID) linker.LinkChallenge(ctx, ch, azID)
w.Header().Add("Link", link(h.linker.GetLink(ctx, AuthzLinkType, azID), "up")) w.Header().Add("Link", link(linker.GetLink(ctx, acme.AuthzLinkType, azID), "up"))
w.Header().Set("Location", h.linker.GetLink(ctx, ChallengeLinkType, azID, ch.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.ChallengeLinkType, azID, ch.ID))
render.JSON(w, ch) render.JSON(w, ch)
} }
// GetCertificate ACME api for retrieving a Certificate. // GetCertificate ACME api for retrieving a Certificate.
func (h *Handler) GetCertificate(w http.ResponseWriter, r *http.Request) { func GetCertificate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
certID := chi.URLParam(r, "certID")
cert, err := h.db.GetCertificate(ctx, certID) certID := chi.URLParam(r, "certID")
cert, err := db.GetCertificate(ctx, certID)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate"))
return return

@ -3,6 +3,7 @@ package api
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
@ -19,11 +20,33 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
) )
type mockClient struct {
get func(url string) (*http.Response, error)
lookupTxt func(name string) ([]string, error)
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
}
func (m *mockClient) Get(u string) (*http.Response, error) { return m.get(u) }
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return m.tlsDial(network, addr, config)
}
func mockMustAuthority(t *testing.T, a acme.CertificateAuthority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) acme.CertificateAuthority {
return a
}
}
func TestHandler_GetNonce(t *testing.T) { func TestHandler_GetNonce(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -38,10 +61,10 @@ func TestHandler_GetNonce(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &Handler{} // h := &Handler{}
w := httptest.NewRecorder() w := httptest.NewRecorder()
req.Method = tt.name req.Method = tt.name
h.GetNonce(w, req) GetNonce(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -52,7 +75,8 @@ func TestHandler_GetNonce(t *testing.T) {
} }
func TestHandler_GetDirectory(t *testing.T) { func TestHandler_GetDirectory(t *testing.T) {
linker := NewLinker("ca.smallstep.com", "acme") linker := acme.NewLinker("ca.smallstep.com", "acme")
_ = linker
type test struct { type test struct {
ctx context.Context ctx context.Context
statusCode int statusCode int
@ -61,23 +85,14 @@ func TestHandler_GetDirectory(t *testing.T) {
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-provisioner": func(t *testing.T) test { "fail/no-provisioner": func(t *testing.T) test {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
ctx: ctx, ctx: context.Background(),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner in context is not an ACME provisioner"), err: acme.NewErrorISE("provisioner is not in context"),
} }
}, },
"fail/different-provisioner": func(t *testing.T) test { "fail/different-provisioner": func(t *testing.T) test {
prov := &provisioner.SCEP{ ctx := acme.NewProvisionerContext(context.Background(), &fakeProvisioner{})
Type: "SCEP",
Name: "test@scep-<test>provisioner.com",
}
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
@ -88,8 +103,7 @@ func TestHandler_GetDirectory(t *testing.T) {
prov := newProv() prov := newProv()
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
expDir := Directory{ expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
@ -108,8 +122,7 @@ func TestHandler_GetDirectory(t *testing.T) {
prov.RequireEAB = true prov.RequireEAB = true
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
expDir := Directory{ expDir := Directory{
NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName), NewNonce: fmt.Sprintf("%s/acme/%s/new-nonce", baseURL.String(), provName),
NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName), NewAccount: fmt.Sprintf("%s/acme/%s/new-account", baseURL.String(), provName),
@ -130,11 +143,11 @@ func TestHandler_GetDirectory(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: linker} ctx := acme.NewLinkerContext(tc.ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetDirectory(w, req) GetDirectory(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -219,7 +232,7 @@ func TestHandler_GetAuthorization(t *testing.T) {
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{}, db: &acme.MockDB{},
@ -285,10 +298,9 @@ func TestHandler_GetAuthorization(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) { MockGetAuthorization: func(ctx context.Context, id string) (*acme.Authorization, error) {
@ -304,11 +316,11 @@ func TestHandler_GetAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme")} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", "/foo/bar", nil) req := httptest.NewRequest("GET", "/foo/bar", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetAuthorization(w, req) GetAuthorization(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -447,11 +459,11 @@ func TestHandler_GetCertificate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} ctx := acme.NewDatabaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetCertificate(w, req) GetCertificate(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -491,7 +503,7 @@ func TestHandler_GetChallenge(t *testing.T) {
type test struct { type test struct {
db acme.DB db acme.DB
vco *acme.ValidateChallengeOptions vc acme.Client
ctx context.Context ctx context.Context
statusCode int statusCode int
ch *acme.Challenge ch *acme.Challenge
@ -500,6 +512,7 @@ func TestHandler_GetChallenge(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -507,6 +520,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), accContextKey, nil), ctx: context.WithValue(context.Background(), accContextKey, nil),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -516,6 +530,7 @@ func TestHandler_GetChallenge(t *testing.T) {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -523,10 +538,11 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload expected in request context"), err: acme.NewErrorISE("payload expected in request context"),
@ -534,7 +550,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/db.GetChallenge-error": func(t *testing.T) test { "fail/db.GetChallenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -553,7 +569,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -572,7 +588,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/no-jwk": func(t *testing.T) test { "fail/no-jwk": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -591,7 +607,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/nil-jwk": func(t *testing.T) test { "fail/nil-jwk": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
ctx = context.WithValue(ctx, jwkContextKey, nil) ctx = context.WithValue(ctx, jwkContextKey, nil)
@ -611,7 +627,7 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"fail/validate-challenge-error": func(t *testing.T) test { "fail/validate-challenge-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
@ -639,8 +655,8 @@ func TestHandler_GetChallenge(t *testing.T) {
return acme.NewErrorISE("force") return acme.NewErrorISE("force")
}, },
}, },
vco: &acme.ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(string) (*http.Response, error) { get: func(string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -651,14 +667,13 @@ func TestHandler_GetChallenge(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{isEmptyJSON: true})
_jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0) _jwk, err := jose.GenerateJWK("EC", "P-256", "ES256", "sig", "", 0)
assert.FatalError(t, err) assert.FatalError(t, err)
_pub := _jwk.Public() _pub := _jwk.Public()
ctx = context.WithValue(ctx, jwkContextKey, &_pub) ctx = context.WithValue(ctx, jwkContextKey, &_pub)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -690,8 +705,8 @@ func TestHandler_GetChallenge(t *testing.T) {
URL: u, URL: u,
Error: acme.NewError(acme.ErrorConnectionType, "force"), Error: acme.NewError(acme.ErrorConnectionType, "force"),
}, },
vco: &acme.ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(string) (*http.Response, error) { get: func(string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -703,11 +718,11 @@ func TestHandler_GetChallenge(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: NewLinker("dns", "acme"), validateChallengeOptions: tc.vco} ctx := acme.NewContext(tc.ctx, tc.db, nil, acme.NewLinker("test.ca.smallstep.com", "acme"), nil)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetChallenge(w, req) GetChallenge(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)

@ -9,7 +9,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/go-chi/chi"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
@ -31,39 +30,11 @@ func logNonce(w http.ResponseWriter, nonce string) {
} }
} }
// baseURLFromRequest determines the base URL which should be used for
// constructing link URLs in e.g. the ACME directory result by taking the
// request Host into consideration.
//
// If the Request.Host is an empty string, we return an empty string, to
// indicate that the configured URL values should be used instead. If this
// function returns a non-empty result, then this should be used in
// constructing ACME link URLs.
func baseURLFromRequest(r *http.Request) *url.URL {
// NOTE: See https://github.com/letsencrypt/boulder/blob/master/web/relative.go
// for an implementation that allows HTTP requests using the x-forwarded-proto
// header.
if r.Host == "" {
return nil
}
return &url.URL{Scheme: "https", Host: r.Host}
}
// baseURLFromRequest is a middleware that extracts and caches the baseURL
// from the request.
// E.g. https://ca.smallstep.com/
func (h *Handler) baseURLFromRequest(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), baseURLContextKey, baseURLFromRequest(r))
next(w, r.WithContext(ctx))
}
}
// addNonce is a middleware that adds a nonce to the response header. // addNonce is a middleware that adds a nonce to the response header.
func (h *Handler) addNonce(next nextHTTP) nextHTTP { func addNonce(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
nonce, err := h.db.CreateNonce(r.Context()) db := acme.MustDatabaseFromContext(r.Context())
nonce, err := db.CreateNonce(r.Context())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -77,25 +48,31 @@ func (h *Handler) addNonce(next nextHTTP) nextHTTP {
// addDirLink is a middleware that adds a 'Link' response reader with the // addDirLink is a middleware that adds a 'Link' response reader with the
// directory index url. // directory index url.
func (h *Handler) addDirLink(next nextHTTP) nextHTTP { func addDirLink(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Link", link(h.linker.GetLink(r.Context(), DirectoryLinkType), "index")) ctx := r.Context()
linker := acme.MustLinkerFromContext(ctx)
w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
next(w, r) next(w, r)
} }
} }
// verifyContentType is a middleware that verifies that content type is // verifyContentType is a middleware that verifies that content type is
// application/jose+json. // application/jose+json.
func (h *Handler) verifyContentType(next nextHTTP) nextHTTP { func verifyContentType(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var expected []string
p, err := provisionerFromContext(r.Context()) p, err := provisionerFromContext(r.Context())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
u := url.URL{Path: h.linker.GetUnescapedPathSuffix(CertificateLinkType, p.GetName(), "")} u := &url.URL{
Path: acme.GetUnescapedPathSuffix(acme.CertificateLinkType, p.GetName(), ""),
}
var expected []string
if strings.Contains(r.URL.String(), u.EscapedPath()) { if strings.Contains(r.URL.String(), u.EscapedPath()) {
// GET /certificate requests allow a greater range of content types. // GET /certificate requests allow a greater range of content types.
expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"} expected = []string{"application/jose+json", "application/pkix-cert", "application/pkcs7-mime"}
@ -117,7 +94,7 @@ func (h *Handler) verifyContentType(next nextHTTP) nextHTTP {
} }
// parseJWS is a middleware that parses a request body into a JSONWebSignature struct. // parseJWS is a middleware that parses a request body into a JSONWebSignature struct.
func (h *Handler) parseJWS(next nextHTTP) nextHTTP { func parseJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
@ -142,17 +119,19 @@ func (h *Handler) parseJWS(next nextHTTP) nextHTTP {
// The JWS Unprotected Header [RFC7515] MUST NOT be used // The JWS Unprotected Header [RFC7515] MUST NOT be used
// The JWS Payload MUST NOT be detached // The JWS Payload MUST NOT be detached
// The JWS Protected Header MUST include the following fields: // The JWS Protected Header MUST include the following fields:
// * “alg” (Algorithm) // - “alg” (Algorithm).
// * This field MUST NOT contain “none” or a Message Authentication Code // This field MUST NOT contain “none” or a Message Authentication Code
// (MAC) algorithm (e.g. one in which the algorithm registry description // (MAC) algorithm (e.g. one in which the algorithm registry description
// mentions MAC/HMAC). // mentions MAC/HMAC).
// * “nonce” (defined in Section 6.5) // - “nonce” (defined in Section 6.5)
// * “url” (defined in Section 6.4) // - “url” (defined in Section 6.4)
// * Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste> // - Either “jwk” (JSON Web Key) or “kid” (Key ID) as specified below<Paste>
func (h *Handler) validateJWS(next nextHTTP) nextHTTP { func validateJWS(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -202,7 +181,7 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
} }
// Check the validity/freshness of the Nonce. // Check the validity/freshness of the Nonce.
if err := h.db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil { if err := db.DeleteNonce(ctx, acme.Nonce(hdr.Nonce)); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
@ -235,10 +214,12 @@ func (h *Handler) validateJWS(next nextHTTP) nextHTTP {
// extractJWK is a middleware that extracts the JWK from the JWS and saves it // extractJWK is a middleware that extracts the JWK from the JWS and saves it
// in the context. Make sure to parse and validate the JWS before running this // in the context. Make sure to parse and validate the JWS before running this
// middleware. // middleware.
func (h *Handler) extractJWK(next nextHTTP) nextHTTP { func extractJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(r.Context()) db := acme.MustDatabaseFromContext(ctx)
jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -264,7 +245,7 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
ctx = context.WithValue(ctx, jwkContextKey, jwk) ctx = context.WithValue(ctx, jwkContextKey, jwk)
// Get Account OR continue to generate a new one OR continue Revoke with certificate private key // Get Account OR continue to generate a new one OR continue Revoke with certificate private key
acc, err := h.db.GetAccountByKeyID(ctx, jwk.KeyID) acc, err := db.GetAccountByKeyID(ctx, jwk.KeyID)
switch { switch {
case errors.Is(err, acme.ErrNotFound): case errors.Is(err, acme.ErrNotFound):
// For NewAccount and Revoke requests ... // For NewAccount and Revoke requests ...
@ -283,63 +264,44 @@ func (h *Handler) extractJWK(next nextHTTP) nextHTTP {
} }
} }
// lookupProvisioner loads the provisioner associated with the request.
// Responds 404 if the provisioner does not exist.
func (h *Handler) lookupProvisioner(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := h.ca.LoadProvisionerByName(name)
if err != nil {
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = context.WithValue(ctx, provisionerContextKey, acme.Provisioner(acmeProv))
next(w, r.WithContext(ctx))
}
}
// checkPrerequisites checks if all prerequisites for serving ACME // checkPrerequisites checks if all prerequisites for serving ACME
// are met by the CA configuration. // are met by the CA configuration.
func (h *Handler) checkPrerequisites(next nextHTTP) nextHTTP { func checkPrerequisites(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
ok, err := h.prerequisitesChecker(ctx) // If the function is not set assume that all prerequisites are met.
if err != nil { checkFunc, ok := acme.PrerequisitesCheckerFromContext(ctx)
render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites")) if ok {
return ok, err := checkFunc(ctx)
} if err != nil {
if !ok { render.Error(w, acme.WrapErrorISE(err, "error checking acme provisioner prerequisites"))
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites")) return
return }
if !ok {
render.Error(w, acme.NewError(acme.ErrorNotImplementedType, "acme provisioner configuration lacks prerequisites"))
return
}
} }
next(w, r.WithContext(ctx)) next(w, r)
} }
} }
// lookupJWK loads the JWK associated with the acme account referenced by the // lookupJWK loads the JWK associated with the acme account referenced by the
// kid parameter of the signed payload. // kid parameter of the signed payload.
// Make sure to parse and validate the JWS before running this middleware. // Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) lookupJWK(next nextHTTP) nextHTTP { func lookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
kidPrefix := h.linker.GetLink(ctx, AccountLinkType, "") kidPrefix := linker.GetLink(ctx, acme.AccountLinkType, "")
kid := jws.Signatures[0].Protected.KeyID kid := jws.Signatures[0].Protected.KeyID
if !strings.HasPrefix(kid, kidPrefix) { if !strings.HasPrefix(kid, kidPrefix) {
render.Error(w, acme.NewError(acme.ErrorMalformedType, render.Error(w, acme.NewError(acme.ErrorMalformedType,
@ -349,7 +311,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
} }
accID := strings.TrimPrefix(kid, kidPrefix) accID := strings.TrimPrefix(kid, kidPrefix)
acc, err := h.db.GetAccount(ctx, accID) acc, err := db.GetAccount(ctx, accID)
switch { switch {
case nosql.IsErrNotFound(err): case nosql.IsErrNotFound(err):
render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID)) render.Error(w, acme.NewError(acme.ErrorAccountDoesNotExistType, "account with ID '%s' not found", accID))
@ -372,7 +334,7 @@ func (h *Handler) lookupJWK(next nextHTTP) nextHTTP {
// extractOrLookupJWK forwards handling to either extractJWK or // extractOrLookupJWK forwards handling to either extractJWK or
// lookupJWK based on the presence of a JWK or a KID, respectively. // lookupJWK based on the presence of a JWK or a KID, respectively.
func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP { func extractOrLookupJWK(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
@ -385,13 +347,13 @@ func (h *Handler) extractOrLookupJWK(next nextHTTP) nextHTTP {
// and it can be used to check if a JWK exists. This flow is used when the ACME client // and it can be used to check if a JWK exists. This flow is used when the ACME client
// signed the payload with a certificate private key. // signed the payload with a certificate private key.
if canExtractJWKFrom(jws) { if canExtractJWKFrom(jws) {
h.extractJWK(next)(w, r) extractJWK(next)(w, r)
return return
} }
// default to looking up the JWK based on KeyID. This flow is used when the ACME client // default to looking up the JWK based on KeyID. This flow is used when the ACME client
// signed the payload with an account private key. // signed the payload with an account private key.
h.lookupJWK(next)(w, r) lookupJWK(next)(w, r)
} }
} }
@ -408,7 +370,7 @@ func canExtractJWKFrom(jws *jose.JSONWebSignature) bool {
// verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context. // verifyAndExtractJWSPayload extracts the JWK from the JWS and saves it in the context.
// Make sure to parse and validate the JWS before running this middleware. // Make sure to parse and validate the JWS before running this middleware.
func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP { func verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
@ -440,7 +402,7 @@ func (h *Handler) verifyAndExtractJWSPayload(next nextHTTP) nextHTTP {
} }
// isPostAsGet asserts that the request is a PostAsGet (empty JWS payload). // isPostAsGet asserts that the request is a PostAsGet (empty JWS payload).
func (h *Handler) isPostAsGet(next nextHTTP) nextHTTP { func isPostAsGet(next nextHTTP) nextHTTP {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
payload, err := payloadFromContext(r.Context()) payload, err := payloadFromContext(r.Context())
if err != nil { if err != nil {
@ -462,16 +424,12 @@ type ContextKey string
const ( const (
// accContextKey account key // accContextKey account key
accContextKey = ContextKey("acc") accContextKey = ContextKey("acc")
// baseURLContextKey baseURL key
baseURLContextKey = ContextKey("baseURL")
// jwsContextKey jws key // jwsContextKey jws key
jwsContextKey = ContextKey("jws") jwsContextKey = ContextKey("jws")
// jwkContextKey jwk key // jwkContextKey jwk key
jwkContextKey = ContextKey("jwk") jwkContextKey = ContextKey("jwk")
// payloadContextKey payload key // payloadContextKey payload key
payloadContextKey = ContextKey("payload") payloadContextKey = ContextKey("payload")
// provisionerContextKey provisioner key
provisionerContextKey = ContextKey("provisioner")
) )
// accountFromContext searches the context for an ACME account. Returns the // accountFromContext searches the context for an ACME account. Returns the
@ -484,15 +442,6 @@ func accountFromContext(ctx context.Context) (*acme.Account, error) {
return val, nil return val, nil
} }
// baseURLFromContext returns the baseURL if one is stored in the context.
func baseURLFromContext(ctx context.Context) *url.URL {
val, ok := ctx.Value(baseURLContextKey).(*url.URL)
if !ok || val == nil {
return nil
}
return val
}
// jwkFromContext searches the context for a JWK. Returns the JWK or an error. // jwkFromContext searches the context for a JWK. Returns the JWK or an error.
func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) { func jwkFromContext(ctx context.Context) (*jose.JSONWebKey, error) {
val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey) val, ok := ctx.Value(jwkContextKey).(*jose.JSONWebKey)
@ -514,29 +463,26 @@ func jwsFromContext(ctx context.Context) (*jose.JSONWebSignature, error) {
// provisionerFromContext searches the context for a provisioner. Returns the // provisionerFromContext searches the context for a provisioner. Returns the
// provisioner or an error. // provisioner or an error.
func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) { func provisionerFromContext(ctx context.Context) (acme.Provisioner, error) {
val := ctx.Value(provisionerContextKey) p, ok := acme.ProvisionerFromContext(ctx)
if val == nil { if !ok || p == nil {
return nil, acme.NewErrorISE("provisioner expected in request context") return nil, acme.NewErrorISE("provisioner expected in request context")
} }
pval, ok := val.(acme.Provisioner) return p, nil
if !ok || pval == nil {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
}
return pval, nil
} }
// acmeProvisionerFromContext searches the context for an ACME provisioner. Returns // acmeProvisionerFromContext searches the context for an ACME provisioner. Returns
// pointer to an ACME provisioner or an error. // pointer to an ACME provisioner or an error.
func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) { func acmeProvisionerFromContext(ctx context.Context) (*provisioner.ACME, error) {
prov, err := provisionerFromContext(ctx) p, err := provisionerFromContext(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
acmeProv, ok := prov.(*provisioner.ACME) ap, ok := p.(*provisioner.ACME)
if !ok || acmeProv == nil { if !ok {
return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner") return nil, acme.NewErrorISE("provisioner in context is not an ACME provisioner")
} }
return acmeProv, nil
return ap, nil
} }
// payloadFromContext searches the context for a payload. Returns the payload // payloadFromContext searches the context for a payload. Returns the payload

@ -27,83 +27,18 @@ func testNext(w http.ResponseWriter, r *http.Request) {
w.Write(testBody) w.Write(testBody)
} }
func Test_baseURLFromRequest(t *testing.T) { func newBaseContext(ctx context.Context, args ...interface{}) context.Context {
tests := []struct { for _, a := range args {
name string switch v := a.(type) {
targetURL string case acme.DB:
expectedResult *url.URL ctx = acme.NewDatabaseContext(ctx, v)
requestPreparer func(*http.Request) case acme.Linker:
}{ ctx = acme.NewLinkerContext(ctx, v)
{ case acme.PrerequisitesChecker:
"HTTPS host pass-through failed.", ctx = acme.NewPrerequisitesCheckerContext(ctx, v)
"https://my.dummy.host",
&url.URL{Scheme: "https", Host: "my.dummy.host"},
nil,
},
{
"Port pass-through failed",
"https://host.with.port:8080",
&url.URL{Scheme: "https", Host: "host.with.port:8080"},
nil,
},
{
"Explicit host from Request.Host was not used.",
"https://some.target.host:8080",
&url.URL{Scheme: "https", Host: "proxied.host"},
func(r *http.Request) {
r.Host = "proxied.host"
},
},
{
"Missing Request.Host value did not result in empty string result.",
"https://some.host",
nil,
func(r *http.Request) {
r.Host = ""
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
request := httptest.NewRequest("GET", tc.targetURL, nil)
if tc.requestPreparer != nil {
tc.requestPreparer(request)
}
result := baseURLFromRequest(request)
if result == nil || tc.expectedResult == nil {
assert.Equals(t, result, tc.expectedResult)
} else if result.String() != tc.expectedResult.String() {
t.Errorf("Expected %q, but got %q", tc.expectedResult.String(), result.String())
}
})
}
}
func TestHandler_baseURLFromRequest(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest("GET", "/foo", nil)
req.Host = "test.ca.smallstep.com:8080"
w := httptest.NewRecorder()
next := func(w http.ResponseWriter, r *http.Request) {
bu := baseURLFromContext(r.Context())
if assert.NotNil(t, bu) {
assert.Equals(t, bu.Host, "test.ca.smallstep.com:8080")
assert.Equals(t, bu.Scheme, "https")
} }
} }
return ctx
h.baseURLFromRequest(next)(w, req)
req = httptest.NewRequest("GET", "/foo", nil)
req.Host = ""
next = func(w http.ResponseWriter, r *http.Request) {
assert.Equals(t, baseURLFromContext(r.Context()), nil)
}
h.baseURLFromRequest(next)(w, req)
} }
func TestHandler_addNonce(t *testing.T) { func TestHandler_addNonce(t *testing.T) {
@ -139,10 +74,10 @@ func TestHandler_addNonce(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} ctx := newBaseContext(context.Background(), tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil).WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.addNonce(testNext)(w, req) addNonce(testNext)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -175,17 +110,15 @@ func TestHandler_addDirLink(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
type test struct { type test struct {
link string link string
linker Linker
statusCode int statusCode int
ctx context.Context ctx context.Context
err *acme.Error err *acme.Error
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = acme.NewLinkerContext(ctx, acme.NewLinker("test.ca.smallstep.com", "acme"))
return test{ return test{
linker: NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName), link: fmt.Sprintf("%s/acme/%s/directory", baseURL.String(), provName),
statusCode: 200, statusCode: 200,
@ -195,11 +128,10 @@ func TestHandler_addDirLink(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: tc.linker}
req := httptest.NewRequest("GET", "/foo", nil) req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.addDirLink(testNext)(w, req) addDirLink(testNext)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -231,7 +163,6 @@ func TestHandler_verifyContentType(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName) u := fmt.Sprintf("%s/acme/%s/certificate/abc123", baseURL.String(), escProvName)
type test struct { type test struct {
h Handler
ctx context.Context ctx context.Context
contentType string contentType string
err *acme.Error err *acme.Error
@ -241,9 +172,6 @@ func TestHandler_verifyContentType(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/provisioner-not-set": func(t *testing.T) test { "fail/provisioner-not-set": func(t *testing.T) test {
return test{ return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
url: u, url: u,
ctx: context.Background(), ctx: context.Background(),
contentType: "foo", contentType: "foo",
@ -253,11 +181,8 @@ func TestHandler_verifyContentType(t *testing.T) {
}, },
"fail/general-bad-content-type": func(t *testing.T) test { "fail/general-bad-content-type": func(t *testing.T) test {
return test{ return test{
h: Handler{
linker: NewLinker("dns", "acme"),
},
url: u, url: u,
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), ctx: acme.NewProvisionerContext(context.Background(), prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"), err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json], but got foo"),
@ -265,10 +190,7 @@ func TestHandler_verifyContentType(t *testing.T) {
}, },
"fail/certificate-bad-content-type": func(t *testing.T) test { "fail/certificate-bad-content-type": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "foo", contentType: "foo",
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"), err: acme.NewError(acme.ErrorMalformedType, "expected content-type to be in [application/jose+json application/pkix-cert application/pkcs7-mime], but got foo"),
@ -276,40 +198,28 @@ func TestHandler_verifyContentType(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json", contentType: "application/jose+json",
statusCode: 200, statusCode: 200,
} }
}, },
"ok/certificate/pkix-cert": func(t *testing.T) test { "ok/certificate/pkix-cert": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkix-cert", contentType: "application/pkix-cert",
statusCode: 200, statusCode: 200,
} }
}, },
"ok/certificate/jose+json": func(t *testing.T) test { "ok/certificate/jose+json": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/jose+json", contentType: "application/jose+json",
statusCode: 200, statusCode: 200,
} }
}, },
"ok/certificate/pkcs7-mime": func(t *testing.T) test { "ok/certificate/pkcs7-mime": func(t *testing.T) test {
return test{ return test{
h: Handler{ ctx: acme.NewProvisionerContext(context.Background(), prov),
linker: NewLinker("dns", "acme"),
},
ctx: context.WithValue(context.Background(), provisionerContextKey, prov),
contentType: "application/pkcs7-mime", contentType: "application/pkcs7-mime",
statusCode: 200, statusCode: 200,
} }
@ -326,7 +236,7 @@ func TestHandler_verifyContentType(t *testing.T) {
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
req.Header.Add("Content-Type", tc.contentType) req.Header.Add("Content-Type", tc.contentType)
w := httptest.NewRecorder() w := httptest.NewRecorder()
tc.h.verifyContentType(testNext)(w, req) verifyContentType(testNext)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -390,11 +300,11 @@ func TestHandler_isPostAsGet(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} // h := &Handler{}
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.isPostAsGet(testNext)(w, req) isPostAsGet(testNext)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -481,10 +391,10 @@ func TestHandler_parseJWS(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} // h := &Handler{}
req := httptest.NewRequest("GET", u, tc.body) req := httptest.NewRequest("GET", u, tc.body)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.parseJWS(tc.next)(w, req) parseJWS(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -679,11 +589,11 @@ func TestHandler_verifyAndExtractJWSPayload(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{} // h := &Handler{}
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.verifyAndExtractJWSPayload(tc.next)(w, req) verifyAndExtractJWSPayload(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -733,7 +643,7 @@ func TestHandler_lookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
type test struct { type test struct {
linker Linker linker acme.Linker
db acme.DB db acme.DB
ctx context.Context ctx context.Context
next func(http.ResponseWriter, *http.Request) next func(http.ResponseWriter, *http.Request)
@ -743,15 +653,19 @@ func TestHandler_lookupJWK(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
} }
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, nil) ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -765,11 +679,11 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_jws, err := _signer.Sign([]byte("baz")) _jws, err := _signer.Sign([]byte("baz"))
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, jwsContextKey, _jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix), err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got ", prefix),
@ -789,22 +703,21 @@ func TestHandler_lookupJWK(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
_parsed, err := jose.ParseJWS(_raw) _parsed, err := jose.ParseJWS(_raw)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _parsed) ctx = context.WithValue(ctx, jwsContextKey, _parsed)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), db: &acme.MockDB{},
linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix), err: acme.NewError(acme.ErrorMalformedType, "kid does not have required prefix; expected %s, but got foo", prefix),
} }
}, },
"fail/account-not-found": func(t *testing.T) test { "fail/account-not-found": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, accID) assert.Equals(t, accID, accID)
@ -817,11 +730,10 @@ func TestHandler_lookupJWK(t *testing.T) {
} }
}, },
"fail/GetAccount-error": func(t *testing.T) test { "fail/GetAccount-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
@ -835,11 +747,10 @@ func TestHandler_lookupJWK(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"} acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
@ -853,11 +764,10 @@ func TestHandler_lookupJWK(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid", Key: jwk} acc := &acme.Account{Status: "valid", Key: jwk}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, id string) (*acme.Account, error) {
assert.Equals(t, id, accID) assert.Equals(t, id, accID)
@ -881,11 +791,11 @@ func TestHandler_lookupJWK(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker} ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.lookupJWK(tc.next)(w, req) lookupJWK(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -945,15 +855,17 @@ func TestHandler_extractJWK(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
} }
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, nil) ctx = context.WithValue(ctx, jwsContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -969,9 +881,10 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
}, },
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"), err: acme.NewError(acme.ErrorMalformedType, "jwk expected in protected header"),
@ -987,16 +900,17 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
}, },
} }
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, _jws) ctx = context.WithValue(ctx, jwsContextKey, _jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"), err: acme.NewError(acme.ErrorMalformedType, "invalid jwk in protected header"),
} }
}, },
"fail/GetAccountByKey-error": func(t *testing.T) test { "fail/GetAccountByKey-error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1012,7 +926,7 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{Status: "deactivated"} acc := &acme.Account{Status: "deactivated"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1028,7 +942,7 @@ func TestHandler_extractJWK(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{Status: "valid"} acc := &acme.Account{Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1051,7 +965,7 @@ func TestHandler_extractJWK(t *testing.T) {
} }
}, },
"ok/no-account": func(t *testing.T) test { "ok/no-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
ctx: ctx, ctx: ctx,
@ -1077,11 +991,11 @@ func TestHandler_extractJWK(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} ctx := newBaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.extractJWK(tc.next)(w, req) extractJWK(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1118,6 +1032,7 @@ func TestHandler_validateJWS(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.Background(), ctx: context.Background(),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -1125,6 +1040,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, nil), ctx: context.WithValue(context.Background(), jwsContextKey, nil),
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -1132,6 +1048,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
"fail/no-signature": func(t *testing.T) test { "fail/no-signature": func(t *testing.T) test {
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}), ctx: context.WithValue(context.Background(), jwsContextKey, &jose.JSONWebSignature{}),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"), err: acme.NewError(acme.ErrorMalformedType, "request body does not contain a signature"),
@ -1145,6 +1062,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"), err: acme.NewError(acme.ErrorMalformedType, "request body contains more than one signature"),
@ -1157,6 +1075,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"), err: acme.NewError(acme.ErrorMalformedType, "unprotected header must not be used"),
@ -1169,6 +1088,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"), err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: none"),
@ -1181,6 +1101,7 @@ func TestHandler_validateJWS(t *testing.T) {
}, },
} }
return test{ return test{
db: &acme.MockDB{},
ctx: context.WithValue(context.Background(), jwsContextKey, jws), ctx: context.WithValue(context.Background(), jwsContextKey, jws),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256), err: acme.NewError(acme.ErrorBadSignatureAlgorithmType, "unsuitable algorithm: %s", jose.HS256),
@ -1444,11 +1365,11 @@ func TestHandler_validateJWS(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} ctx := newBaseContext(tc.ctx, tc.db)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.validateJWS(tc.next)(w, req) validateJWS(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1542,7 +1463,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
u := "https://ca.smallstep.com/acme/account" u := "https://ca.smallstep.com/acme/account"
type test struct { type test struct {
db acme.DB db acme.DB
linker Linker linker acme.Linker
statusCode int statusCode int
ctx context.Context ctx context.Context
err *acme.Error err *acme.Error
@ -1570,7 +1491,7 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) { MockGetAccountByKeyID: func(ctx context.Context, kid string) (*acme.Account, error) {
assert.Equals(t, kid, pub.KeyID) assert.Equals(t, kid, pub.KeyID)
@ -1606,11 +1527,10 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
parsedJWS, err := jose.ParseJWS(raw) parsedJWS, err := jose.ParseJWS(raw)
assert.FatalError(t, err) assert.FatalError(t, err)
acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"} acc := &acme.Account{ID: "accID", Key: jwk, Status: "valid"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
return test{ return test{
linker: NewLinker("test.ca.smallstep.com", "acme"), linker: acme.NewLinker("test.ca.smallstep.com", "acme"),
db: &acme.MockDB{ db: &acme.MockDB{
MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) { MockGetAccount: func(ctx context.Context, accID string) (*acme.Account, error) {
assert.Equals(t, accID, acc.ID) assert.Equals(t, accID, acc.ID)
@ -1628,11 +1548,11 @@ func TestHandler_extractOrLookupJWK(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db, linker: tc.linker} ctx := newBaseContext(tc.ctx, tc.db, tc.linker)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.extractOrLookupJWK(tc.next)(w, req) extractOrLookupJWK(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1664,7 +1584,7 @@ func TestHandler_checkPrerequisites(t *testing.T) {
u := fmt.Sprintf("%s/acme/%s/account/1234", u := fmt.Sprintf("%s/acme/%s/account/1234",
baseURL, provName) baseURL, provName)
type test struct { type test struct {
linker Linker linker acme.Linker
ctx context.Context ctx context.Context
prerequisitesChecker func(context.Context) (bool, error) prerequisitesChecker func(context.Context) (bool, error)
next func(http.ResponseWriter, *http.Request) next func(http.ResponseWriter, *http.Request)
@ -1673,10 +1593,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/error": func(t *testing.T) test { "fail/error": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") }, prerequisitesChecker: func(context.Context) (bool, error) { return false, errors.New("force") },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
@ -1687,10 +1606,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
} }
}, },
"fail/prerequisites-nok": func(t *testing.T) test { "fail/prerequisites-nok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return false, nil }, prerequisitesChecker: func(context.Context) (bool, error) { return false, nil },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
@ -1701,10 +1619,9 @@ func TestHandler_checkPrerequisites(t *testing.T) {
} }
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
linker: NewLinker("dns", "acme"), linker: acme.NewLinker("dns", "acme"),
ctx: ctx, ctx: ctx,
prerequisitesChecker: func(context.Context) (bool, error) { return true, nil }, prerequisitesChecker: func(context.Context) (bool, error) { return true, nil },
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
@ -1717,11 +1634,11 @@ func TestHandler_checkPrerequisites(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: nil, linker: tc.linker, prerequisitesChecker: tc.prerequisitesChecker} ctx := acme.NewPrerequisitesCheckerContext(tc.ctx, tc.prerequisitesChecker)
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.checkPrerequisites(tc.next)(w, req) checkPrerequisites(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)

@ -16,6 +16,8 @@ import (
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
) )
// NewOrderRequest represents the body for a NewOrder request. // NewOrderRequest represents the body for a NewOrder request.
@ -37,6 +39,8 @@ func (n *NewOrderRequest) Validate() error {
if id.Type == acme.IP && net.ParseIP(id.Value) == nil { if id.Type == acme.IP && net.ParseIP(id.Value) == nil {
return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value) return acme.NewError(acme.ErrorMalformedType, "invalid IP address: %s", id.Value)
} }
// TODO(hs): add some validations for DNS domains?
// TODO(hs): combine the errors from this with allow/deny policy, like example error in https://datatracker.ietf.org/doc/html/rfc8555#section-6.7.1
} }
return nil return nil
} }
@ -50,7 +54,13 @@ type FinalizeRequest struct {
// Validate validates a finalize request body. // Validate validates a finalize request body.
func (f *FinalizeRequest) Validate() error { func (f *FinalizeRequest) Validate() error {
var err error var err error
csrBytes, err := base64.RawURLEncoding.DecodeString(f.CSR) // RFC 8555 isn't 100% conclusive about using raw base64-url encoding for the
// CSR specifically, instead of "normal" base64-url encoding (incl. padding).
// By trimming the padding from CSRs submitted by ACME clients that use
// base64-url encoding instead of raw base64-url encoding, these are also
// supported. This was reported in https://github.com/smallstep/certificates/issues/939
// to be the case for a Synology DSM NAS system.
csrBytes, err := base64.RawURLEncoding.DecodeString(strings.TrimRight(f.CSR, "="))
if err != nil { if err != nil {
return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr") return acme.WrapError(acme.ErrorMalformedType, err, "error base64url decoding csr")
} }
@ -68,8 +78,12 @@ var defaultOrderExpiry = time.Hour * 24
var defaultOrderBackdate = time.Minute var defaultOrderBackdate = time.Minute
// NewOrder ACME api for creating a new order. // NewOrder ACME api for creating a new order.
func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) { func NewOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
ca := mustAuthority(ctx)
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -85,6 +99,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
render.Error(w, err) render.Error(w, err)
return return
} }
var nor NewOrderRequest var nor NewOrderRequest
if err := json.Unmarshal(payload.value, &nor); err != nil { if err := json.Unmarshal(payload.value, &nor); err != nil {
render.Error(w, acme.WrapError(acme.ErrorMalformedType, err, render.Error(w, acme.WrapError(acme.ErrorMalformedType, err,
@ -97,6 +112,48 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
// TODO(hs): gather all errors, so that we can build one response with ACME subproblems
// include the nor.Validate() error here too, like in the example in the ACME RFC?
acmeProv, err := acmeProvisionerFromContext(ctx)
if err != nil {
render.Error(w, err)
return
}
var eak *acme.ExternalAccountKey
if acmeProv.RequireEAB {
if eak, err = db.GetExternalAccountKeyByAccountID(ctx, prov.GetID(), acc.ID); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving external account binding key"))
return
}
}
acmePolicy, err := newACMEPolicyEngine(eak)
if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating ACME policy engine"))
return
}
for _, identifier := range nor.Identifiers {
// evaluate the ACME account level policy
if err = isIdentifierAllowed(acmePolicy, identifier); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return
}
// evaluate the provisioner level policy
orderIdentifier := provisioner.ACMEIdentifier{Type: provisioner.ACMEIdentifierType(identifier.Type), Value: identifier.Value}
if err = prov.AuthorizeOrderIdentifier(ctx, orderIdentifier); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return
}
// evaluate the authority level policy
if err = ca.AreSANsAllowed(ctx, []string{identifier.Value}); err != nil {
render.Error(w, acme.WrapError(acme.ErrorRejectedIdentifierType, err, "not authorized"))
return
}
}
now := clock.Now() now := clock.Now()
// New order. // New order.
o := &acme.Order{ o := &acme.Order{
@ -117,7 +174,7 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
ExpiresAt: o.ExpiresAt, ExpiresAt: o.ExpiresAt,
Status: acme.StatusPending, Status: acme.StatusPending,
} }
if err := h.newAuthorization(ctx, az); err != nil { if err := newAuthorization(ctx, az); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
@ -136,18 +193,32 @@ func (h *Handler) NewOrder(w http.ResponseWriter, r *http.Request) {
o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate) o.NotBefore = o.NotBefore.Add(-defaultOrderBackdate)
} }
if err := h.db.CreateOrder(ctx, o); err != nil { if err := db.CreateOrder(ctx, o); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error creating order")) render.Error(w, acme.WrapErrorISE(err, "error creating order"))
return return
} }
h.linker.LinkOrder(ctx, o) linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSONStatus(w, o, http.StatusCreated) render.JSONStatus(w, o, http.StatusCreated)
} }
func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization) error { func isIdentifierAllowed(acmePolicy policy.X509Policy, identifier acme.Identifier) error {
if acmePolicy == nil {
return nil
}
return acmePolicy.AreSANsAllowed([]string{identifier.Value})
}
func newACMEPolicyEngine(eak *acme.ExternalAccountKey) (policy.X509Policy, error) {
if eak == nil {
return nil, nil
}
return policy.NewX509PolicyEngine(eak.Policy)
}
func newAuthorization(ctx context.Context, az *acme.Authorization) error {
if strings.HasPrefix(az.Identifier.Value, "*.") { if strings.HasPrefix(az.Identifier.Value, "*.") {
az.Wildcard = true az.Wildcard = true
az.Identifier = acme.Identifier{ az.Identifier = acme.Identifier{
@ -163,6 +234,8 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
if err != nil { if err != nil {
return acme.WrapErrorISE(err, "error generating random alphanumeric ID") return acme.WrapErrorISE(err, "error generating random alphanumeric ID")
} }
db := acme.MustDatabaseFromContext(ctx)
az.Challenges = make([]*acme.Challenge, len(chTypes)) az.Challenges = make([]*acme.Challenge, len(chTypes))
for i, typ := range chTypes { for i, typ := range chTypes {
ch := &acme.Challenge{ ch := &acme.Challenge{
@ -172,20 +245,23 @@ func (h *Handler) newAuthorization(ctx context.Context, az *acme.Authorization)
Token: az.Token, Token: az.Token,
Status: acme.StatusPending, Status: acme.StatusPending,
} }
if err := h.db.CreateChallenge(ctx, ch); err != nil { if err := db.CreateChallenge(ctx, ch); err != nil {
return acme.WrapErrorISE(err, "error creating challenge") return acme.WrapErrorISE(err, "error creating challenge")
} }
az.Challenges[i] = ch az.Challenges[i] = ch
} }
if err = h.db.CreateAuthorization(ctx, az); err != nil { if err = db.CreateAuthorization(ctx, az); err != nil {
return acme.WrapErrorISE(err, "error creating authorization") return acme.WrapErrorISE(err, "error creating authorization")
} }
return nil return nil
} }
// GetOrder ACME api for retrieving an order. // GetOrder ACME api for retrieving an order.
func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) { func GetOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -196,7 +272,8 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
render.Error(w, err) render.Error(w, err)
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID"))
o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
@ -211,20 +288,23 @@ func (h *Handler) GetOrder(w http.ResponseWriter, r *http.Request) {
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.UpdateStatus(ctx, h.db); err != nil { if err = o.UpdateStatus(ctx, db); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error updating order status")) render.Error(w, acme.WrapErrorISE(err, "error updating order status"))
return return
} }
h.linker.LinkOrder(ctx, o) linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o) render.JSON(w, o)
} }
// FinalizeOrder attemptst to finalize an order and create a certificate. // FinalizeOrder attempts to finalize an order and create a certificate.
func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) { func FinalizeOrder(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
acc, err := accountFromContext(ctx) acc, err := accountFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -251,7 +331,7 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
return return
} }
o, err := h.db.GetOrder(ctx, chi.URLParam(r, "ordID")) o, err := db.GetOrder(ctx, chi.URLParam(r, "ordID"))
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving order")) render.Error(w, acme.WrapErrorISE(err, "error retrieving order"))
return return
@ -266,14 +346,16 @@ func (h *Handler) FinalizeOrder(w http.ResponseWriter, r *http.Request) {
"provisioner '%s' does not own order '%s'", prov.GetID(), o.ID)) "provisioner '%s' does not own order '%s'", prov.GetID(), o.ID))
return return
} }
if err = o.Finalize(ctx, h.db, fr.csr, h.ca, prov); err != nil {
ca := mustAuthority(ctx)
if err = o.Finalize(ctx, db, fr.csr, ca, prov); err != nil {
render.Error(w, acme.WrapErrorISE(err, "error finalizing order")) render.Error(w, acme.WrapErrorISE(err, "error finalizing order"))
return return
} }
h.linker.LinkOrder(ctx, o) linker.LinkOrder(ctx, o)
w.Header().Set("Location", h.linker.GetLink(ctx, OrderLinkType, o.ID)) w.Header().Set("Location", linker.GetLink(ctx, acme.OrderLinkType, o.ID))
render.JSON(w, o) render.JSON(w, o)
} }

@ -16,9 +16,13 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/crypto/pemutil"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"go.step.sm/crypto/pemutil" "github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner"
) )
func TestNewOrderRequest_Validate(t *testing.T) { func TestNewOrderRequest_Validate(t *testing.T) {
@ -206,6 +210,13 @@ func TestFinalizeRequestValidate(t *testing.T) {
}, },
} }
}, },
"ok/padding": func(t *testing.T) test {
return test{
fr: &FinalizeRequest{
CSR: base64.RawURLEncoding.EncodeToString(csr.Raw) + "==", // add intentional padding
},
}
},
} }
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
@ -276,15 +287,17 @@ func TestHandler_GetOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -294,6 +307,7 @@ func TestHandler_GetOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -301,9 +315,10 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx := acme.NewProvisionerContext(context.Background(), nil)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -311,7 +326,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/db.GetOrder-error": func(t *testing.T) test { "fail/db.GetOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -325,7 +340,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -341,7 +356,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/provisioner-id-mismatch": func(t *testing.T) test { "fail/provisioner-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -357,7 +372,7 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"fail/order-update-error": func(t *testing.T) test { "fail/order-update-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
@ -381,10 +396,9 @@ func TestHandler_GetOrder(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) { MockGetOrder: func(ctx context.Context, id string) (*acme.Order, error) {
@ -421,11 +435,11 @@ func TestHandler_GetOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetOrder(w, req) GetOrder(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -636,8 +650,8 @@ func TestHandler_newAuthorization(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
h := &Handler{db: tc.db} ctx := newBaseContext(context.Background(), tc.db)
if err := h.newAuthorization(context.Background(), tc.az); err != nil { if err := newAuthorization(ctx, tc.az); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *acme.Error: case *acme.Error:
@ -667,6 +681,7 @@ func TestHandler_NewOrder(t *testing.T) {
baseURL.String(), escProvName) baseURL.String(), escProvName)
type test struct { type test struct {
ca acme.CertificateAuthority
db acme.DB db acme.DB
ctx context.Context ctx context.Context
nor *NewOrderRequest nor *NewOrderRequest
@ -677,15 +692,17 @@ func TestHandler_NewOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -695,6 +712,7 @@ func TestHandler_NewOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -702,9 +720,10 @@ func TestHandler_NewOrder(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -713,8 +732,9 @@ func TestHandler_NewOrder(t *testing.T) {
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -722,21 +742,23 @@ func TestHandler_NewOrder(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"), err: acme.NewErrorISE("payload does not exist"),
} }
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"), err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal new-order request payload: unexpected end of JSON input"),
@ -747,15 +769,232 @@ func TestHandler_NewOrder(t *testing.T) {
fr := &NewOrderRequest{} fr := &NewOrderRequest{}
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"), err: acme.NewError(acme.ErrorMalformedType, "identifiers list cannot be empty"),
} }
}, },
"fail/acmeProvisionerFromContext-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), &acme.MockProvisioner{})
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, errors.New("force")
},
},
err: acme.NewErrorISE("error retrieving external account binding key: force"),
}
},
"fail/db.GetExternalAccountKeyByAccountID-error": func(t *testing.T) test {
acmeProv := newACMEProv(t)
acmeProv.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, errors.New("force")
},
},
err: acme.NewErrorISE("error retrieving external account binding key: force"),
}
},
"fail/newACMEPolicyEngine-error": func(t *testing.T) test {
acmeProv := newACMEProv(t)
acmeProv.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"**.local"},
},
},
},
}, nil
},
},
err: acme.NewErrorISE("error creating ACME policy engine"),
}
},
"fail/isIdentifierAllowed-error": func(t *testing.T) test {
acmeProv := newACMEProv(t)
acmeProv.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), acmeProv)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.local"},
},
},
},
}, nil
},
},
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
}
},
"fail/prov.AuthorizeOrderIdentifier-error": func(t *testing.T) test {
options := &provisioner.Options{
X509: &provisioner.X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.local"},
},
},
}
provWithPolicy := newACMEProvWithOptions(t, options)
provWithPolicy.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
ca: &mockCA{},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.internal"},
},
},
},
}, nil
},
},
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
}
},
"fail/ca.AreSANsAllowed-error": func(t *testing.T) test {
options := &provisioner.Options{
X509: &provisioner.X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.internal"},
},
},
}
provWithPolicy := newACMEProvWithOptions(t, options)
provWithPolicy.RequireEAB = true
acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(fr)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{
ctx: ctx,
statusCode: 400,
ca: &mockCA{
MockAreSANsallowed: func(ctx context.Context, sans []string) error {
return errors.New("force: not authorized by authority")
},
},
db: &acme.MockDB{
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return &acme.ExternalAccountKey{
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.internal"},
},
},
},
}, nil
},
},
err: acme.NewError(acme.ErrorRejectedIdentifierType, "not authorized"),
}
},
"fail/error-h.newAuthorization": func(t *testing.T) test { "fail/error-h.newAuthorization": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
fr := &NewOrderRequest{ fr := &NewOrderRequest{
@ -765,12 +1004,13 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
assert.Equals(t, ch.AccountID, "accID") assert.Equals(t, ch.AccountID, "accID")
@ -780,6 +1020,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, ch.Value, "zap.internal") assert.Equals(t, ch.Value, "zap.internal")
return errors.New("force") return errors.New("force")
}, },
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
}, },
err: acme.NewErrorISE("error creating challenge: force"), err: acme.NewErrorISE("error creating challenge: force"),
} }
@ -793,7 +1038,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
var ( var (
@ -804,6 +1049,7 @@ func TestHandler_NewOrder(t *testing.T) {
return test{ return test{
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
ca: &mockCA{},
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count { switch count {
@ -849,6 +1095,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return errors.New("force") return errors.New("force")
}, },
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
}, },
err: acme.NewErrorISE("error creating order: force"), err: acme.NewErrorISE("error creating order: force"),
} }
@ -863,10 +1114,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3, ch4 **acme.Challenge ch1, ch2, ch3, ch4 **acme.Challenge
az1ID, az2ID *string az1ID, az2ID *string
@ -876,6 +1126,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx, ctx: ctx,
statusCode: 201, statusCode: 201,
nor: nor, nor: nor,
ca: &mockCA{},
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch chCount { switch chCount {
@ -945,6 +1196,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID}) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID, *az2ID})
return nil return nil
}, },
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
}, },
vr: func(t *testing.T, o *acme.Order) { vr: func(t *testing.T, o *acme.Order) {
now := clock.Now() now := clock.Now()
@ -978,10 +1234,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -991,6 +1246,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx, ctx: ctx,
statusCode: 201, statusCode: 201,
nor: nor, nor: nor,
ca: &mockCA{},
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count { switch count {
@ -1037,6 +1293,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil return nil
}, },
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
}, },
vr: func(t *testing.T, o *acme.Order) { vr: func(t *testing.T, o *acme.Order) {
now := clock.Now() now := clock.Now()
@ -1070,10 +1331,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1083,6 +1343,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx, ctx: ctx,
statusCode: 201, statusCode: 201,
nor: nor, nor: nor,
ca: &mockCA{},
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count { switch count {
@ -1129,6 +1390,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil return nil
}, },
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
}, },
vr: func(t *testing.T, o *acme.Order) { vr: func(t *testing.T, o *acme.Order) {
now := clock.Now() now := clock.Now()
@ -1161,10 +1427,9 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1174,6 +1439,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx, ctx: ctx,
statusCode: 201, statusCode: 201,
nor: nor, nor: nor,
ca: &mockCA{},
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count { switch count {
@ -1220,6 +1486,11 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil return nil
}, },
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
}, },
vr: func(t *testing.T, o *acme.Order) { vr: func(t *testing.T, o *acme.Order) {
testBufferDur := 5 * time.Second testBufferDur := 5 * time.Second
@ -1253,10 +1524,109 @@ func TestHandler_NewOrder(t *testing.T) {
} }
b, err := json.Marshal(nor) b, err := json.Marshal(nor)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
var (
ch1, ch2, ch3 **acme.Challenge
az1ID *string
count = 0
)
return test{
ctx: ctx,
statusCode: 201,
nor: nor,
ca: &mockCA{},
db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count {
case 0:
ch.ID = "dns"
assert.Equals(t, ch.Type, acme.DNS01)
ch1 = &ch
case 1:
ch.ID = "http"
assert.Equals(t, ch.Type, acme.HTTP01)
ch2 = &ch
case 2:
ch.ID = "tls"
assert.Equals(t, ch.Type, acme.TLSALPN01)
ch3 = &ch
default:
assert.FatalError(t, errors.New("test logic error"))
return errors.New("force")
}
count++
assert.Equals(t, ch.AccountID, "accID")
assert.NotEquals(t, ch.Token, "")
assert.Equals(t, ch.Status, acme.StatusPending)
assert.Equals(t, ch.Value, "zap.internal")
return nil
},
MockCreateAuthorization: func(ctx context.Context, az *acme.Authorization) error {
az.ID = "az1ID"
az1ID = &az.ID
assert.Equals(t, az.AccountID, "accID")
assert.NotEquals(t, az.Token, "")
assert.Equals(t, az.Status, acme.StatusPending)
assert.Equals(t, az.Identifier, nor.Identifiers[0])
assert.Equals(t, az.Challenges, []*acme.Challenge{*ch1, *ch2, *ch3})
assert.Equals(t, az.Wildcard, false)
return nil
},
MockCreateOrder: func(ctx context.Context, o *acme.Order) error {
o.ID = "ordID"
assert.Equals(t, o.AccountID, "accID")
assert.Equals(t, o.ProvisionerID, prov.GetID())
assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil
},
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
},
vr: func(t *testing.T, o *acme.Order) {
testBufferDur := 5 * time.Second
orderExpiry := now.Add(defaultOrderExpiry)
assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending)
assert.Equals(t, o.Identifiers, nor.Identifiers)
assert.Equals(t, o.AuthorizationURLs, []string{fmt.Sprintf("%s/acme/%s/authz/az1ID", baseURL.String(), escProvName)})
assert.True(t, o.NotBefore.Add(-testBufferDur).Before(expNbf))
assert.True(t, o.NotBefore.Add(testBufferDur).After(expNbf))
assert.True(t, o.NotAfter.Add(-testBufferDur).Before(expNaf))
assert.True(t, o.NotAfter.Add(testBufferDur).After(expNaf))
assert.True(t, o.ExpiresAt.Add(-testBufferDur).Before(orderExpiry))
assert.True(t, o.ExpiresAt.Add(testBufferDur).After(orderExpiry))
},
}
},
"ok/default-naf-nbf-with-policy": func(t *testing.T) test {
options := &provisioner.Options{
X509: &provisioner.X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.internal"},
},
},
}
provWithPolicy := newACMEProvWithOptions(t, options)
provWithPolicy.RequireEAB = true
acc := &acme.Account{ID: "accID"}
nor := &NewOrderRequest{
Identifiers: []acme.Identifier{
{Type: "dns", Value: "zap.internal"},
},
}
b, err := json.Marshal(nor)
assert.FatalError(t, err)
ctx := acme.NewProvisionerContext(context.Background(), provWithPolicy)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
var ( var (
ch1, ch2, ch3 **acme.Challenge ch1, ch2, ch3 **acme.Challenge
az1ID *string az1ID *string
@ -1266,6 +1636,7 @@ func TestHandler_NewOrder(t *testing.T) {
ctx: ctx, ctx: ctx,
statusCode: 201, statusCode: 201,
nor: nor, nor: nor,
ca: &mockCA{},
db: &acme.MockDB{ db: &acme.MockDB{
MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error { MockCreateChallenge: func(ctx context.Context, ch *acme.Challenge) error {
switch count { switch count {
@ -1312,10 +1683,18 @@ func TestHandler_NewOrder(t *testing.T) {
assert.Equals(t, o.AuthorizationIDs, []string{*az1ID}) assert.Equals(t, o.AuthorizationIDs, []string{*az1ID})
return nil return nil
}, },
MockGetExternalAccountKeyByAccountID: func(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, prov.GetID(), provisionerID)
assert.Equals(t, "accID", accountID)
return nil, nil
},
}, },
vr: func(t *testing.T, o *acme.Order) { vr: func(t *testing.T, o *acme.Order) {
now := clock.Now()
testBufferDur := 5 * time.Second testBufferDur := 5 * time.Second
orderExpiry := now.Add(defaultOrderExpiry) orderExpiry := now.Add(defaultOrderExpiry)
expNbf := now.Add(-defaultOrderBackdate)
expNaf := now.Add(prov.DefaultTLSCertDuration())
assert.Equals(t, o.ID, "ordID") assert.Equals(t, o.ID, "ordID")
assert.Equals(t, o.Status, acme.StatusPending) assert.Equals(t, o.Status, acme.StatusPending)
@ -1334,11 +1713,12 @@ func TestHandler_NewOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} mockMustAuthority(t, tc.ca)
ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.NewOrder(w, req) NewOrder(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1371,6 +1751,7 @@ func TestHandler_NewOrder(t *testing.T) {
} }
func TestHandler_FinalizeOrder(t *testing.T) { func TestHandler_FinalizeOrder(t *testing.T) {
mockMustAuthority(t, &mockCA{})
prov := newProv() prov := newProv()
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
@ -1429,15 +1810,17 @@ func TestHandler_FinalizeOrder(t *testing.T) {
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
return test{ return test{
ctx: context.WithValue(context.Background(), provisionerContextKey, prov), db: &acme.MockDB{},
ctx: acme.NewProvisionerContext(context.Background(), prov),
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"), err: acme.NewError(acme.ErrorAccountDoesNotExistType, "account does not exist"),
@ -1447,6 +1830,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -1454,9 +1838,10 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, nil) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -1465,8 +1850,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), accContextKey, acc) ctx := context.WithValue(context.Background(), accContextKey, acc)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -1474,21 +1860,23 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("paylod does not exist"), err: acme.NewErrorISE("payload does not exist"),
} }
}, },
"fail/unmarshal-payload-error": func(t *testing.T) test { "fail/unmarshal-payload-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accID"} acc := &acme.Account{ID: "accID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"), err: acme.NewError(acme.ErrorMalformedType, "failed to unmarshal finalize-order request payload: unexpected end of JSON input"),
@ -1499,10 +1887,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
fr := &FinalizeRequest{} fr := &FinalizeRequest{}
b, err := json.Marshal(fr) b, err := json.Marshal(fr)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: b})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"), err: acme.NewError(acme.ErrorMalformedType, "unable to parse csr: asn1: syntax error: sequence truncated"),
@ -1511,7 +1900,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
"fail/db.GetOrder-error": func(t *testing.T) test { "fail/db.GetOrder-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1526,7 +1915,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/account-id-mismatch": func(t *testing.T) test { "fail/account-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1543,7 +1932,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/provisioner-id-mismatch": func(t *testing.T) test { "fail/provisioner-id-mismatch": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1560,7 +1949,7 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"fail/order-finalize-error": func(t *testing.T) test { "fail/order-finalize-error": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
@ -1585,10 +1974,9 @@ func TestHandler_FinalizeOrder(t *testing.T) {
}, },
"ok": func(t *testing.T) test { "ok": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID"} acc := &acme.Account{ID: "accountID"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
return test{ return test{
db: &acme.MockDB{ db: &acme.MockDB{
@ -1624,11 +2012,11 @@ func TestHandler_FinalizeOrder(t *testing.T) {
for name, run := range tests { for name, run := range tests {
tc := run(t) tc := run(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
req := httptest.NewRequest("GET", u, nil) req := httptest.NewRequest("GET", u, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.FinalizeOrder(w, req) FinalizeOrder(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)

@ -26,9 +26,11 @@ type revokePayload struct {
} }
// RevokeCert attempts to revoke a certificate. // RevokeCert attempts to revoke a certificate.
func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) { func RevokeCert(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
db := acme.MustDatabaseFromContext(ctx)
linker := acme.MustLinkerFromContext(ctx)
jws, err := jwsFromContext(ctx) jws, err := jwsFromContext(ctx)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
@ -69,7 +71,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
} }
serial := certToBeRevoked.SerialNumber.String() serial := certToBeRevoked.SerialNumber.String()
dbCert, err := h.db.GetCertificateBySerial(ctx, serial) dbCert, err := db.GetCertificateBySerial(ctx, serial)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial")) render.Error(w, acme.WrapErrorISE(err, "error retrieving certificate by serial"))
return return
@ -87,7 +89,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
render.Error(w, err) render.Error(w, err)
return return
} }
acmeErr := h.isAccountAuthorized(ctx, dbCert, certToBeRevoked, account) acmeErr := isAccountAuthorized(ctx, dbCert, certToBeRevoked, account)
if acmeErr != nil { if acmeErr != nil {
render.Error(w, acmeErr) render.Error(w, acmeErr)
return return
@ -103,7 +105,8 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
} }
} }
hasBeenRevokedBefore, err := h.ca.IsRevoked(serial) ca := mustAuthority(ctx)
hasBeenRevokedBefore, err := ca.IsRevoked(serial)
if err != nil { if err != nil {
render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate")) render.Error(w, acme.WrapErrorISE(err, "error retrieving revocation status of certificate"))
return return
@ -130,14 +133,14 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
} }
options := revokeOptions(serial, certToBeRevoked, reasonCode) options := revokeOptions(serial, certToBeRevoked, reasonCode)
err = h.ca.Revoke(ctx, options) err = ca.Revoke(ctx, options)
if err != nil { if err != nil {
render.Error(w, wrapRevokeErr(err)) render.Error(w, wrapRevokeErr(err))
return return
} }
logRevoke(w, options) logRevoke(w, options)
w.Header().Add("Link", link(h.linker.GetLink(ctx, DirectoryLinkType), "index")) w.Header().Add("Link", link(linker.GetLink(ctx, acme.DirectoryLinkType), "index"))
w.Write(nil) w.Write(nil)
} }
@ -148,7 +151,7 @@ func (h *Handler) RevokeCert(w http.ResponseWriter, r *http.Request) {
// the identifiers in the certificate are extracted and compared against the (valid) Authorizations // the identifiers in the certificate are extracted and compared against the (valid) Authorizations
// that are stored for the ACME Account. If these sets match, the Account is considered authorized // that are stored for the ACME Account. If these sets match, the Account is considered authorized
// to revoke the certificate. If this check fails, the client will receive an unauthorized error. // to revoke the certificate. If this check fails, the client will receive an unauthorized error.
func (h *Handler) isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error { func isAccountAuthorized(ctx context.Context, dbCert *acme.Certificate, certToBeRevoked *x509.Certificate, account *acme.Account) *acme.Error {
if !account.IsValid() { if !account.IsValid() {
return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil) return wrapUnauthorizedError(certToBeRevoked, nil, fmt.Sprintf("account '%s' has status '%s'", account.ID, account.Status), nil)
} }

@ -24,14 +24,16 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ocsp"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority" "github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil"
"go.step.sm/crypto/x509util"
"golang.org/x/crypto/ocsp"
) )
// v is a utility function to return the pointer to an integer // v is a utility function to return the pointer to an integer
@ -274,14 +276,22 @@ func jwsFinal(sha crypto.Hash, sig []byte, phead, payload string) ([]byte, error
} }
type mockCA struct { type mockCA struct {
MockIsRevoked func(sn string) (bool, error) MockIsRevoked func(sn string) (bool, error)
MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error MockRevoke func(ctx context.Context, opts *authority.RevokeOptions) error
MockAreSANsallowed func(ctx context.Context, sans []string) error
} }
func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { func (m *mockCA) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return nil, nil return nil, nil
} }
func (m *mockCA) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.MockAreSANsallowed != nil {
return m.MockAreSANsallowed(ctx, sans)
}
return nil
}
func (m *mockCA) IsRevoked(sn string) (bool, error) { func (m *mockCA) IsRevoked(sn string) (bool, error) {
if m.MockIsRevoked != nil { if m.MockIsRevoked != nil {
return m.MockIsRevoked(sn) return m.MockIsRevoked(sn)
@ -511,6 +521,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/no-jws": func(t *testing.T) test { "fail/no-jws": func(t *testing.T) test {
ctx := context.Background() ctx := context.Background()
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -519,6 +530,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/nil-jws": func(t *testing.T) test { "fail/nil-jws": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, nil) ctx := context.WithValue(context.Background(), jwsContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("jws expected in request context"), err: acme.NewErrorISE("jws expected in request context"),
@ -527,6 +539,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/no-provisioner": func(t *testing.T) test { "fail/no-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -534,8 +547,9 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/nil-provisioner": func(t *testing.T) test { "fail/nil-provisioner": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, nil) ctx = acme.NewProvisionerContext(ctx, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("provisioner does not exist"), err: acme.NewErrorISE("provisioner does not exist"),
@ -543,8 +557,9 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/no-payload": func(t *testing.T) test { "fail/no-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -552,9 +567,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/nil-payload": func(t *testing.T) test { "fail/nil-payload": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, payloadContextKey, nil) ctx = context.WithValue(ctx, payloadContextKey, nil)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("payload does not exist"), err: acme.NewErrorISE("payload does not exist"),
@ -563,9 +579,10 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/unmarshal-payload": func(t *testing.T) test { "fail/unmarshal-payload": func(t *testing.T) test {
malformedPayload := []byte(`{"payload":malformed?}`) malformedPayload := []byte(`{"payload":malformed?}`)
ctx := context.WithValue(context.Background(), jwsContextKey, jws) ctx := context.WithValue(context.Background(), jwsContextKey, jws)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = acme.NewProvisionerContext(ctx, prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: malformedPayload})
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 500, statusCode: 500,
err: acme.NewErrorISE("error unmarshaling payload"), err: acme.NewErrorISE("error unmarshaling payload"),
@ -577,10 +594,11 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload) wronglyEncodedPayloadBytes, err := json.Marshal(wrongPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: wronglyEncodedPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: &acme.Error{ err: &acme.Error{
@ -596,10 +614,11 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
emptyPayloadBytes, err := json.Marshal(emptyPayload) emptyPayloadBytes, err := json.Marshal(emptyPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: emptyPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
return test{ return test{
db: &acme.MockDB{},
ctx: ctx, ctx: ctx,
statusCode: 400, statusCode: 400,
err: &acme.Error{ err: &acme.Error{
@ -610,7 +629,7 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/db.GetCertificateBySerial": func(t *testing.T) test { "fail/db.GetCertificateBySerial": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{ db := &acme.MockDB{
@ -628,7 +647,7 @@ func TestHandler_RevokeCert(t *testing.T) {
"fail/different-certificate-contents": func(t *testing.T) test { "fail/different-certificate-contents": func(t *testing.T) test {
aDifferentCert, _, err := generateCertKeyPair() aDifferentCert, _, err := generateCertKeyPair()
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{ db := &acme.MockDB{
@ -647,7 +666,7 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/no-account": func(t *testing.T) test { "fail/no-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
db := &acme.MockDB{ db := &acme.MockDB{
@ -666,7 +685,7 @@ func TestHandler_RevokeCert(t *testing.T) {
} }
}, },
"fail/nil-account": func(t *testing.T) test { "fail/nil-account": func(t *testing.T) test {
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, accContextKey, nil) ctx = context.WithValue(ctx, accContextKey, nil)
@ -687,11 +706,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/account-not-valid": func(t *testing.T) test { "fail/account-not-valid": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid} acc := &acme.Account{ID: "accountID", Status: acme.StatusInvalid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -717,11 +735,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/account-not-authorized": func(t *testing.T) test { "fail/account-not-authorized": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -771,10 +788,9 @@ func TestHandler_RevokeCert(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
unauthorizedPayloadBytes, err := json.Marshal(jwsPayload) unauthorizedPayloadBytes, err := json.Marshal(jwsPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: unauthorizedPayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, parsedJWS) ctx = context.WithValue(ctx, jwsContextKey, parsedJWS)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -798,11 +814,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/certificate-revoked-check-fails": func(t *testing.T) test { "fail/certificate-revoked-check-fails": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -832,7 +847,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/certificate-already-revoked": func(t *testing.T) test { "fail/certificate-already-revoked": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -870,7 +885,7 @@ func TestHandler_RevokeCert(t *testing.T) {
invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload) invalidReasonCodePayloadBytes, err := json.Marshal(invalidReasonPayload)
assert.FatalError(t, err) assert.FatalError(t, err)
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: invalidReasonCodePayloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -908,7 +923,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
} }
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, mockACMEProv) ctx := acme.NewProvisionerContext(context.Background(), mockACMEProv)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -940,7 +955,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/ca.Revoke": func(t *testing.T) test { "fail/ca.Revoke": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -972,7 +987,7 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"fail/ca.Revoke-already-revoked": func(t *testing.T) test { "fail/ca.Revoke-already-revoked": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
@ -1003,11 +1018,10 @@ func TestHandler_RevokeCert(t *testing.T) {
}, },
"ok/using-account-key": func(t *testing.T) test { "ok/using-account-key": func(t *testing.T) test {
acc := &acme.Account{ID: "accountID", Status: acme.StatusValid} acc := &acme.Account{ID: "accountID", Status: acme.StatusValid}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, accContextKey, acc) ctx = context.WithValue(ctx, accContextKey, acc)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -1031,10 +1045,9 @@ func TestHandler_RevokeCert(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
jws, err := jose.ParseJWS(string(jwsBytes)) jws, err := jose.ParseJWS(string(jwsBytes))
assert.FatalError(t, err) assert.FatalError(t, err)
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := acme.NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes}) ctx = context.WithValue(ctx, payloadContextKey, &payloadInfo{value: payloadBytes})
ctx = context.WithValue(ctx, jwsContextKey, jws) ctx = context.WithValue(ctx, jwsContextKey, jws)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL)
ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx) ctx = context.WithValue(ctx, chi.RouteCtxKey, chiCtx)
db := &acme.MockDB{ db := &acme.MockDB{
MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) { MockGetCertificateBySerial: func(ctx context.Context, serial string) (*acme.Certificate, error) {
@ -1057,11 +1070,12 @@ func TestHandler_RevokeCert(t *testing.T) {
for name, setup := range tests { for name, setup := range tests {
tc := setup(t) tc := setup(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{linker: NewLinker("dns", "acme"), db: tc.db, ca: tc.ca} ctx := newBaseContext(tc.ctx, tc.db, acme.NewLinker("test.ca.smallstep.com", "acme"))
mockMustAuthority(t, tc.ca)
req := httptest.NewRequest("POST", revokeURL, nil) req := httptest.NewRequest("POST", revokeURL, nil)
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.RevokeCert(w, req) RevokeCert(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, res.StatusCode, tc.statusCode) assert.Equals(t, res.StatusCode, tc.statusCode)
@ -1198,8 +1212,8 @@ func TestHandler_isAccountAuthorized(t *testing.T) {
for name, setup := range tests { for name, setup := range tests {
tc := setup(t) tc := setup(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{db: tc.db} // h := &Handler{db: tc.db}
acmeErr := h.isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account) acmeErr := isAccountAuthorized(tc.ctx, tc.existingCert, tc.certToBeRevoked, tc.account)
expectError := tc.err != nil expectError := tc.err != nil
gotError := acmeErr != nil gotError := acmeErr != nil

@ -14,7 +14,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http"
"net/url" "net/url"
"reflect" "reflect"
"strings" "strings"
@ -61,27 +60,28 @@ func (ch *Challenge) ToLog() (interface{}, error) {
// type using the DB interface. // type using the DB interface.
// satisfactorily validated, the 'status' and 'validated' attributes are // satisfactorily validated, the 'status' and 'validated' attributes are
// updated. // updated.
func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func (ch *Challenge) Validate(ctx context.Context, db DB, jwk *jose.JSONWebKey) error {
// If already valid or invalid then return without performing validation. // If already valid or invalid then return without performing validation.
if ch.Status != StatusPending { if ch.Status != StatusPending {
return nil return nil
} }
switch ch.Type { switch ch.Type {
case HTTP01: case HTTP01:
return http01Validate(ctx, ch, db, jwk, vo) return http01Validate(ctx, ch, db, jwk)
case DNS01: case DNS01:
return dns01Validate(ctx, ch, db, jwk, vo) return dns01Validate(ctx, ch, db, jwk)
case TLSALPN01: case TLSALPN01:
return tlsalpn01Validate(ctx, ch, db, jwk, vo) return tlsalpn01Validate(ctx, ch, db, jwk)
default: default:
return NewErrorISE("unexpected challenge type '%s'", ch.Type) return NewErrorISE("unexpected challenge type '%s'", ch.Type)
} }
} }
func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)} u := &url.URL{Scheme: "http", Host: http01ChallengeHost(ch.Value), Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
resp, err := vo.HTTPGet(u.String()) vc := MustClientFromContext(ctx)
resp, err := vc.Get(u.String())
if err != nil { if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err, return storeError(ctx, db, ch, false, WrapError(ErrorConnectionType, err,
"error doing http GET for url %s", u)) "error doing http GET for url %s", u))
@ -119,6 +119,17 @@ func http01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWeb
return nil return nil
} }
// http01ChallengeHost checks if a Challenge value is an IPv6 address
// and adds square brackets if that's the case, so that it can be used
// as a hostname. Returns the original Challenge value as the host to
// use in other cases.
func http01ChallengeHost(value string) string {
if ip := net.ParseIP(value); ip != nil && ip.To4() == nil {
value = "[" + value + "]"
}
return value
}
func tlsAlert(err error) uint8 { func tlsAlert(err error) uint8 {
var opErr *net.OpError var opErr *net.OpError
if errors.As(err, &opErr) { if errors.As(err, &opErr) {
@ -130,7 +141,7 @@ func tlsAlert(err error) uint8 {
return 0 return 0
} }
func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
config := &tls.Config{ config := &tls.Config{
NextProtos: []string{"acme-tls/1"}, NextProtos: []string{"acme-tls/1"},
// https://tools.ietf.org/html/rfc8737#section-4 // https://tools.ietf.org/html/rfc8737#section-4
@ -143,7 +154,8 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
hostPort := net.JoinHostPort(ch.Value, "443") hostPort := net.JoinHostPort(ch.Value, "443")
conn, err := vo.TLSDial("tcp", hostPort, config) vc := MustClientFromContext(ctx)
conn, err := vc.TLSDial("tcp", hostPort, config)
if err != nil { if err != nil {
// With Go 1.17+ tls.Dial fails if there's no overlap between configured // With Go 1.17+ tls.Dial fails if there's no overlap between configured
// client and server protocols. When this happens the connection is // client and server protocols. When this happens the connection is
@ -242,14 +254,15 @@ func tlsalpn01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSON
"incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension")) "incorrect certificate for tls-alpn-01 challenge: missing acmeValidationV1 extension"))
} }
func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey, vo *ValidateChallengeOptions) error { func dns01Validate(ctx context.Context, ch *Challenge, db DB, jwk *jose.JSONWebKey) error {
// Normalize domain for wildcard DNS names // Normalize domain for wildcard DNS names
// This is done to avoid making TXT lookups for domains like // This is done to avoid making TXT lookups for domains like
// _acme-challenge.*.example.com // _acme-challenge.*.example.com
// Instead perform txt lookup for _acme-challenge.example.com // Instead perform txt lookup for _acme-challenge.example.com
domain := strings.TrimPrefix(ch.Value, "*.") domain := strings.TrimPrefix(ch.Value, "*.")
txtRecords, err := vo.LookupTxt("_acme-challenge." + domain) vc := MustClientFromContext(ctx)
txtRecords, err := vc.LookupTxt("_acme-challenge." + domain)
if err != nil { if err != nil {
return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err, return storeError(ctx, db, ch, false, WrapError(ErrorDNSType, err,
"error looking up TXT records for domain %s", domain)) "error looking up TXT records for domain %s", domain))
@ -365,14 +378,3 @@ func storeError(ctx context.Context, db DB, ch *Challenge, markInvalid bool, err
} }
return nil return nil
} }
type httpGetter func(string) (*http.Response, error)
type lookupTxt func(string) ([]string, error)
type tlsDialer func(network, addr string, config *tls.Config) (*tls.Conn, error)
// ValidateChallengeOptions are ACME challenge validator functions.
type ValidateChallengeOptions struct {
HTTPGet httpGetter
LookupTxt lookupTxt
TLSDial tlsDialer
}

@ -13,6 +13,7 @@ import (
"encoding/asn1" "encoding/asn1"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"io" "io"
"math/big" "math/big"
@ -23,11 +24,23 @@ import (
"testing" "testing"
"time" "time"
"github.com/pkg/errors"
"github.com/smallstep/assert"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"github.com/smallstep/assert"
) )
type mockClient struct {
get func(url string) (*http.Response, error)
lookupTxt func(name string) ([]string, error)
tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error)
}
func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) }
func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) }
func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return m.tlsDial(network, addr, config)
}
func Test_storeError(t *testing.T) { func Test_storeError(t *testing.T) {
type test struct { type test struct {
ch *Challenge ch *Challenge
@ -228,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) {
func TestChallenge_Validate(t *testing.T) { func TestChallenge_Validate(t *testing.T) {
type test struct { type test struct {
ch *Challenge ch *Challenge
vo *ValidateChallengeOptions vc Client
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
srv *httptest.Server srv *httptest.Server
@ -272,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -308,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -343,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -380,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -415,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) {
} }
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -465,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -492,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) {
defer tc.srv.Close() defer tc.srv.Close()
} }
if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -523,7 +537,7 @@ func (errReader) Close() error {
func TestHTTP01Validate(t *testing.T) { func TestHTTP01Validate(t *testing.T) {
type test struct { type test struct {
vo *ValidateChallengeOptions vc Client
ch *Challenge ch *Challenge
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
@ -540,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -574,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -607,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
StatusCode: http.StatusBadRequest, StatusCode: http.StatusBadRequest,
Body: errReader(0), Body: errReader(0),
@ -644,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
StatusCode: http.StatusBadRequest, StatusCode: http.StatusBadRequest,
Body: errReader(0), Body: errReader(0),
@ -680,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: errReader(0), Body: errReader(0),
}, nil }, nil
@ -703,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) {
jwk.Key = "foo" jwk.Key = "foo"
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
@ -729,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
@ -771,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString("foo")), Body: io.NopCloser(bytes.NewBufferString("foo")),
}, nil }, nil
@ -814,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
@ -856,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
HTTPGet: func(url string) (*http.Response, error) { get: func(url string) (*http.Response, error) {
return &http.Response{ return &http.Response{
Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)),
}, nil }, nil
@ -886,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -910,7 +925,7 @@ func TestDNS01Validate(t *testing.T) {
fulldomain := "*.zap.internal" fulldomain := "*.zap.internal"
domain := strings.TrimPrefix(fulldomain, "*.") domain := strings.TrimPrefix(fulldomain, "*.")
type test struct { type test struct {
vo *ValidateChallengeOptions vc Client
ch *Challenge ch *Challenge
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
@ -927,8 +942,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -962,8 +977,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -1000,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo"}, nil return []string{"foo"}, nil
}, },
}, },
@ -1025,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", "bar"}, nil return []string{"foo", "bar"}, nil
}, },
}, },
@ -1067,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", "bar"}, nil return []string{"foo", "bar"}, nil
}, },
}, },
@ -1110,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", expected}, nil return []string{"foo", expected}, nil
}, },
}, },
@ -1155,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
LookupTxt: func(url string) ([]string, error) { lookupTxt: func(url string) ([]string, error) {
return []string{"foo", expected}, nil return []string{"foo", expected}, nil
}, },
}, },
@ -1185,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) {
for name, run := range tests { for name, run := range tests {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
tc := run(t) tc := run(t)
if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -1205,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) {
} }
} }
type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error)
func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) {
srv := httptest.NewUnstartedServer(http.NewServeMux()) srv := httptest.NewUnstartedServer(http.NewServeMux())
@ -1308,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) {
} }
} }
type test struct { type test struct {
vo *ValidateChallengeOptions vc Client
ch *Challenge ch *Challenge
jwk *jose.JSONWebKey jwk *jose.JSONWebKey
db DB db DB
@ -1320,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) {
ch := makeTLSCh() ch := makeTLSCh()
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -1350,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) {
ch := makeTLSCh() ch := makeTLSCh()
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return nil, errors.New("force") return nil, errors.New("force")
}, },
}, },
@ -1383,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1412,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.Client(&noopConn{}, config), nil return tls.Client(&noopConn{}, config), nil
}, },
}, },
@ -1442,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.Client(&noopConn{}, config), nil return tls.Client(&noopConn{}, config), nil
}, },
}, },
@ -1478,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
}, },
}, },
@ -1515,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config)
}, },
}, },
@ -1561,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1604,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1648,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1691,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1735,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
srv: srv, srv: srv,
jwk: jwk, jwk: jwk,
@ -1757,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1796,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1840,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1883,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1923,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -1962,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2007,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2053,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2099,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2143,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2188,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2225,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) {
return test{ return test{
ch: ch, ch: ch,
vo: &ValidateChallengeOptions{ vc: &mockClient{
TLSDial: tlsDial, tlsDial: tlsDial,
}, },
db: &MockDB{ db: &MockDB{
MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error {
@ -2252,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) {
defer tc.srv.Close() defer tc.srv.Close()
} }
if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { ctx := NewClientContext(context.Background(), tc.vc)
if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil {
if assert.NotNil(t, tc.err) { if assert.NotNil(t, tc.err) {
switch k := err.(type) { switch k := err.(type) {
case *Error: case *Error:
@ -2350,3 +2369,34 @@ func Test_serverName(t *testing.T) {
}) })
} }
} }
func Test_http01ChallengeHost(t *testing.T) {
tests := []struct {
name string
value string
want string
}{
{
name: "dns",
value: "www.example.com",
want: "www.example.com",
},
{
name: "ipv4",
value: "127.0.0.1",
want: "127.0.0.1",
},
{
name: "ipv6",
value: "::1",
want: "[::1]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := http01ChallengeHost(tt.value); got != tt.want {
t.Errorf("http01ChallengeHost() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,79 @@
package acme
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
)
// Client is the interface used to verify ACME challenges.
type Client interface {
// Get issues an HTTP GET to the specified URL.
Get(url string) (*http.Response, error)
// LookupTXT returns the DNS TXT records for the given domain name.
LookupTxt(name string) ([]string, error)
// TLSDial connects to the given network address using net.Dialer and then
// initiates a TLS handshake, returning the resulting TLS connection.
TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error)
}
type clientKey struct{}
// NewClientContext adds the given client to the context.
func NewClientContext(ctx context.Context, c Client) context.Context {
return context.WithValue(ctx, clientKey{}, c)
}
// ClientFromContext returns the current client from the given context.
func ClientFromContext(ctx context.Context) (c Client, ok bool) {
c, ok = ctx.Value(clientKey{}).(Client)
return
}
// MustClientFromContext returns the current client from the given context. It will
// return a new instance of the client if it does not exist.
func MustClientFromContext(ctx context.Context) Client {
c, ok := ClientFromContext(ctx)
if !ok {
return NewClient()
}
return c
}
type client struct {
http *http.Client
dialer *net.Dialer
}
// NewClient returns an implementation of Client for verifying ACME challenges.
func NewClient() Client {
return &client{
http: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
},
dialer: &net.Dialer{
Timeout: 30 * time.Second,
},
}
}
func (c *client) Get(url string) (*http.Response, error) {
return c.http.Get(url)
}
func (c *client) LookupTxt(name string) ([]string, error) {
return net.LookupTXT(name)
}
func (c *client) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.DialWithDialer(c.dialer, network, addr, config)
}

@ -9,27 +9,66 @@ import (
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
) )
// Clock that returns time in UTC rounded to seconds.
type Clock struct{}
// Now returns the UTC time rounded to seconds.
func (c *Clock) Now() time.Time {
return time.Now().UTC().Truncate(time.Second)
}
var clock Clock
// CertificateAuthority is the interface implemented by a CA authority. // CertificateAuthority is the interface implemented by a CA authority.
type CertificateAuthority interface { type CertificateAuthority interface {
Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) Sign(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
AreSANsAllowed(ctx context.Context, sans []string) error
IsRevoked(sn string) (bool, error) IsRevoked(sn string) (bool, error)
Revoke(context.Context, *authority.RevokeOptions) error Revoke(context.Context, *authority.RevokeOptions) error
LoadProvisionerByName(string) (provisioner.Interface, error) LoadProvisionerByName(string) (provisioner.Interface, error)
} }
// Clock that returns time in UTC rounded to seconds. // NewContext adds the given acme components to the context.
type Clock struct{} func NewContext(ctx context.Context, db DB, client Client, linker Linker, fn PrerequisitesChecker) context.Context {
ctx = NewDatabaseContext(ctx, db)
ctx = NewClientContext(ctx, client)
ctx = NewLinkerContext(ctx, linker)
// Prerequisite checker is optional.
if fn != nil {
ctx = NewPrerequisitesCheckerContext(ctx, fn)
}
return ctx
}
// Now returns the UTC time rounded to seconds. // PrerequisitesChecker is a function that checks if all prerequisites for
func (c *Clock) Now() time.Time { // serving ACME are met by the CA configuration.
return time.Now().UTC().Truncate(time.Second) type PrerequisitesChecker func(ctx context.Context) (bool, error)
// DefaultPrerequisitesChecker is the default PrerequisiteChecker and returns
// always true.
func DefaultPrerequisitesChecker(ctx context.Context) (bool, error) {
return true, nil
} }
var clock Clock type prerequisitesKey struct{}
// NewPrerequisitesCheckerContext adds the given PrerequisitesChecker to the
// context.
func NewPrerequisitesCheckerContext(ctx context.Context, fn PrerequisitesChecker) context.Context {
return context.WithValue(ctx, prerequisitesKey{}, fn)
}
// PrerequisitesCheckerFromContext returns the PrerequisitesChecker in the
// context.
func PrerequisitesCheckerFromContext(ctx context.Context) (PrerequisitesChecker, bool) {
fn, ok := ctx.Value(prerequisitesKey{}).(PrerequisitesChecker)
return fn, ok && fn != nil
}
// Provisioner is an interface that implements a subset of the provisioner.Interface -- // Provisioner is an interface that implements a subset of the provisioner.Interface --
// only those methods required by the ACME api/authority. // only those methods required by the ACME api/authority.
type Provisioner interface { type Provisioner interface {
AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error
AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error) AuthorizeSign(ctx context.Context, token string) ([]provisioner.SignOption, error)
AuthorizeRevoke(ctx context.Context, token string) error AuthorizeRevoke(ctx context.Context, token string) error
GetID() string GetID() string
@ -38,16 +77,40 @@ type Provisioner interface {
GetOptions() *provisioner.Options GetOptions() *provisioner.Options
} }
type provisionerKey struct{}
// NewProvisionerContext adds the given provisioner to the context.
func NewProvisionerContext(ctx context.Context, v Provisioner) context.Context {
return context.WithValue(ctx, provisionerKey{}, v)
}
// ProvisionerFromContext returns the current provisioner from the given context.
func ProvisionerFromContext(ctx context.Context) (v Provisioner, ok bool) {
v, ok = ctx.Value(provisionerKey{}).(Provisioner)
return
}
// MustLinkerFromContext returns the current provisioner from the given context.
// It will panic if it's not in the context.
func MustProvisionerFromContext(ctx context.Context) Provisioner {
if v, ok := ProvisionerFromContext(ctx); !ok {
panic("acme provisioner is not the context")
} else {
return v
}
}
// MockProvisioner for testing // MockProvisioner for testing
type MockProvisioner struct { type MockProvisioner struct {
Mret1 interface{} Mret1 interface{}
Merr error Merr error
MgetID func() string MgetID func() string
MgetName func() string MgetName func() string
MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error) MauthorizeOrderIdentifier func(ctx context.Context, identifier provisioner.ACMEIdentifier) error
MauthorizeRevoke func(ctx context.Context, token string) error MauthorizeSign func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
MdefaultTLSCertDuration func() time.Duration MauthorizeRevoke func(ctx context.Context, token string) error
MgetOptions func() *provisioner.Options MdefaultTLSCertDuration func() time.Duration
MgetOptions func() *provisioner.Options
} }
// GetName mock // GetName mock
@ -58,6 +121,14 @@ func (m *MockProvisioner) GetName() string {
return m.Mret1.(string) return m.Mret1.(string)
} }
// AuthorizeOrderIdentifiers mock
func (m *MockProvisioner) AuthorizeOrderIdentifier(ctx context.Context, identifier provisioner.ACMEIdentifier) error {
if m.MauthorizeOrderIdentifier != nil {
return m.MauthorizeOrderIdentifier(ctx, identifier)
}
return m.Merr
}
// AuthorizeSign mock // AuthorizeSign mock
func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) { func (m *MockProvisioner) AuthorizeSign(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
if m.MauthorizeSign != nil { if m.MauthorizeSign != nil {

@ -23,6 +23,7 @@ type DB interface {
GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) GetExternalAccountKey(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) GetExternalAccountKeyByReference(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error
UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
@ -48,6 +49,29 @@ type DB interface {
UpdateOrder(ctx context.Context, o *Order) error UpdateOrder(ctx context.Context, o *Order) error
} }
type dbKey struct{}
// NewDatabaseContext adds the given acme database to the context.
func NewDatabaseContext(ctx context.Context, db DB) context.Context {
return context.WithValue(ctx, dbKey{}, db)
}
// DatabaseFromContext returns the current acme database from the given context.
func DatabaseFromContext(ctx context.Context) (db DB, ok bool) {
db, ok = ctx.Value(dbKey{}).(DB)
return
}
// MustDatabaseFromContext returns the current database from the given context.
// It will panic if it's not in the context.
func MustDatabaseFromContext(ctx context.Context) DB {
if db, ok := DatabaseFromContext(ctx); !ok {
panic("acme database is not in the context")
} else {
return db
}
}
// MockDB is an implementation of the DB interface that should only be used as // MockDB is an implementation of the DB interface that should only be used as
// a mock in tests. // a mock in tests.
type MockDB struct { type MockDB struct {
@ -60,6 +84,7 @@ type MockDB struct {
MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error) MockGetExternalAccountKey func(ctx context.Context, provisionerID, keyID string) (*ExternalAccountKey, error)
MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error) MockGetExternalAccountKeys func(ctx context.Context, provisionerID, cursor string, limit int) ([]*ExternalAccountKey, string, error)
MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error) MockGetExternalAccountKeyByReference func(ctx context.Context, provisionerID, reference string) (*ExternalAccountKey, error)
MockGetExternalAccountKeyByAccountID func(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error)
MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error MockDeleteExternalAccountKey func(ctx context.Context, provisionerID, keyID string) error
MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error MockUpdateExternalAccountKey func(ctx context.Context, provisionerID string, eak *ExternalAccountKey) error
@ -168,6 +193,16 @@ func (m *MockDB) GetExternalAccountKeyByReference(ctx context.Context, provision
return m.MockRet1.(*ExternalAccountKey), m.MockError return m.MockRet1.(*ExternalAccountKey), m.MockError
} }
// GetExternalAccountKeyByAccountID mock
func (m *MockDB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*ExternalAccountKey, error) {
if m.MockGetExternalAccountKeyByAccountID != nil {
return m.MockGetExternalAccountKeyByAccountID(ctx, provisionerID, accountID)
} else if m.MockError != nil {
return nil, m.MockError
}
return m.MockRet1.(*ExternalAccountKey), m.MockError
}
// DeleteExternalAccountKey mock // DeleteExternalAccountKey mock
func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error { func (m *MockDB) DeleteExternalAccountKey(ctx context.Context, provisionerID, keyID string) error {
if m.MockDeleteExternalAccountKey != nil { if m.MockDeleteExternalAccountKey != nil {

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
nosqlDB "github.com/smallstep/nosql" nosqlDB "github.com/smallstep/nosql"
) )
@ -23,7 +24,7 @@ type dbExternalAccountKey struct {
ProvisionerID string `json:"provisionerID"` ProvisionerID string `json:"provisionerID"`
Reference string `json:"reference"` Reference string `json:"reference"`
AccountID string `json:"accountID,omitempty"` AccountID string `json:"accountID,omitempty"`
KeyBytes []byte `json:"key"` HmacKey []byte `json:"key"`
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
BoundAt time.Time `json:"boundAt"` BoundAt time.Time `json:"boundAt"`
} }
@ -72,7 +73,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
ID: keyID, ID: keyID,
ProvisionerID: provisionerID, ProvisionerID: provisionerID,
Reference: reference, Reference: reference,
KeyBytes: random, HmacKey: random,
CreatedAt: clock.Now(), CreatedAt: clock.Now(),
} }
@ -99,7 +100,7 @@ func (db *DB) CreateExternalAccountKey(ctx context.Context, provisionerID, refer
ProvisionerID: dbeak.ProvisionerID, ProvisionerID: dbeak.ProvisionerID,
Reference: dbeak.Reference, Reference: dbeak.Reference,
AccountID: dbeak.AccountID, AccountID: dbeak.AccountID,
KeyBytes: dbeak.KeyBytes, HmacKey: dbeak.HmacKey,
CreatedAt: dbeak.CreatedAt, CreatedAt: dbeak.CreatedAt,
BoundAt: dbeak.BoundAt, BoundAt: dbeak.BoundAt,
}, nil }, nil
@ -124,7 +125,7 @@ func (db *DB) GetExternalAccountKey(ctx context.Context, provisionerID, keyID st
ProvisionerID: dbeak.ProvisionerID, ProvisionerID: dbeak.ProvisionerID,
Reference: dbeak.Reference, Reference: dbeak.Reference,
AccountID: dbeak.AccountID, AccountID: dbeak.AccountID,
KeyBytes: dbeak.KeyBytes, HmacKey: dbeak.HmacKey,
CreatedAt: dbeak.CreatedAt, CreatedAt: dbeak.CreatedAt,
BoundAt: dbeak.BoundAt, BoundAt: dbeak.BoundAt,
}, nil }, nil
@ -191,7 +192,7 @@ func (db *DB) GetExternalAccountKeys(ctx context.Context, provisionerID, cursor
} }
keys = append(keys, &acme.ExternalAccountKey{ keys = append(keys, &acme.ExternalAccountKey{
ID: eak.ID, ID: eak.ID,
KeyBytes: eak.KeyBytes, HmacKey: eak.HmacKey,
ProvisionerID: eak.ProvisionerID, ProvisionerID: eak.ProvisionerID,
Reference: eak.Reference, Reference: eak.Reference,
AccountID: eak.AccountID, AccountID: eak.AccountID,
@ -226,6 +227,10 @@ func (db *DB) GetExternalAccountKeyByReference(ctx context.Context, provisionerI
return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID) return db.GetExternalAccountKey(ctx, provisionerID, dbExternalAccountKeyReference.ExternalAccountKeyID)
} }
func (db *DB) GetExternalAccountKeyByAccountID(ctx context.Context, provisionerID, accountID string) (*acme.ExternalAccountKey, error) {
return nil, nil
}
func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error { func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string, eak *acme.ExternalAccountKey) error {
externalAccountKeyMutex.Lock() externalAccountKeyMutex.Lock()
defer externalAccountKeyMutex.Unlock() defer externalAccountKeyMutex.Unlock()
@ -252,7 +257,7 @@ func (db *DB) UpdateExternalAccountKey(ctx context.Context, provisionerID string
ProvisionerID: eak.ProvisionerID, ProvisionerID: eak.ProvisionerID,
Reference: eak.Reference, Reference: eak.Reference,
AccountID: eak.AccountID, AccountID: eak.AccountID,
KeyBytes: eak.KeyBytes, HmacKey: eak.HmacKey,
CreatedAt: eak.CreatedAt, CreatedAt: eak.CreatedAt,
BoundAt: eak.BoundAt, BoundAt: eak.BoundAt,
} }

@ -8,6 +8,7 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
certdb "github.com/smallstep/certificates/db" certdb "github.com/smallstep/certificates/db"
@ -32,7 +33,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: "ref", Reference: "ref",
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b, err := json.Marshal(dbeak) b, err := json.Marshal(dbeak)
@ -108,7 +109,7 @@ func TestDB_getDBExternalAccountKey(t *testing.T) {
} }
} else if assert.Nil(t, tc.err) { } else if assert.Nil(t, tc.err) {
assert.Equals(t, dbeak.ID, tc.dbeak.ID) assert.Equals(t, dbeak.ID, tc.dbeak.ID)
assert.Equals(t, dbeak.KeyBytes, tc.dbeak.KeyBytes) assert.Equals(t, dbeak.HmacKey, tc.dbeak.HmacKey)
assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID) assert.Equals(t, dbeak.ProvisionerID, tc.dbeak.ProvisionerID)
assert.Equals(t, dbeak.Reference, tc.dbeak.Reference) assert.Equals(t, dbeak.Reference, tc.dbeak.Reference)
assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt) assert.Equals(t, dbeak.CreatedAt, tc.dbeak.CreatedAt)
@ -136,7 +137,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: "ref", Reference: "ref",
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b, err := json.Marshal(dbeak) b, err := json.Marshal(dbeak)
@ -154,7 +155,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: "ref", Reference: "ref",
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
}, },
} }
@ -179,7 +180,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: "aDifferentProvID", ProvisionerID: "aDifferentProvID",
Reference: "ref", Reference: "ref",
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b, err := json.Marshal(dbeak) b, err := json.Marshal(dbeak)
@ -197,7 +198,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: "ref", Reference: "ref",
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
}, },
acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"), acmeErr: acme.NewError(acme.ErrorUnauthorizedType, "provisioner does not match provisioner for which the EAB key was created"),
@ -225,7 +226,7 @@ func TestDB_GetExternalAccountKey(t *testing.T) {
} }
} else if assert.Nil(t, tc.err) { } else if assert.Nil(t, tc.err) {
assert.Equals(t, eak.ID, tc.eak.ID) assert.Equals(t, eak.ID, tc.eak.ID)
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes) assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
assert.Equals(t, eak.Reference, tc.eak.Reference) assert.Equals(t, eak.Reference, tc.eak.Reference)
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
@ -255,7 +256,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
dbref := &dbExternalAccountKeyReference{ dbref := &dbExternalAccountKeyReference{
@ -288,7 +289,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
}, },
err: nil, err: nil,
@ -392,7 +393,7 @@ func TestDB_GetExternalAccountKeyByReference(t *testing.T) {
assert.Equals(t, eak.AccountID, tc.eak.AccountID) assert.Equals(t, eak.AccountID, tc.eak.AccountID)
assert.Equals(t, eak.BoundAt, tc.eak.BoundAt) assert.Equals(t, eak.BoundAt, tc.eak.BoundAt)
assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, eak.CreatedAt, tc.eak.CreatedAt)
assert.Equals(t, eak.KeyBytes, tc.eak.KeyBytes) assert.Equals(t, eak.HmacKey, tc.eak.HmacKey)
assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID) assert.Equals(t, eak.ProvisionerID, tc.eak.ProvisionerID)
assert.Equals(t, eak.Reference, tc.eak.Reference) assert.Equals(t, eak.Reference, tc.eak.Reference)
} }
@ -420,7 +421,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b1, err := json.Marshal(dbeak1) b1, err := json.Marshal(dbeak1)
@ -430,7 +431,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b2, err := json.Marshal(dbeak2) b2, err := json.Marshal(dbeak2)
@ -440,7 +441,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: "aDifferentProvID", ProvisionerID: "aDifferentProvID",
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b3, err := json.Marshal(dbeak3) b3, err := json.Marshal(dbeak3)
@ -513,7 +514,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
}, },
{ {
@ -521,7 +522,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
}, },
}, },
@ -598,7 +599,7 @@ func TestDB_GetExternalAccountKeys(t *testing.T) {
assert.Equals(t, "", nextCursor) assert.Equals(t, "", nextCursor)
for i, eak := range eaks { for i, eak := range eaks {
assert.Equals(t, eak.ID, tc.eaks[i].ID) assert.Equals(t, eak.ID, tc.eaks[i].ID)
assert.Equals(t, eak.KeyBytes, tc.eaks[i].KeyBytes) assert.Equals(t, eak.HmacKey, tc.eaks[i].HmacKey)
assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID) assert.Equals(t, eak.ProvisionerID, tc.eaks[i].ProvisionerID)
assert.Equals(t, eak.Reference, tc.eaks[i].Reference) assert.Equals(t, eak.Reference, tc.eaks[i].Reference)
assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt) assert.Equals(t, eak.CreatedAt, tc.eaks[i].CreatedAt)
@ -627,7 +628,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
dbref := &dbExternalAccountKeyReference{ dbref := &dbExternalAccountKeyReference{
@ -707,7 +708,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: "aDifferentProvID", ProvisionerID: "aDifferentProvID",
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b, err := json.Marshal(dbeak) b, err := json.Marshal(dbeak)
@ -730,7 +731,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
dbref := &dbExternalAccountKeyReference{ dbref := &dbExternalAccountKeyReference{
@ -780,7 +781,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
dbref := &dbExternalAccountKeyReference{ dbref := &dbExternalAccountKeyReference{
@ -830,7 +831,7 @@ func TestDB_DeleteExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
dbref := &dbExternalAccountKeyReference{ dbref := &dbExternalAccountKeyReference{
@ -953,7 +954,7 @@ func TestDB_CreateExternalAccountKey(t *testing.T) {
assert.Equals(t, string(key), dbeak.ID) assert.Equals(t, string(key), dbeak.ID)
assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID) assert.Equals(t, eak.ProvisionerID, dbeak.ProvisionerID)
assert.Equals(t, eak.Reference, dbeak.Reference) assert.Equals(t, eak.Reference, dbeak.Reference)
assert.Equals(t, 32, len(dbeak.KeyBytes)) assert.Equals(t, 32, len(dbeak.HmacKey))
assert.False(t, dbeak.CreatedAt.IsZero()) assert.False(t, dbeak.CreatedAt.IsZero())
assert.Equals(t, dbeak.AccountID, eak.AccountID) assert.Equals(t, dbeak.AccountID, eak.AccountID)
assert.True(t, dbeak.BoundAt.IsZero()) assert.True(t, dbeak.BoundAt.IsZero())
@ -1078,7 +1079,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b, err := json.Marshal(dbeak) b, err := json.Marshal(dbeak)
@ -1096,7 +1097,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
return test{ return test{
@ -1120,7 +1121,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
assert.Equals(t, dbNew.AccountID, dbeak.AccountID) assert.Equals(t, dbNew.AccountID, dbeak.AccountID)
assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt) assert.Equals(t, dbNew.CreatedAt, dbeak.CreatedAt)
assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt) assert.Equals(t, dbNew.BoundAt, dbeak.BoundAt)
assert.Equals(t, dbNew.KeyBytes, dbeak.KeyBytes) assert.Equals(t, dbNew.HmacKey, dbeak.HmacKey)
return nu, true, nil return nu, true, nil
}, },
}, },
@ -1148,7 +1149,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: "aDifferentProvID", ProvisionerID: "aDifferentProvID",
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b, err := json.Marshal(newDBEAK) b, err := json.Marshal(newDBEAK)
@ -1174,7 +1175,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b, err := json.Marshal(newDBEAK) b, err := json.Marshal(newDBEAK)
@ -1200,7 +1201,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
ProvisionerID: provID, ProvisionerID: provID,
Reference: ref, Reference: ref,
AccountID: "", AccountID: "",
KeyBytes: []byte{1, 3, 3, 7}, HmacKey: []byte{1, 3, 3, 7},
CreatedAt: now, CreatedAt: now,
} }
b, err := json.Marshal(newDBEAK) b, err := json.Marshal(newDBEAK)
@ -1237,7 +1238,7 @@ func TestDB_UpdateExternalAccountKey(t *testing.T) {
assert.Equals(t, dbeak.AccountID, tc.eak.AccountID) assert.Equals(t, dbeak.AccountID, tc.eak.AccountID)
assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt) assert.Equals(t, dbeak.CreatedAt, tc.eak.CreatedAt)
assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt) assert.Equals(t, dbeak.BoundAt, tc.eak.BoundAt)
assert.Equals(t, dbeak.KeyBytes, tc.eak.KeyBytes) assert.Equals(t, dbeak.HmacKey, tc.eak.HmacKey)
} }
}) })
} }

@ -1,100 +1,19 @@
package api package acme
import ( import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/smallstep/certificates/acme" "github.com/go-chi/chi"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/provisioner"
) )
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) Linker {
_, _, err := net.SplitHostPort(dns)
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
// these cases, then the input dns is not changed.
lastIndex := strings.LastIndex(dns, ":")
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
if ip := net.ParseIP(hostPart); ip != nil {
dns = "[" + hostPart + "]:" + portPart
} else if ip := net.ParseIP(dns); ip != nil {
dns = "[" + dns + "]"
}
}
return &linker{prefix: prefix, dns: dns}
}
// Linker interface for generating links for ACME resources.
type Linker interface {
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
GetUnescapedPathSuffix(typ LinkType, provName string, inputs ...string) string
LinkOrder(ctx context.Context, o *acme.Order)
LinkAccount(ctx context.Context, o *acme.Account)
LinkChallenge(ctx context.Context, o *acme.Challenge, azID string)
LinkAuthorization(ctx context.Context, o *acme.Authorization)
LinkOrdersByAccountID(ctx context.Context, orders []string)
}
// linker generates ACME links.
type linker struct {
prefix string
dns string
}
func (l *linker) GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// GetLink is a helper for GetLinkExplicit
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var (
provName string
baseURL = baseURLFromContext(ctx)
u = url.URL{}
)
if p, err := provisionerFromContext(ctx); err == nil && p != nil {
provName = p.GetName()
}
// Copy the baseURL value from the pointer. https://github.com/golang/go/issues/38351
if baseURL != nil {
u = *baseURL
}
u.Path = l.GetUnescapedPathSuffix(typ, provName, inputs...)
// If no Scheme is set, then default to https.
if u.Scheme == "" {
u.Scheme = "https"
}
// If no Host is set, then use the default (first DNS attr in the ca.json).
if u.Host == "" {
u.Host = l.dns
}
u.Path = l.prefix + u.Path
return u.String()
}
// LinkType captures the link type. // LinkType captures the link type.
type LinkType int type LinkType int
@ -160,8 +79,155 @@ func (l LinkType) String() string {
} }
} }
func GetUnescapedPathSuffix(typ LinkType, provisionerName string, inputs ...string) string {
switch typ {
case NewNonceLinkType, NewAccountLinkType, NewOrderLinkType, NewAuthzLinkType, DirectoryLinkType, KeyChangeLinkType, RevokeCertLinkType:
return fmt.Sprintf("/%s/%s", provisionerName, typ)
case AccountLinkType, OrderLinkType, AuthzLinkType, CertificateLinkType:
return fmt.Sprintf("/%s/%s/%s", provisionerName, typ, inputs[0])
case ChallengeLinkType:
return fmt.Sprintf("/%s/%s/%s/%s", provisionerName, typ, inputs[0], inputs[1])
case OrdersByAccountLinkType:
return fmt.Sprintf("/%s/%s/%s/orders", provisionerName, AccountLinkType, inputs[0])
case FinalizeLinkType:
return fmt.Sprintf("/%s/%s/%s/finalize", provisionerName, OrderLinkType, inputs[0])
default:
return ""
}
}
// NewLinker returns a new Directory type.
func NewLinker(dns, prefix string) Linker {
_, _, err := net.SplitHostPort(dns)
if err != nil && strings.Contains(err.Error(), "too many colons in address") {
// this is most probably an IPv6 without brackets, e.g. ::1, 2001:0db8:85a3:0000:0000:8a2e:0370:7334
// in case a port was appended to this wrong format, we try to extract the port, then check if it's
// still a valid IPv6: 2001:0db8:85a3:0000:0000:8a2e:0370:7334:8443 (8443 is the port). If none of
// these cases, then the input dns is not changed.
lastIndex := strings.LastIndex(dns, ":")
hostPart, portPart := dns[:lastIndex], dns[lastIndex+1:]
if ip := net.ParseIP(hostPart); ip != nil {
dns = "[" + hostPart + "]:" + portPart
} else if ip := net.ParseIP(dns); ip != nil {
dns = "[" + dns + "]"
}
}
return &linker{prefix: prefix, dns: dns}
}
// Linker interface for generating links for ACME resources.
type Linker interface {
GetLink(ctx context.Context, typ LinkType, inputs ...string) string
Middleware(http.Handler) http.Handler
LinkOrder(ctx context.Context, o *Order)
LinkAccount(ctx context.Context, o *Account)
LinkChallenge(ctx context.Context, o *Challenge, azID string)
LinkAuthorization(ctx context.Context, o *Authorization)
LinkOrdersByAccountID(ctx context.Context, orders []string)
}
type linkerKey struct{}
// NewLinkerContext adds the given linker to the context.
func NewLinkerContext(ctx context.Context, v Linker) context.Context {
return context.WithValue(ctx, linkerKey{}, v)
}
// LinkerFromContext returns the current linker from the given context.
func LinkerFromContext(ctx context.Context) (v Linker, ok bool) {
v, ok = ctx.Value(linkerKey{}).(Linker)
return
}
// MustLinkerFromContext returns the current linker from the given context. It
// will panic if it's not in the context.
func MustLinkerFromContext(ctx context.Context) Linker {
if v, ok := LinkerFromContext(ctx); !ok {
panic("acme linker is not the context")
} else {
return v
}
}
type baseURLKey struct{}
func newBaseURLContext(ctx context.Context, r *http.Request) context.Context {
var u *url.URL
if r.Host != "" {
u = &url.URL{Scheme: "https", Host: r.Host}
}
return context.WithValue(ctx, baseURLKey{}, u)
}
func baseURLFromContext(ctx context.Context) *url.URL {
if u, ok := ctx.Value(baseURLKey{}).(*url.URL); ok {
return u
}
return nil
}
// linker generates ACME links.
type linker struct {
prefix string
dns string
}
// Middleware gets the provisioner and current url from the request and sets
// them in the context so we can use the linker to create ACME links.
func (l *linker) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add base url to the context.
ctx := newBaseURLContext(r.Context(), r)
// Add provisioner to the context.
nameEscaped := chi.URLParam(r, "provisionerID")
name, err := url.PathUnescape(nameEscaped)
if err != nil {
render.Error(w, WrapErrorISE(err, "error url unescaping provisioner name '%s'", nameEscaped))
return
}
p, err := authority.MustFromContext(ctx).LoadProvisionerByName(name)
if err != nil {
render.Error(w, err)
return
}
acmeProv, ok := p.(*provisioner.ACME)
if !ok {
render.Error(w, NewError(ErrorAccountDoesNotExistType, "provisioner must be of type ACME"))
return
}
ctx = NewProvisionerContext(ctx, Provisioner(acmeProv))
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// GetLink is a helper for GetLinkExplicit.
func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string {
var name string
if p, ok := ProvisionerFromContext(ctx); ok {
name = p.GetName()
}
var u url.URL
if baseURL := baseURLFromContext(ctx); baseURL != nil {
u = *baseURL
}
if u.Scheme == "" {
u.Scheme = "https"
}
if u.Host == "" {
u.Host = l.dns
}
u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...)
return u.String()
}
// LinkOrder sets the ACME links required by an ACME order. // LinkOrder sets the ACME links required by an ACME order.
func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) { func (l *linker) LinkOrder(ctx context.Context, o *Order) {
o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs)) o.AuthorizationURLs = make([]string, len(o.AuthorizationIDs))
for i, azID := range o.AuthorizationIDs { for i, azID := range o.AuthorizationIDs {
o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID) o.AuthorizationURLs[i] = l.GetLink(ctx, AuthzLinkType, azID)
@ -173,17 +239,17 @@ func (l *linker) LinkOrder(ctx context.Context, o *acme.Order) {
} }
// LinkAccount sets the ACME links required by an ACME account. // LinkAccount sets the ACME links required by an ACME account.
func (l *linker) LinkAccount(ctx context.Context, acc *acme.Account) { func (l *linker) LinkAccount(ctx context.Context, acc *Account) {
acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID) acc.OrdersURL = l.GetLink(ctx, OrdersByAccountLinkType, acc.ID)
} }
// LinkChallenge sets the ACME links required by an ACME challenge. // LinkChallenge sets the ACME links required by an ACME challenge.
func (l *linker) LinkChallenge(ctx context.Context, ch *acme.Challenge, azID string) { func (l *linker) LinkChallenge(ctx context.Context, ch *Challenge, azID string) {
ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID) ch.URL = l.GetLink(ctx, ChallengeLinkType, azID, ch.ID)
} }
// LinkAuthorization sets the ACME links required by an ACME authorization. // LinkAuthorization sets the ACME links required by an ACME authorization.
func (l *linker) LinkAuthorization(ctx context.Context, az *acme.Authorization) { func (l *linker) LinkAuthorization(ctx context.Context, az *Authorization) {
for _, ch := range az.Challenges { for _, ch := range az.Challenges {
l.LinkChallenge(ctx, ch, az.ID) l.LinkChallenge(ctx, ch, az.ID)
} }

@ -1,21 +1,38 @@
package api package acme
import ( import (
"context" "context"
"fmt" "fmt"
"net/url" "net/url"
"testing" "testing"
"time"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/authority/provisioner"
) )
func TestLinker_GetUnescapedPathSuffix(t *testing.T) { func mockProvisioner(t *testing.T) Provisioner {
dns := "ca.smallstep.com" t.Helper()
prefix := "acme" var defaultDisableRenewal = false
linker := NewLinker(dns, prefix)
// Initialize provisioners
p := &provisioner.ACME{
Type: "ACME",
Name: "test@acme-<test>provisioner.com",
}
if err := p.Init(provisioner.Config{Claims: provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute},
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DisableRenewal: &defaultDisableRenewal,
}}); err != nil {
fmt.Printf("%v", err)
}
return p
}
getPath := linker.GetUnescapedPathSuffix func TestGetUnescapedPathSuffix(t *testing.T) {
getPath := GetUnescapedPathSuffix
assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce") assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce")
assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory") assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory")
@ -32,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) {
} }
func TestLinker_DNS(t *testing.T) { func TestLinker_DNS(t *testing.T) {
prov := newProv() prov := mockProvisioner(t)
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := NewProvisionerContext(context.Background(), prov)
type test struct { type test struct {
name string name string
dns string dns string
@ -117,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) {
linker := NewLinker(dns, prefix) linker := NewLinker(dns, prefix)
id := "1234" id := "1234"
prov := newProv() prov := mockProvisioner(t)
escProvName := url.PathEscape(prov.GetName()) escProvName := url.PathEscape(prov.GetName())
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
ctx := context.WithValue(context.Background(), provisionerContextKey, prov) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, baseURLContextKey, baseURL) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
// No provisioner and no BaseURL from request // No provisioner and no BaseURL from request
assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", ""))
// Provisioner: yes, BaseURL: no // Provisioner: yes, BaseURL: no
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName))
// Provisioner: no, BaseURL: yes // Provisioner: no, BaseURL: yes
assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", ""))
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName))
@ -163,37 +180,37 @@ func TestLinker_GetLink(t *testing.T) {
func TestLinker_LinkOrder(t *testing.T) { func TestLinker_LinkOrder(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
oid := "orderID" oid := "orderID"
certID := "certID" certID := "certID"
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
o *acme.Order o *Order
validate func(o *acme.Order) validate func(o *Order)
} }
var tests = map[string]test{ var tests = map[string]test{
"no-authz-and-no-cert": { "no-authz-and-no-cert": {
o: &acme.Order{ o: &Order{
ID: oid, ID: oid,
}, },
validate: func(o *acme.Order) { validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{}) assert.Equals(t, o.AuthorizationURLs, []string{})
assert.Equals(t, o.CertificateURL, "") assert.Equals(t, o.CertificateURL, "")
}, },
}, },
"one-authz-and-cert": { "one-authz-and-cert": {
o: &acme.Order{ o: &Order{
ID: oid, ID: oid,
CertificateID: certID, CertificateID: certID,
AuthorizationIDs: []string{"foo"}, AuthorizationIDs: []string{"foo"},
}, },
validate: func(o *acme.Order) { validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{ assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -202,12 +219,12 @@ func TestLinker_LinkOrder(t *testing.T) {
}, },
}, },
"many-authz": { "many-authz": {
o: &acme.Order{ o: &Order{
ID: oid, ID: oid,
CertificateID: certID, CertificateID: certID,
AuthorizationIDs: []string{"foo", "bar", "zap"}, AuthorizationIDs: []string{"foo", "bar", "zap"},
}, },
validate: func(o *acme.Order) { validate: func(o *Order) {
assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid)) assert.Equals(t, o.FinalizeURL, fmt.Sprintf("%s/%s/%s/order/%s/finalize", baseURL, linkerPrefix, provName, oid))
assert.Equals(t, o.AuthorizationURLs, []string{ assert.Equals(t, o.AuthorizationURLs, []string{
fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"), fmt.Sprintf("%s/%s/%s/authz/%s", baseURL, linkerPrefix, provName, "foo"),
@ -228,24 +245,24 @@ func TestLinker_LinkOrder(t *testing.T) {
func TestLinker_LinkAccount(t *testing.T) { func TestLinker_LinkAccount(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
accID := "accountID" accID := "accountID"
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
a *acme.Account a *Account
validate func(o *acme.Account) validate func(o *Account)
} }
var tests = map[string]test{ var tests = map[string]test{
"ok": { "ok": {
a: &acme.Account{ a: &Account{
ID: accID, ID: accID,
}, },
validate: func(a *acme.Account) { validate: func(a *Account) {
assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID)) assert.Equals(t, a.OrdersURL, fmt.Sprintf("%s/%s/%s/account/%s/orders", baseURL, linkerPrefix, provName, accID))
}, },
}, },
@ -260,25 +277,25 @@ func TestLinker_LinkAccount(t *testing.T) {
func TestLinker_LinkChallenge(t *testing.T) { func TestLinker_LinkChallenge(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
chID := "chID" chID := "chID"
azID := "azID" azID := "azID"
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
ch *acme.Challenge ch *Challenge
validate func(o *acme.Challenge) validate func(o *Challenge)
} }
var tests = map[string]test{ var tests = map[string]test{
"ok": { "ok": {
ch: &acme.Challenge{ ch: &Challenge{
ID: chID, ID: chID,
}, },
validate: func(ch *acme.Challenge) { validate: func(ch *Challenge) {
assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID)) assert.Equals(t, ch.URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, azID, ch.ID))
}, },
}, },
@ -293,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) {
func TestLinker_LinkAuthorization(t *testing.T) { func TestLinker_LinkAuthorization(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
chID0 := "chID-0" chID0 := "chID-0"
chID1 := "chID-1" chID1 := "chID-1"
@ -305,20 +322,20 @@ func TestLinker_LinkAuthorization(t *testing.T) {
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)
type test struct { type test struct {
az *acme.Authorization az *Authorization
validate func(o *acme.Authorization) validate func(o *Authorization)
} }
var tests = map[string]test{ var tests = map[string]test{
"ok": { "ok": {
az: &acme.Authorization{ az: &Authorization{
ID: azID, ID: azID,
Challenges: []*acme.Challenge{ Challenges: []*Challenge{
{ID: chID0}, {ID: chID0},
{ID: chID1}, {ID: chID1},
{ID: chID2}, {ID: chID2},
}, },
}, },
validate: func(az *acme.Authorization) { validate: func(az *Authorization) {
assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0)) assert.Equals(t, az.Challenges[0].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID0))
assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1)) assert.Equals(t, az.Challenges[1].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID1))
assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2)) assert.Equals(t, az.Challenges[2].URL, fmt.Sprintf("%s/%s/%s/challenge/%s/%s", baseURL, linkerPrefix, provName, az.ID, chID2))
@ -335,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) {
func TestLinker_LinkOrdersByAccountID(t *testing.T) { func TestLinker_LinkOrdersByAccountID(t *testing.T) {
baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"}
prov := newProv() prov := mockProvisioner(t)
provName := url.PathEscape(prov.GetName()) provName := url.PathEscape(prov.GetName())
ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) ctx := NewProvisionerContext(context.Background(), prov)
ctx = context.WithValue(ctx, provisionerContextKey, prov) ctx = context.WithValue(ctx, baseURLKey{}, baseURL)
linkerPrefix := "acme" linkerPrefix := "acme"
l := NewLinker("dns", linkerPrefix) l := NewLinker("dns", linkerPrefix)

@ -268,6 +268,7 @@ func TestOrder_UpdateStatus(t *testing.T) {
type mockSignAuth struct { type mockSignAuth struct {
sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) sign func(csr *x509.CertificateRequest, signOpts provisioner.SignOptions, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error)
areSANsAllowed func(ctx context.Context, sans []string) error
loadProvisionerByName func(string) (provisioner.Interface, error) loadProvisionerByName func(string) (provisioner.Interface, error)
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
@ -282,6 +283,13 @@ func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.S
return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err
} }
func (m *mockSignAuth) AreSANsAllowed(ctx context.Context, sans []string) error {
if m.areSANsAllowed != nil {
return m.areSANsAllowed(ctx, sans)
}
return m.err
}
func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) { func (m *mockSignAuth) LoadProvisionerByName(name string) (provisioner.Interface, error) {
if m.loadProvisionerByName != nil { if m.loadProvisionerByName != nil {
return m.loadProvisionerByName(name) return m.loadProvisionerByName(name)

@ -35,7 +35,6 @@ type Authority interface {
SSHAuthority SSHAuthority
// context specifies the Authorize[Sign|Revoke|etc.] method. // context specifies the Authorize[Sign|Revoke|etc.] method.
Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error)
AuthorizeSign(ott string) ([]provisioner.SignOption, error)
AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.Certificate, error)
GetTLSOptions() *config.TLSOptions GetTLSOptions() *config.TLSOptions
Root(shasum string) (*x509.Certificate, error) Root(shasum string) (*x509.Certificate, error)
@ -53,6 +52,11 @@ type Authority interface {
GetCertificateRevocationList() ([]byte, error) GetCertificateRevocationList() ([]byte, error)
} }
// mustAuthority will be replaced on unit tests.
var mustAuthority = func(ctx context.Context) Authority {
return authority.MustFromContext(ctx)
}
// TimeDuration is an alias of provisioner.TimeDuration // TimeDuration is an alias of provisioner.TimeDuration
type TimeDuration = provisioner.TimeDuration type TimeDuration = provisioner.TimeDuration
@ -244,49 +248,54 @@ type caHandler struct {
Authority Authority Authority Authority
} }
// New creates a new RouterHandler with the CA endpoints. // Route configures the http request router.
func New(auth Authority) RouterHandler { func (h *caHandler) Route(r Router) {
return &caHandler{ Route(r)
Authority: auth,
}
} }
func (h *caHandler) Route(r Router) { // New creates a new RouterHandler with the CA endpoints.
r.MethodFunc("GET", "/version", h.Version) //
r.MethodFunc("GET", "/health", h.Health) // Deprecated: Use api.Route(r Router)
r.MethodFunc("GET", "/root/{sha}", h.Root) func New(auth Authority) RouterHandler {
r.MethodFunc("POST", "/sign", h.Sign) return &caHandler{}
r.MethodFunc("POST", "/renew", h.Renew) }
r.MethodFunc("POST", "/rekey", h.Rekey)
r.MethodFunc("POST", "/revoke", h.Revoke) func Route(r Router) {
r.MethodFunc("GET", "/crl", h.CRL) r.MethodFunc("GET", "/version", Version)
r.MethodFunc("GET", "/provisioners", h.Provisioners) r.MethodFunc("GET", "/health", Health)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", h.ProvisionerKey) r.MethodFunc("GET", "/root/{sha}", Root)
r.MethodFunc("GET", "/roots", h.Roots) r.MethodFunc("POST", "/sign", Sign)
r.MethodFunc("GET", "/roots.pem", h.RootsPEM) r.MethodFunc("POST", "/renew", Renew)
r.MethodFunc("GET", "/federation", h.Federation) r.MethodFunc("POST", "/rekey", Rekey)
r.MethodFunc("POST", "/revoke", Revoke)
r.MethodFunc("GET", "/crl", CRL)
r.MethodFunc("GET", "/provisioners", Provisioners)
r.MethodFunc("GET", "/provisioners/{kid}/encrypted-key", ProvisionerKey)
r.MethodFunc("GET", "/roots", Roots)
r.MethodFunc("GET", "/roots.pem", RootsPEM)
r.MethodFunc("GET", "/federation", Federation)
// SSH CA // SSH CA
r.MethodFunc("POST", "/ssh/sign", h.SSHSign) r.MethodFunc("POST", "/ssh/sign", SSHSign)
r.MethodFunc("POST", "/ssh/renew", h.SSHRenew) r.MethodFunc("POST", "/ssh/renew", SSHRenew)
r.MethodFunc("POST", "/ssh/revoke", h.SSHRevoke) r.MethodFunc("POST", "/ssh/revoke", SSHRevoke)
r.MethodFunc("POST", "/ssh/rekey", h.SSHRekey) r.MethodFunc("POST", "/ssh/rekey", SSHRekey)
r.MethodFunc("GET", "/ssh/roots", h.SSHRoots) r.MethodFunc("GET", "/ssh/roots", SSHRoots)
r.MethodFunc("GET", "/ssh/federation", h.SSHFederation) r.MethodFunc("GET", "/ssh/federation", SSHFederation)
r.MethodFunc("POST", "/ssh/config", h.SSHConfig) r.MethodFunc("POST", "/ssh/config", SSHConfig)
r.MethodFunc("POST", "/ssh/config/{type}", h.SSHConfig) r.MethodFunc("POST", "/ssh/config/{type}", SSHConfig)
r.MethodFunc("POST", "/ssh/check-host", h.SSHCheckHost) r.MethodFunc("POST", "/ssh/check-host", SSHCheckHost)
r.MethodFunc("GET", "/ssh/hosts", h.SSHGetHosts) r.MethodFunc("GET", "/ssh/hosts", SSHGetHosts)
r.MethodFunc("POST", "/ssh/bastion", h.SSHBastion) r.MethodFunc("POST", "/ssh/bastion", SSHBastion)
// For compatibility with old code: // For compatibility with old code:
r.MethodFunc("POST", "/re-sign", h.Renew) r.MethodFunc("POST", "/re-sign", Renew)
r.MethodFunc("POST", "/sign-ssh", h.SSHSign) r.MethodFunc("POST", "/sign-ssh", SSHSign)
r.MethodFunc("GET", "/ssh/get-hosts", h.SSHGetHosts) r.MethodFunc("GET", "/ssh/get-hosts", SSHGetHosts)
} }
// Version is an HTTP handler that returns the version of the server. // Version is an HTTP handler that returns the version of the server.
func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) { func Version(w http.ResponseWriter, r *http.Request) {
v := h.Authority.Version() v := mustAuthority(r.Context()).Version()
render.JSON(w, VersionResponse{ render.JSON(w, VersionResponse{
Version: v.Version, Version: v.Version,
RequireClientAuthentication: v.RequireClientAuthentication, RequireClientAuthentication: v.RequireClientAuthentication,
@ -294,17 +303,17 @@ func (h *caHandler) Version(w http.ResponseWriter, r *http.Request) {
} }
// Health is an HTTP handler that returns the status of the server. // Health is an HTTP handler that returns the status of the server.
func (h *caHandler) Health(w http.ResponseWriter, r *http.Request) { func Health(w http.ResponseWriter, r *http.Request) {
render.JSON(w, HealthResponse{Status: "ok"}) render.JSON(w, HealthResponse{Status: "ok"})
} }
// Root is an HTTP handler that using the SHA256 from the URL, returns the root // Root is an HTTP handler that using the SHA256 from the URL, returns the root
// certificate for the given SHA256. // certificate for the given SHA256.
func (h *caHandler) Root(w http.ResponseWriter, r *http.Request) { func Root(w http.ResponseWriter, r *http.Request) {
sha := chi.URLParam(r, "sha") sha := chi.URLParam(r, "sha")
sum := strings.ToLower(strings.ReplaceAll(sha, "-", "")) sum := strings.ToLower(strings.ReplaceAll(sha, "-", ""))
// Load root certificate with the // Load root certificate with the
cert, err := h.Authority.Root(sum) cert, err := mustAuthority(r.Context()).Root(sum)
if err != nil { if err != nil {
render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI)) render.Error(w, errs.Wrapf(http.StatusNotFound, err, "%s was not found", r.RequestURI))
return return
@ -322,18 +331,19 @@ func certChainToPEM(certChain []*x509.Certificate) []Certificate {
} }
// Provisioners returns the list of provisioners configured in the authority. // Provisioners returns the list of provisioners configured in the authority.
func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) { func Provisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := ParseCursor(r) cursor, limit, err := ParseCursor(r)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
p, next, err := h.Authority.GetProvisioners(cursor, limit) p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
} }
render.JSON(w, &ProvisionersResponse{ render.JSON(w, &ProvisionersResponse{
Provisioners: p, Provisioners: p,
NextCursor: next, NextCursor: next,
@ -341,19 +351,20 @@ func (h *caHandler) Provisioners(w http.ResponseWriter, r *http.Request) {
} }
// ProvisionerKey returns the encrypted key of a provisioner by it's key id. // ProvisionerKey returns the encrypted key of a provisioner by it's key id.
func (h *caHandler) ProvisionerKey(w http.ResponseWriter, r *http.Request) { func ProvisionerKey(w http.ResponseWriter, r *http.Request) {
kid := chi.URLParam(r, "kid") kid := chi.URLParam(r, "kid")
key, err := h.Authority.GetEncryptedKey(kid) key, err := mustAuthority(r.Context()).GetEncryptedKey(kid)
if err != nil { if err != nil {
render.Error(w, errs.NotFoundErr(err)) render.Error(w, errs.NotFoundErr(err))
return return
} }
render.JSON(w, &ProvisionerKeyResponse{key}) render.JSON(w, &ProvisionerKeyResponse{key})
} }
// Roots returns all the root certificates for the CA. // Roots returns all the root certificates for the CA.
func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) { func Roots(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := mustAuthority(r.Context()).GetRoots()
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error getting roots")) render.Error(w, errs.ForbiddenErr(err, "error getting roots"))
return return
@ -370,8 +381,8 @@ func (h *caHandler) Roots(w http.ResponseWriter, r *http.Request) {
} }
// RootsPEM returns all the root certificates for the CA in PEM format. // RootsPEM returns all the root certificates for the CA in PEM format.
func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) { func RootsPEM(w http.ResponseWriter, r *http.Request) {
roots, err := h.Authority.GetRoots() roots, err := mustAuthority(r.Context()).GetRoots()
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -393,8 +404,8 @@ func (h *caHandler) RootsPEM(w http.ResponseWriter, r *http.Request) {
} }
// Federation returns all the public certificates in the federation. // Federation returns all the public certificates in the federation.
func (h *caHandler) Federation(w http.ResponseWriter, r *http.Request) { func Federation(w http.ResponseWriter, r *http.Request) {
federated, err := h.Authority.GetFederation() federated, err := mustAuthority(r.Context()).GetFederation()
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error getting federated roots")) render.Error(w, errs.ForbiddenErr(err, "error getting federated roots"))
return return

@ -171,10 +171,21 @@ func parseCertificateRequest(data string) *x509.CertificateRequest {
return csr return csr
} }
func mockMustAuthority(t *testing.T, a Authority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) Authority {
return a
}
}
type mockAuthority struct { type mockAuthority struct {
ret1, ret2 interface{} ret1, ret2 interface{}
err error err error
authorizeSign func(ott string) ([]provisioner.SignOption, error) authorize func(ctx context.Context, ott string) ([]provisioner.SignOption, error)
authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error) authorizeRenewToken func(ctx context.Context, ott string) (*x509.Certificate, error)
getTLSOptions func() *authority.TLSOptions getTLSOptions func() *authority.TLSOptions
root func(shasum string) (*x509.Certificate, error) root func(shasum string) (*x509.Certificate, error)
@ -207,12 +218,8 @@ func (m *mockAuthority) GetCertificateRevocationList() ([]byte, error) {
// TODO: remove once Authorize is deprecated. // TODO: remove once Authorize is deprecated.
func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) { func (m *mockAuthority) Authorize(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return m.AuthorizeSign(ott) if m.authorize != nil {
} return m.authorize(ctx, ott)
func (m *mockAuthority) AuthorizeSign(ott string) ([]provisioner.SignOption, error) {
if m.authorizeSign != nil {
return m.authorizeSign(ott)
} }
return m.ret1.([]provisioner.SignOption), m.err return m.ret1.([]provisioner.SignOption), m.err
} }
@ -793,11 +800,10 @@ func Test_caHandler_Route(t *testing.T) {
} }
} }
func Test_caHandler_Health(t *testing.T) { func Test_Health(t *testing.T) {
req := httptest.NewRequest("GET", "http://example.com/health", nil) req := httptest.NewRequest("GET", "http://example.com/health", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h := New(&mockAuthority{}).(*caHandler) Health(w, req)
h.Health(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != 200 { if res.StatusCode != 200 {
@ -815,7 +821,7 @@ func Test_caHandler_Health(t *testing.T) {
} }
} }
func Test_caHandler_Root(t *testing.T) { func Test_Root(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
root *x509.Certificate root *x509.Certificate
@ -836,9 +842,9 @@ func Test_caHandler_Root(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.root, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: tt.root, err: tt.err})
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Root(w, req) Root(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -859,7 +865,7 @@ func Test_caHandler_Root(t *testing.T) {
} }
} }
func Test_caHandler_Sign(t *testing.T) { func Test_Sign(t *testing.T) {
csr := parseCertificateRequest(csrPEM) csr := parseCertificateRequest(csrPEM)
valid, err := json.Marshal(SignRequest{ valid, err := json.Marshal(SignRequest{
CsrPEM: CertificateRequest{csr}, CsrPEM: CertificateRequest{csr},
@ -900,18 +906,18 @@ func Test_caHandler_Sign(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.signErr, ret1: tt.cert, ret2: tt.root, err: tt.signErr,
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return tt.certAttrOpts, tt.autherr return tt.certAttrOpts, tt.autherr
}, },
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input)) req := httptest.NewRequest("POST", "http://example.com/sign", strings.NewReader(tt.input))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Sign(logging.NewResponseLogger(w), req) Sign(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -932,7 +938,7 @@ func Test_caHandler_Sign(t *testing.T) {
} }
} }
func Test_caHandler_Renew(t *testing.T) { func Test_Renew(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1022,7 +1028,7 @@ func Test_caHandler_Renew(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err, ret1: tt.cert, ret2: tt.root, err: tt.err,
authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) { authorizeRenewToken: func(ctx context.Context, ott string) (*x509.Certificate, error) {
jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root}) jwt, chain, err := jose.ParseX5cInsecure(ott, []*x509.Certificate{tt.root})
@ -1043,12 +1049,12 @@ func Test_caHandler_Renew(t *testing.T) {
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/renew", nil) req := httptest.NewRequest("POST", "http://example.com/renew", nil)
req.TLS = tt.tls req.TLS = tt.tls
req.Header = tt.header req.Header = tt.header
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Renew(logging.NewResponseLogger(w), req) Renew(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
defer res.Body.Close() defer res.Body.Close()
@ -1077,7 +1083,7 @@ func Test_caHandler_Renew(t *testing.T) {
} }
} }
func Test_caHandler_Rekey(t *testing.T) { func Test_Rekey(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1108,16 +1114,16 @@ func Test_caHandler_Rekey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
ret1: tt.cert, ret2: tt.root, err: tt.err, ret1: tt.cert, ret2: tt.root, err: tt.err,
getTLSOptions: func() *authority.TLSOptions { getTLSOptions: func() *authority.TLSOptions {
return nil return nil
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input)) req := httptest.NewRequest("POST", "http://example.com/rekey", strings.NewReader(tt.input))
req.TLS = tt.tls req.TLS = tt.tls
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Rekey(logging.NewResponseLogger(w), req) Rekey(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -1138,7 +1144,7 @@ func Test_caHandler_Rekey(t *testing.T) {
} }
} }
func Test_caHandler_Provisioners(t *testing.T) { func Test_Provisioners(t *testing.T) {
type fields struct { type fields struct {
Authority Authority Authority Authority
} }
@ -1204,10 +1210,8 @@ func Test_caHandler_Provisioners(t *testing.T) {
assert.FatalError(t, err) assert.FatalError(t, err)
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &caHandler{ mockMustAuthority(t, tt.fields.Authority)
Authority: tt.fields.Authority, Provisioners(tt.args.w, tt.args.r)
}
h.Provisioners(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder) rec := tt.args.w.(*httptest.ResponseRecorder)
res := rec.Result() res := rec.Result()
@ -1242,7 +1246,7 @@ func Test_caHandler_Provisioners(t *testing.T) {
} }
} }
func Test_caHandler_ProvisionerKey(t *testing.T) { func Test_ProvisionerKey(t *testing.T) {
type fields struct { type fields struct {
Authority Authority Authority Authority
} }
@ -1274,10 +1278,8 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := &caHandler{ mockMustAuthority(t, tt.fields.Authority)
Authority: tt.fields.Authority, ProvisionerKey(tt.args.w, tt.args.r)
}
h.ProvisionerKey(tt.args.w, tt.args.r)
rec := tt.args.w.(*httptest.ResponseRecorder) rec := tt.args.w.(*httptest.ResponseRecorder)
res := rec.Result() res := rec.Result()
@ -1302,7 +1304,7 @@ func Test_caHandler_ProvisionerKey(t *testing.T) {
} }
} }
func Test_caHandler_Roots(t *testing.T) { func Test_Roots(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1323,11 +1325,11 @@ func Test_caHandler_Roots(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
req := httptest.NewRequest("GET", "http://example.com/roots", nil) req := httptest.NewRequest("GET", "http://example.com/roots", nil)
req.TLS = tt.tls req.TLS = tt.tls
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Roots(w, req) Roots(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -1364,10 +1366,10 @@ func Test_caHandler_RootsPEM(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: tt.roots, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: tt.roots, err: tt.err})
req := httptest.NewRequest("GET", "https://example.com/roots", nil) req := httptest.NewRequest("GET", "https://example.com/roots", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.RootsPEM(w, req) RootsPEM(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -1388,7 +1390,7 @@ func Test_caHandler_RootsPEM(t *testing.T) {
} }
} }
func Test_caHandler_Federation(t *testing.T) { func Test_Federation(t *testing.T) {
cs := &tls.ConnectionState{ cs := &tls.ConnectionState{
PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)}, PeerCertificates: []*x509.Certificate{parseCertificate(certPEM)},
} }
@ -1409,11 +1411,11 @@ func Test_caHandler_Federation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err}).(*caHandler) mockMustAuthority(t, &mockAuthority{ret1: []*x509.Certificate{tt.root}, err: tt.err})
req := httptest.NewRequest("GET", "http://example.com/federation", nil) req := httptest.NewRequest("GET", "http://example.com/federation", nil)
req.TLS = tt.tls req.TLS = tt.tls
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Federation(w, req) Federation(w, req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {

@ -3,16 +3,20 @@ package read
import ( import (
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http"
"strings"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
// JSON reads JSON from the request body and stores it in the value // JSON reads JSON from the request body and stores it in the value
// pointed by v. // pointed to by v.
func JSON(r io.Reader, v interface{}) error { func JSON(r io.Reader, v interface{}) error {
if err := json.NewDecoder(r).Decode(v); err != nil { if err := json.NewDecoder(r).Decode(v); err != nil {
return errs.BadRequestErr(err, "error decoding json") return errs.BadRequestErr(err, "error decoding json")
@ -21,11 +25,42 @@ func JSON(r io.Reader, v interface{}) error {
} }
// ProtoJSON reads JSON from the request body and stores it in the value // ProtoJSON reads JSON from the request body and stores it in the value
// pointed by v. // pointed to by m.
func ProtoJSON(r io.Reader, m proto.Message) error { func ProtoJSON(r io.Reader, m proto.Message) error {
data, err := io.ReadAll(r) data, err := io.ReadAll(r)
if err != nil { if err != nil {
return errs.BadRequestErr(err, "error reading request body") return errs.BadRequestErr(err, "error reading request body")
} }
return protojson.Unmarshal(data, m)
switch err := protojson.Unmarshal(data, m); {
case errors.Is(err, proto.Error):
return badProtoJSONError(err.Error())
default:
return err
}
}
// badProtoJSONError is an error type that is returned by ProtoJSON
// when a proto message cannot be unmarshaled. Usually this is caused
// by an error in the request body.
type badProtoJSONError string
// Error implements error for badProtoJSONError
func (e badProtoJSONError) Error() string {
return string(e)
}
// Render implements render.RenderableError for badProtoJSONError
func (e badProtoJSONError) Render(w http.ResponseWriter) {
v := struct {
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}{
Type: "badRequest",
Detail: "bad request",
// trim the proto prefix for the message
Message: strings.TrimSpace(strings.TrimPrefix(e.Error(), "proto:")),
}
render.JSONStatus(w, v, http.StatusBadRequest)
} }

@ -1,10 +1,21 @@
package read package read
import ( import (
"encoding/json"
"errors"
"io" "io"
"net/http"
"net/http/httptest"
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"testing/iotest"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/errs" "github.com/smallstep/certificates/errs"
) )
@ -44,3 +55,110 @@ func TestJSON(t *testing.T) {
}) })
} }
} }
func TestProtoJSON(t *testing.T) {
p := new(linkedca.Policy) // TODO(hs): can we use something different, so we don't need the import?
type args struct {
r io.Reader
m proto.Message
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "fail/io.ReadAll",
args: args{
r: iotest.ErrReader(errors.New("read error")),
m: p,
},
wantErr: true,
},
{
name: "fail/proto",
args: args{
r: strings.NewReader(`{?}`),
m: p,
},
wantErr: true,
},
{
name: "ok",
args: args{
r: strings.NewReader(`{"x509":{}}`),
m: p,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ProtoJSON(tt.args.r, tt.args.m)
if (err != nil) != tt.wantErr {
t.Errorf("ProtoJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
switch err.(type) {
case badProtoJSONError:
assert.Contains(t, err.Error(), "syntax error")
case *errs.Error:
var ee *errs.Error
if errors.As(err, &ee) {
assert.Equal(t, http.StatusBadRequest, ee.Status)
}
}
return
}
assert.Equal(t, protoreflect.FullName("linkedca.Policy"), proto.MessageName(tt.args.m))
assert.True(t, proto.Equal(&linkedca.Policy{X509: &linkedca.X509Policy{}}, tt.args.m))
})
}
}
func Test_badProtoJSONError_Render(t *testing.T) {
tests := []struct {
name string
e badProtoJSONError
expected string
}{
{
name: "bad proto normal space",
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
expected: "syntax error (line 1:2): invalid value ?",
},
{
name: "bad proto non breaking space",
e: badProtoJSONError("proto: syntax error (line 1:2): invalid value ?"),
expected: "syntax error (line 1:2): invalid value ?",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
tt.e.Render(w)
res := w.Result()
defer res.Body.Close()
data, err := io.ReadAll(res.Body)
assert.NoError(t, err)
v := struct {
Type string `json:"type"`
Detail string `json:"detail"`
Message string `json:"message"`
}{}
assert.NoError(t, json.Unmarshal(data, &v))
assert.Equal(t, "badRequest", v.Type)
assert.Equal(t, "bad request", v.Detail)
assert.Equal(t, "syntax error (line 1:2): invalid value ?", v.Message)
})
}
}

@ -27,7 +27,7 @@ func (s *RekeyRequest) Validate() error {
} }
// Rekey is similar to renew except that the certificate will be renewed with new key from csr. // Rekey is similar to renew except that the certificate will be renewed with new key from csr.
func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) { func Rekey(w http.ResponseWriter, r *http.Request) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
render.Error(w, errs.BadRequest("missing client certificate")) render.Error(w, errs.BadRequest("missing client certificate"))
return return
@ -44,7 +44,8 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
return return
} }
certChain, err := h.Authority.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey) a := mustAuthority(r.Context())
certChain, err := a.Rekey(r.TLS.PeerCertificates[0], body.CsrPEM.CertificateRequest.PublicKey)
if err != nil { if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey")) render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Rekey"))
return return
@ -60,6 +61,6 @@ func (h *caHandler) Rekey(w http.ResponseWriter, r *http.Request) {
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: a.GetTLSOptions(),
}, http.StatusCreated) }, http.StatusCreated)
} }

@ -16,14 +16,15 @@ const (
// Renew uses the information of certificate in the TLS connection to create a // Renew uses the information of certificate in the TLS connection to create a
// new one. // new one.
func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) { func Renew(w http.ResponseWriter, r *http.Request) {
cert, err := h.getPeerCertificate(r) cert, err := getPeerCertificate(r)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
certChain, err := h.Authority.Renew(cert) a := mustAuthority(r.Context())
certChain, err := a.Renew(cert)
if err != nil { if err != nil {
render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew")) render.Error(w, errs.Wrap(http.StatusInternalServerError, err, "cahandler.Renew"))
return return
@ -39,17 +40,18 @@ func (h *caHandler) Renew(w http.ResponseWriter, r *http.Request) {
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: a.GetTLSOptions(),
}, http.StatusCreated) }, http.StatusCreated)
} }
func (h *caHandler) getPeerCertificate(r *http.Request) (*x509.Certificate, error) { func getPeerCertificate(r *http.Request) (*x509.Certificate, error) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
return r.TLS.PeerCertificates[0], nil return r.TLS.PeerCertificates[0], nil
} }
if s := r.Header.Get(authorizationHeader); s != "" { if s := r.Header.Get(authorizationHeader); s != "" {
if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 { if parts := strings.SplitN(s, bearerScheme+" ", 2); len(parts) == 2 {
return h.Authority.AuthorizeRenewToken(r.Context(), parts[1]) ctx := r.Context()
return mustAuthority(ctx).AuthorizeRenewToken(ctx, parts[1])
} }
} }
return nil, errs.BadRequest("missing client certificate") return nil, errs.BadRequest("missing client certificate")

@ -1,7 +1,6 @@
package api package api
import ( import (
"context"
"net/http" "net/http"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
@ -49,7 +48,7 @@ func (r *RevokeRequest) Validate() (err error) {
// NOTE: currently only Passive revocation is supported. // NOTE: currently only Passive revocation is supported.
// //
// TODO: Add CRL and OCSP support. // TODO: Add CRL and OCSP support.
func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) { func Revoke(w http.ResponseWriter, r *http.Request) {
var body RevokeRequest var body RevokeRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -68,12 +67,14 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
PassiveOnly: body.Passive, PassiveOnly: body.Passive,
} }
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.RevokeMethod)
a := mustAuthority(ctx)
// A token indicates that we are using the api via a provisioner token, // A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS. // otherwise it is assumed that the certificate is revoking itself over mTLS.
if len(body.OTT) > 0 { if len(body.OTT) > 0 {
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil { if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
@ -98,7 +99,7 @@ func (h *caHandler) Revoke(w http.ResponseWriter, r *http.Request) {
opts.MTLS = true opts.MTLS = true
} }
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := a.Revoke(ctx, opts); err != nil {
render.Error(w, errs.ForbiddenErr(err, "error revoking certificate")) render.Error(w, errs.ForbiddenErr(err, "error revoking certificate"))
return return
} }

@ -108,7 +108,7 @@ func Test_caHandler_Revoke(t *testing.T) {
input: string(input), input: string(input),
statusCode: http.StatusOK, statusCode: http.StatusOK,
auth: &mockAuthority{ auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
@ -152,7 +152,7 @@ func Test_caHandler_Revoke(t *testing.T) {
statusCode: http.StatusOK, statusCode: http.StatusOK,
tls: cs, tls: cs,
auth: &mockAuthority{ auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, ri *authority.RevokeOptions) error { revoke: func(ctx context.Context, ri *authority.RevokeOptions) error {
@ -187,7 +187,7 @@ func Test_caHandler_Revoke(t *testing.T) {
input: string(input), input: string(input),
statusCode: http.StatusInternalServerError, statusCode: http.StatusInternalServerError,
auth: &mockAuthority{ auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
@ -209,7 +209,7 @@ func Test_caHandler_Revoke(t *testing.T) {
input: string(input), input: string(input),
statusCode: http.StatusForbidden, statusCode: http.StatusForbidden,
auth: &mockAuthority{ auth: &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return nil, nil return nil, nil
}, },
revoke: func(ctx context.Context, opts *authority.RevokeOptions) error { revoke: func(ctx context.Context, opts *authority.RevokeOptions) error {
@ -223,13 +223,13 @@ func Test_caHandler_Revoke(t *testing.T) {
for name, _tc := range tests { for name, _tc := range tests {
tc := _tc(t) tc := _tc(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := New(tc.auth).(*caHandler) mockMustAuthority(t, tc.auth)
req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input)) req := httptest.NewRequest("POST", "http://example.com/revoke", strings.NewReader(tc.input))
if tc.tls != nil { if tc.tls != nil {
req.TLS = tc.tls req.TLS = tc.tls
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.Revoke(logging.NewResponseLogger(w), req) Revoke(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)

@ -49,7 +49,7 @@ type SignResponse struct {
// Sign is an HTTP handler that reads a certificate request and an // Sign is an HTTP handler that reads a certificate request and an
// one-time-token (ott) from the body and creates a new certificate with the // one-time-token (ott) from the body and creates a new certificate with the
// information in the certificate request. // information in the certificate request.
func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) { func Sign(w http.ResponseWriter, r *http.Request) {
var body SignRequest var body SignRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -68,13 +68,17 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
TemplateData: body.TemplateData, TemplateData: body.TemplateData,
} }
signOpts, err := h.Authority.AuthorizeSign(body.OTT) ctx := r.Context()
a := mustAuthority(ctx)
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
certChain, err := h.Authority.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...) certChain, err := a.Sign(body.CsrPEM.CertificateRequest, opts, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing certificate"))
return return
@ -89,6 +93,6 @@ func (h *caHandler) Sign(w http.ResponseWriter, r *http.Request) {
ServerPEM: certChainPEM[0], ServerPEM: certChainPEM[0],
CaPEM: caPEM, CaPEM: caPEM,
CertChainPEM: certChainPEM, CertChainPEM: certChainPEM,
TLSOptions: h.Authority.GetTLSOptions(), TLSOptions: a.GetTLSOptions(),
}, http.StatusCreated) }, http.StatusCreated)
} }

@ -250,7 +250,7 @@ type SSHBastionResponse struct {
// SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token // SSHSign is an HTTP handler that reads an SignSSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in // (ott) from the body and creates a new SSH certificate with the information in
// the request. // the request.
func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) { func SSHSign(w http.ResponseWriter, r *http.Request) {
var body SSHSignRequest var body SSHSignRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -288,13 +288,16 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHSignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) ctx = provisioner.NewContextWithToken(ctx, body.OTT)
a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
cert, err := h.Authority.SignSSH(ctx, publicKey, opts, signOpts...) cert, err := a.SignSSH(ctx, publicKey, opts, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
@ -302,7 +305,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
var addUserCertificate *SSHCertificate var addUserCertificate *SSHCertificate
if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil { if addUserPublicKey != nil && authority.IsValidForAddUser(cert) == nil {
addUserCert, err := h.Authority.SignSSHAddUser(ctx, addUserPublicKey, cert) addUserCert, err := a.SignSSHAddUser(ctx, addUserPublicKey, cert)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing ssh certificate"))
return return
@ -315,7 +318,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
if cr := body.IdentityCSR.CertificateRequest; cr != nil { if cr := body.IdentityCSR.CertificateRequest; cr != nil {
ctx := authority.NewContextWithSkipTokenReuse(r.Context()) ctx := authority.NewContextWithSkipTokenReuse(r.Context())
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
@ -327,7 +330,7 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
NotAfter: time.Unix(int64(cert.ValidBefore), 0), NotAfter: time.Unix(int64(cert.ValidBefore), 0),
}) })
certChain, err := h.Authority.Sign(cr, provisioner.SignOptions{}, signOpts...) certChain, err := a.Sign(cr, provisioner.SignOptions{}, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error signing identity certificate"))
return return
@ -344,8 +347,9 @@ func (h *caHandler) SSHSign(w http.ResponseWriter, r *http.Request) {
// SSHRoots is an HTTP handler that returns the SSH public keys for user and host // SSHRoots is an HTTP handler that returns the SSH public keys for user and host
// certificates. // certificates.
func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) { func SSHRoots(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHRoots(r.Context()) ctx := r.Context()
keys, err := mustAuthority(ctx).GetSSHRoots(ctx)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -369,8 +373,9 @@ func (h *caHandler) SSHRoots(w http.ResponseWriter, r *http.Request) {
// SSHFederation is an HTTP handler that returns the federated SSH public keys // SSHFederation is an HTTP handler that returns the federated SSH public keys
// for user and host certificates. // for user and host certificates.
func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) { func SSHFederation(w http.ResponseWriter, r *http.Request) {
keys, err := h.Authority.GetSSHFederation(r.Context()) ctx := r.Context()
keys, err := mustAuthority(ctx).GetSSHFederation(ctx)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -394,7 +399,7 @@ func (h *caHandler) SSHFederation(w http.ResponseWriter, r *http.Request) {
// SSHConfig is an HTTP handler that returns rendered templates for ssh clients // SSHConfig is an HTTP handler that returns rendered templates for ssh clients
// and servers. // and servers.
func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) { func SSHConfig(w http.ResponseWriter, r *http.Request) {
var body SSHConfigRequest var body SSHConfigRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -405,7 +410,8 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
ts, err := h.Authority.GetSSHConfig(r.Context(), body.Type, body.Data) ctx := r.Context()
ts, err := mustAuthority(ctx).GetSSHConfig(ctx, body.Type, body.Data)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -426,7 +432,7 @@ func (h *caHandler) SSHConfig(w http.ResponseWriter, r *http.Request) {
} }
// SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not. // SSHCheckHost is the HTTP handler that returns if a hosts certificate exists or not.
func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) { func SSHCheckHost(w http.ResponseWriter, r *http.Request) {
var body SSHCheckPrincipalRequest var body SSHCheckPrincipalRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -437,7 +443,8 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
return return
} }
exists, err := h.Authority.CheckSSHHost(r.Context(), body.Principal, body.Token) ctx := r.Context()
exists, err := mustAuthority(ctx).CheckSSHHost(ctx, body.Principal, body.Token)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -448,13 +455,14 @@ func (h *caHandler) SSHCheckHost(w http.ResponseWriter, r *http.Request) {
} }
// SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts. // SSHGetHosts is the HTTP handler that returns a list of valid ssh hosts.
func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) { func SSHGetHosts(w http.ResponseWriter, r *http.Request) {
var cert *x509.Certificate var cert *x509.Certificate
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 { if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
cert = r.TLS.PeerCertificates[0] cert = r.TLS.PeerCertificates[0]
} }
hosts, err := h.Authority.GetSSHHosts(r.Context(), cert) ctx := r.Context()
hosts, err := mustAuthority(ctx).GetSSHHosts(ctx, cert)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -465,7 +473,7 @@ func (h *caHandler) SSHGetHosts(w http.ResponseWriter, r *http.Request) {
} }
// SSHBastion provides returns the bastion configured if any. // SSHBastion provides returns the bastion configured if any.
func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) { func SSHBastion(w http.ResponseWriter, r *http.Request) {
var body SSHBastionRequest var body SSHBastionRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -476,7 +484,8 @@ func (h *caHandler) SSHBastion(w http.ResponseWriter, r *http.Request) {
return return
} }
bastion, err := h.Authority.GetSSHBastion(r.Context(), body.User, body.Hostname) ctx := r.Context()
bastion, err := mustAuthority(ctx).GetSSHBastion(ctx, body.User, body.Hostname)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return

@ -39,7 +39,7 @@ type SSHRekeyResponse struct {
// SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token // SSHRekey is an HTTP handler that reads an RekeySSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in // (ott) from the body and creates a new SSH certificate with the information in
// the request. // the request.
func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) { func SSHRekey(w http.ResponseWriter, r *http.Request) {
var body SSHRekeyRequest var body SSHRekeyRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -59,7 +59,10 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRekeyMethod)
signOpts, err := h.Authority.Authorize(ctx, body.OTT) ctx = provisioner.NewContextWithToken(ctx, body.OTT)
a := mustAuthority(ctx)
signOpts, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
@ -70,7 +73,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
return return
} }
newCert, err := h.Authority.RekeySSH(ctx, oldCert, publicKey, signOpts...) newCert, err := a.RekeySSH(ctx, oldCert, publicKey, signOpts...)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error rekeying ssh certificate"))
return return
@ -80,7 +83,7 @@ func (h *caHandler) SSHRekey(w http.ResponseWriter, r *http.Request) {
notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return

@ -37,7 +37,7 @@ type SSHRenewResponse struct {
// SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token // SSHRenew is an HTTP handler that reads an RenewSSHRequest with a one-time-token
// (ott) from the body and creates a new SSH certificate with the information in // (ott) from the body and creates a new SSH certificate with the information in
// the request. // the request.
func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) { func SSHRenew(w http.ResponseWriter, r *http.Request) {
var body SSHRenewRequest var body SSHRenewRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -51,7 +51,10 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRenewMethod)
_, err := h.Authority.Authorize(ctx, body.OTT) ctx = provisioner.NewContextWithToken(ctx, body.OTT)
a := mustAuthority(ctx)
_, err := a.Authorize(ctx, body.OTT)
if err != nil { if err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
@ -62,7 +65,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
return return
} }
newCert, err := h.Authority.RenewSSH(ctx, oldCert) newCert, err := a.RenewSSH(ctx, oldCert)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing ssh certificate"))
return return
@ -72,7 +75,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
notBefore := time.Unix(int64(oldCert.ValidAfter), 0) notBefore := time.Unix(int64(oldCert.ValidAfter), 0)
notAfter := time.Unix(int64(oldCert.ValidBefore), 0) notAfter := time.Unix(int64(oldCert.ValidBefore), 0)
identity, err := h.renewIdentityCertificate(r, notBefore, notAfter) identity, err := renewIdentityCertificate(r, notBefore, notAfter)
if err != nil { if err != nil {
render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate")) render.Error(w, errs.ForbiddenErr(err, "error renewing identity certificate"))
return return
@ -85,7 +88,7 @@ func (h *caHandler) SSHRenew(w http.ResponseWriter, r *http.Request) {
} }
// renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the // renewIdentityCertificate request the client TLS certificate if present. If notBefore and notAfter are passed the
func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) { func renewIdentityCertificate(r *http.Request, notBefore, notAfter time.Time) ([]Certificate, error) {
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 { if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
return nil, nil return nil, nil
} }
@ -105,7 +108,7 @@ func (h *caHandler) renewIdentityCertificate(r *http.Request, notBefore, notAfte
cert.NotAfter = notAfter cert.NotAfter = notAfter
} }
certChain, err := h.Authority.Renew(cert) certChain, err := mustAuthority(r.Context()).Renew(cert)
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -48,7 +48,7 @@ func (r *SSHRevokeRequest) Validate() (err error) {
// Revoke supports handful of different methods that revoke a Certificate. // Revoke supports handful of different methods that revoke a Certificate.
// //
// NOTE: currently only Passive revocation is supported. // NOTE: currently only Passive revocation is supported.
func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) { func SSHRevoke(w http.ResponseWriter, r *http.Request) {
var body SSHRevokeRequest var body SSHRevokeRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, errs.BadRequestErr(err, "error reading request body")) render.Error(w, errs.BadRequestErr(err, "error reading request body"))
@ -68,16 +68,19 @@ func (h *caHandler) SSHRevoke(w http.ResponseWriter, r *http.Request) {
} }
ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod) ctx := provisioner.NewContextWithMethod(r.Context(), provisioner.SSHRevokeMethod)
a := mustAuthority(ctx)
// A token indicates that we are using the api via a provisioner token, // A token indicates that we are using the api via a provisioner token,
// otherwise it is assumed that the certificate is revoking itself over mTLS. // otherwise it is assumed that the certificate is revoking itself over mTLS.
logOtt(w, body.OTT) logOtt(w, body.OTT)
if _, err := h.Authority.Authorize(ctx, body.OTT); err != nil {
if _, err := a.Authorize(ctx, body.OTT); err != nil {
render.Error(w, errs.UnauthorizedErr(err)) render.Error(w, errs.UnauthorizedErr(err))
return return
} }
opts.OTT = body.OTT opts.OTT = body.OTT
if err := h.Authority.Revoke(ctx, opts); err != nil { if err := a.Revoke(ctx, opts); err != nil {
render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate")) render.Error(w, errs.ForbiddenErr(err, "error revoking ssh certificate"))
return return
} }

@ -251,7 +251,7 @@ func TestSignSSHRequest_Validate(t *testing.T) {
} }
} }
func Test_caHandler_SSHSign(t *testing.T) { func Test_SSHSign(t *testing.T) {
user, err := getSignedUserCertificate() user, err := getSignedUserCertificate()
assert.FatalError(t, err) assert.FatalError(t, err)
host, err := getSignedHostCertificate() host, err := getSignedHostCertificate()
@ -315,8 +315,8 @@ func Test_caHandler_SSHSign(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
authorizeSign: func(ott string) ([]provisioner.SignOption, error) { authorize: func(ctx context.Context, ott string) ([]provisioner.SignOption, error) {
return []provisioner.SignOption{}, tt.authErr return []provisioner.SignOption{}, tt.authErr
}, },
signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) { signSSH: func(ctx context.Context, key ssh.PublicKey, opts provisioner.SignSSHOptions, signOpts ...provisioner.SignOption) (*ssh.Certificate, error) {
@ -328,11 +328,11 @@ func Test_caHandler_SSHSign(t *testing.T) {
sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { sign: func(cr *x509.CertificateRequest, opts provisioner.SignOptions, signOpts ...provisioner.SignOption) ([]*x509.Certificate, error) {
return tt.tlsSignCerts, tt.tlsSignErr return tt.tlsSignCerts, tt.tlsSignErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req)) req := httptest.NewRequest("POST", "http://example.com/ssh/sign", bytes.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHSign(logging.NewResponseLogger(w), req) SSHSign(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -353,7 +353,7 @@ func Test_caHandler_SSHSign(t *testing.T) {
} }
} }
func Test_caHandler_SSHRoots(t *testing.T) { func Test_SSHRoots(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public()) user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err) assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
@ -378,15 +378,15 @@ func Test_caHandler_SSHRoots(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) { getSSHRoots: func(ctx context.Context) (*authority.SSHKeys, error) {
return tt.keys, tt.keysErr return tt.keys, tt.keysErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody) req := httptest.NewRequest("GET", "http://example.com/ssh/roots", http.NoBody)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHRoots(logging.NewResponseLogger(w), req) SSHRoots(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -407,7 +407,7 @@ func Test_caHandler_SSHRoots(t *testing.T) {
} }
} }
func Test_caHandler_SSHFederation(t *testing.T) { func Test_SSHFederation(t *testing.T) {
user, err := ssh.NewPublicKey(sshUserKey.Public()) user, err := ssh.NewPublicKey(sshUserKey.Public())
assert.FatalError(t, err) assert.FatalError(t, err)
userB64 := base64.StdEncoding.EncodeToString(user.Marshal()) userB64 := base64.StdEncoding.EncodeToString(user.Marshal())
@ -432,15 +432,15 @@ func Test_caHandler_SSHFederation(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) { getSSHFederation: func(ctx context.Context) (*authority.SSHKeys, error) {
return tt.keys, tt.keysErr return tt.keys, tt.keysErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody) req := httptest.NewRequest("GET", "http://example.com/ssh/federation", http.NoBody)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHFederation(logging.NewResponseLogger(w), req) SSHFederation(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -461,7 +461,7 @@ func Test_caHandler_SSHFederation(t *testing.T) {
} }
} }
func Test_caHandler_SSHConfig(t *testing.T) { func Test_SSHConfig(t *testing.T) {
userOutput := []templates.Output{ userOutput := []templates.Output{
{Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")}, {Name: "config.tpl", Type: templates.File, Comment: "#", Path: "ssh/config", Content: []byte("UserKnownHostsFile /home/user/.step/ssh/known_hosts")},
{Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")}, {Name: "known_host.tpl", Type: templates.File, Comment: "#", Path: "ssh/known_host", Content: []byte("@cert-authority * ecdsa-sha2-nistp256 AAAA...=")},
@ -492,15 +492,15 @@ func Test_caHandler_SSHConfig(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) { getSSHConfig: func(ctx context.Context, typ string, data map[string]string) ([]templates.Output, error) {
return tt.output, tt.err return tt.output, tt.err
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req)) req := httptest.NewRequest("GET", "http://example.com/ssh/config", strings.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHConfig(logging.NewResponseLogger(w), req) SSHConfig(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -521,7 +521,7 @@ func Test_caHandler_SSHConfig(t *testing.T) {
} }
} }
func Test_caHandler_SSHCheckHost(t *testing.T) { func Test_SSHCheckHost(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
req string req string
@ -539,15 +539,15 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) { checkSSHHost: func(ctx context.Context, principal, token string) (bool, error) {
return tt.exists, tt.err return tt.exists, tt.err
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req)) req := httptest.NewRequest("GET", "http://example.com/ssh/check-host", strings.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHCheckHost(logging.NewResponseLogger(w), req) SSHCheckHost(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -568,7 +568,7 @@ func Test_caHandler_SSHCheckHost(t *testing.T) {
} }
} }
func Test_caHandler_SSHGetHosts(t *testing.T) { func Test_SSHGetHosts(t *testing.T) {
hosts := []authority.Host{ hosts := []authority.Host{
{HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"}, {HostID: "1", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}}, Hostname: "host1"},
{HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"}, {HostID: "2", HostTags: []authority.HostTag{{ID: "1", Name: "group", Value: "1"}, {ID: "2", Name: "group", Value: "2"}}, Hostname: "host2"},
@ -590,15 +590,15 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) { getSSHHosts: func(context.Context, *x509.Certificate) ([]authority.Host, error) {
return tt.hosts, tt.err return tt.hosts, tt.err
}, },
}).(*caHandler) })
req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody) req := httptest.NewRequest("GET", "http://example.com/ssh/host", http.NoBody)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHGetHosts(logging.NewResponseLogger(w), req) SSHGetHosts(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {
@ -619,7 +619,7 @@ func Test_caHandler_SSHGetHosts(t *testing.T) {
} }
} }
func Test_caHandler_SSHBastion(t *testing.T) { func Test_SSHBastion(t *testing.T) {
bastion := &authority.Bastion{ bastion := &authority.Bastion{
Hostname: "bastion.local", Hostname: "bastion.local",
} }
@ -645,15 +645,15 @@ func Test_caHandler_SSHBastion(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
h := New(&mockAuthority{ mockMustAuthority(t, &mockAuthority{
getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) { getSSHBastion: func(ctx context.Context, user, hostname string) (*authority.Bastion, error) {
return tt.bastion, tt.bastionErr return tt.bastion, tt.bastionErr
}, },
}).(*caHandler) })
req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req)) req := httptest.NewRequest("POST", "http://example.com/ssh/bastion", bytes.NewReader(tt.req))
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.SSHBastion(logging.NewResponseLogger(w), req) SSHBastion(logging.NewResponseLogger(w), req)
res := w.Result() res := w.Result()
if res.StatusCode != tt.statusCode { if res.StatusCode != tt.statusCode {

@ -1,22 +1,15 @@
package api package api
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
)
const (
// provisionerContextKey provisioner key
provisionerContextKey = ContextKey("provisioner")
) )
// CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests // CreateExternalAccountKeyRequest is the type for POST /admin/acme/eab requests
@ -40,78 +33,121 @@ type GetExternalAccountKeysResponse struct {
// requireEABEnabled is a middleware that ensures ACME EAB is enabled // requireEABEnabled is a middleware that ensures ACME EAB is enabled
// before serving requests that act on ACME EAB credentials. // before serving requests that act on ACME EAB credentials.
func (h *Handler) requireEABEnabled(next nextHTTP) nextHTTP { func requireEABEnabled(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
provName := chi.URLParam(r, "provisionerName") prov := linkedca.MustProvisionerFromContext(ctx)
eabEnabled, prov, err := h.provisionerHasEABEnabled(ctx, provName)
if err != nil { acmeProvisioner := prov.GetDetails().GetACME()
render.Error(w, err) if acmeProvisioner == nil {
render.Error(w, admin.NewErrorISE("error getting ACME details for provisioner '%s'", prov.GetName()))
return return
} }
if !eabEnabled {
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner %s", prov.GetName())) if !acmeProvisioner.RequireEab {
render.Error(w, admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner '%s'", prov.GetName()))
return return
} }
ctx = context.WithValue(ctx, provisionerContextKey, prov)
next(w, r.WithContext(ctx))
}
}
// provisionerHasEABEnabled determines if the "requireEAB" setting for an ACME
// provisioner is set to true and thus has EAB enabled.
func (h *Handler) provisionerHasEABEnabled(ctx context.Context, provisionerName string) (bool, *linkedca.Provisioner, error) {
var (
p provisioner.Interface
err error
)
if p, err = h.auth.LoadProvisionerByName(provisionerName); err != nil {
return false, nil, admin.WrapErrorISE(err, "error loading provisioner %s", provisionerName)
}
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) next(w, r)
if err != nil {
return false, nil, admin.WrapErrorISE(err, "error getting provisioner with ID: %s", p.GetID())
} }
details := prov.GetDetails()
if details == nil {
return false, nil, admin.NewErrorISE("error getting details for provisioner with ID: %s", p.GetID())
}
acmeProvisioner := details.GetACME()
if acmeProvisioner == nil {
return false, nil, admin.NewErrorISE("error getting ACME details for provisioner with ID: %s", p.GetID())
}
return acmeProvisioner.GetRequireEab(), prov, nil
} }
type acmeAdminResponderInterface interface { // ACMEAdminResponder is responsible for writing ACME admin responses
type ACMEAdminResponder interface {
GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request)
CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request)
DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request)
} }
// ACMEAdminResponder is responsible for writing ACME admin responses // acmeAdminResponder implements ACMEAdminResponder.
type ACMEAdminResponder struct{} type acmeAdminResponder struct{}
// NewACMEAdminResponder returns a new ACMEAdminResponder // NewACMEAdminResponder returns a new ACMEAdminResponder
func NewACMEAdminResponder() *ACMEAdminResponder { func NewACMEAdminResponder() ACMEAdminResponder {
return &ACMEAdminResponder{} return &acmeAdminResponder{}
} }
// GetExternalAccountKeys writes the response for the EAB keys GET endpoint // GetExternalAccountKeys writes the response for the EAB keys GET endpoint
func (h *ACMEAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) { func (h *acmeAdminResponder) GetExternalAccountKeys(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// CreateExternalAccountKey writes the response for the EAB key POST endpoint // CreateExternalAccountKey writes the response for the EAB key POST endpoint
func (h *ACMEAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) { func (h *acmeAdminResponder) CreateExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
// DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint // DeleteExternalAccountKey writes the response for the EAB key DELETE endpoint
func (h *ACMEAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) { func (h *acmeAdminResponder) DeleteExternalAccountKey(w http.ResponseWriter, r *http.Request) {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm")) render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "this functionality is currently only available in Certificate Manager: https://u.step.sm/cm"))
} }
func eakToLinked(k *acme.ExternalAccountKey) *linkedca.EABKey {
if k == nil {
return nil
}
eak := &linkedca.EABKey{
Id: k.ID,
HmacKey: k.HmacKey,
Provisioner: k.ProvisionerID,
Reference: k.Reference,
Account: k.AccountID,
CreatedAt: timestamppb.New(k.CreatedAt),
BoundAt: timestamppb.New(k.BoundAt),
}
if k.Policy != nil {
eak.Policy = &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{},
Deny: &linkedca.X509Names{},
},
}
eak.Policy.X509.Allow.Dns = k.Policy.X509.Allowed.DNSNames
eak.Policy.X509.Allow.Ips = k.Policy.X509.Allowed.IPRanges
eak.Policy.X509.Deny.Dns = k.Policy.X509.Denied.DNSNames
eak.Policy.X509.Deny.Ips = k.Policy.X509.Denied.IPRanges
eak.Policy.X509.AllowWildcardNames = k.Policy.X509.AllowWildcardNames
}
return eak
}
func linkedEAKToCertificates(k *linkedca.EABKey) *acme.ExternalAccountKey {
if k == nil {
return nil
}
eak := &acme.ExternalAccountKey{
ID: k.Id,
ProvisionerID: k.Provisioner,
Reference: k.Reference,
AccountID: k.Account,
HmacKey: k.HmacKey,
CreatedAt: k.CreatedAt.AsTime(),
BoundAt: k.BoundAt.AsTime(),
}
if policy := k.GetPolicy(); policy != nil {
eak.Policy = &acme.Policy{}
if x509 := policy.GetX509(); x509 != nil {
eak.Policy.X509 = acme.X509Policy{}
if allow := x509.GetAllow(); allow != nil {
eak.Policy.X509.Allowed = acme.PolicyNames{}
eak.Policy.X509.Allowed.DNSNames = allow.Dns
eak.Policy.X509.Allowed.IPRanges = allow.Ips
}
if deny := x509.GetDeny(); deny != nil {
eak.Policy.X509.Denied = acme.PolicyNames{}
eak.Policy.X509.Denied.DNSNames = deny.Dns
eak.Policy.X509.Denied.IPRanges = deny.Ips
}
eak.Policy.X509.AllowWildcardNames = x509.AllowWildcardNames
}
}
return eak
}

@ -4,20 +4,24 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"strings" "strings"
"testing" "testing"
"time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/admin"
) )
func readProtoJSON(r io.ReadCloser, m proto.Message) error { func readProtoJSON(r io.ReadCloser, m proto.Message) error {
@ -29,109 +33,90 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error {
return protojson.Unmarshal(data, m) return protojson.Unmarshal(data, m)
} }
func mockMustAuthority(t *testing.T, a adminAuthority) {
t.Helper()
fn := mustAuthority
t.Cleanup(func() {
mustAuthority = fn
})
mustAuthority = func(ctx context.Context) adminAuthority {
return a
}
}
func TestHandler_requireEABEnabled(t *testing.T) { func TestHandler_requireEABEnabled(t *testing.T) {
type test struct { type test struct {
ctx context.Context ctx context.Context
adminDB admin.DB next http.HandlerFunc
auth adminAuthority
next nextHTTP
err *admin.Error err *admin.Error
statusCode int statusCode int
} }
var tests = map[string]func(t *testing.T) test{ var tests = map[string]func(t *testing.T) test{
"fail/h.provisionerHasEABEnabled": func(t *testing.T) test { "fail/prov.GetDetails": func(t *testing.T) test {
chiCtx := chi.NewRouteContext() prov := &linkedca.Provisioner{
chiCtx.URLParams.Add("provisionerName", "provName") Id: "provID",
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) Name: "provName",
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return nil, errors.New("force")
},
} }
err := admin.NewErrorISE("error loading provisioner provName: force") ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
err.Message = "error loading provisioner provName: force" err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
err.Message = "error getting ACME details for provisioner 'provName'"
return test{ return test{
ctx: ctx, ctx: ctx,
auth: auth,
err: err, err: err,
statusCode: 500, statusCode: 500,
} }
}, },
"ok/eab-disabled": func(t *testing.T) test { "fail/prov.GetDetails.GetACME": func(t *testing.T) test {
chiCtx := chi.NewRouteContext() prov := &linkedca.Provisioner{
chiCtx.URLParams.Add("provisionerName", "provName") Id: "provID",
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) Name: "provName",
auth := &mockAdminAuthority{ Details: &linkedca.ProvisionerDetails{},
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { }
assert.Equals(t, "provName", name) ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
return &provisioner.MockProvisioner{ err := admin.NewErrorISE("error getting ACME details for provisioner 'provName'")
MgetID: func() string { err.Message = "error getting ACME details for provisioner 'provName'"
return "provID" return test{
}, ctx: ctx,
}, nil err: err,
}, statusCode: 500,
} }
db := &admin.MockDB{ },
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) { "ok/eab-disabled": func(t *testing.T) test {
assert.Equals(t, "provID", id) prov := &linkedca.Provisioner{
return &linkedca.Provisioner{ Id: "provID",
Id: "provID", Name: "provName",
Name: "provName", Details: &linkedca.ProvisionerDetails{
Details: &linkedca.ProvisionerDetails{ Data: &linkedca.ProvisionerDetails_ACME{
Data: &linkedca.ProvisionerDetails_ACME{ ACME: &linkedca.ACMEProvisioner{
ACME: &linkedca.ACMEProvisioner{ RequireEab: false,
RequireEab: false,
},
},
}, },
}, nil },
}, },
} }
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName") err := admin.NewError(admin.ErrorBadRequestType, "ACME EAB not enabled for provisioner provName")
err.Message = "ACME EAB not enabled for provisioner provName" err.Message = "ACME EAB not enabled for provisioner 'provName'"
return test{ return test{
ctx: ctx, ctx: ctx,
auth: auth,
adminDB: db,
err: err, err: err,
statusCode: 400, statusCode: 400,
} }
}, },
"ok/eab-enabled": func(t *testing.T) test { "ok/eab-enabled": func(t *testing.T) test {
chiCtx := chi.NewRouteContext() prov := &linkedca.Provisioner{
chiCtx.URLParams.Add("provisionerName", "provName") Id: "provID",
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx) Name: "provName",
auth := &mockAdminAuthority{ Details: &linkedca.ProvisionerDetails{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) { Data: &linkedca.ProvisionerDetails_ACME{
assert.Equals(t, "provName", name) ACME: &linkedca.ACMEProvisioner{
return &provisioner.MockProvisioner{ RequireEab: true,
MgetID: func() string {
return "provID"
}, },
}, nil },
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: true,
},
},
},
}, nil
}, },
} }
ctx := linkedca.NewContextWithProvisioner(context.Background(), prov)
return test{ return test{
ctx: ctx, ctx: ctx,
auth: auth,
adminDB: db,
next: func(w http.ResponseWriter, r *http.Request) { next: func(w http.ResponseWriter, r *http.Request) {
w.Write(nil) // mock response with status 200 w.Write(nil) // mock response with status 200
}, },
@ -143,16 +128,9 @@ func TestHandler_requireEABEnabled(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ req := httptest.NewRequest("GET", "/foo", nil).WithContext(tc.ctx)
auth: tc.auth,
adminDB: tc.adminDB,
acmeDB: nil,
}
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.requireEABEnabled(tc.next)(w, req) requireEABEnabled(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -176,216 +154,6 @@ func TestHandler_requireEABEnabled(t *testing.T) {
} }
} }
func TestHandler_provisionerHasEABEnabled(t *testing.T) {
type test struct {
adminDB admin.DB
auth adminAuthority
provisionerName string
want bool
err *admin.Error
}
var tests = map[string]func(t *testing.T) test{
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return nil, errors.New("force")
},
}
return test{
auth: auth,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"fail/db.GetProvisioner": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return nil, errors.New("force")
},
}
return test{
auth: auth,
adminDB: db,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"fail/prov.GetDetails": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: nil,
}, nil
},
}
return test{
auth: auth,
adminDB: db,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"fail/details.GetACME": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: nil,
},
},
}, nil
},
}
return test{
auth: auth,
adminDB: db,
provisionerName: "provName",
want: false,
err: admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName"),
}
},
"ok/eab-disabled": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "eab-disabled", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "eab-disabled",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: false,
},
},
},
}, nil
},
}
return test{
adminDB: db,
auth: auth,
provisionerName: "eab-disabled",
want: false,
}
},
"ok/eab-enabled": func(t *testing.T) test {
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "eab-enabled", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "eab-enabled",
Details: &linkedca.ProvisionerDetails{
Data: &linkedca.ProvisionerDetails_ACME{
ACME: &linkedca.ACMEProvisioner{
RequireEab: true,
},
},
},
}, nil
},
}
return test{
adminDB: db,
auth: auth,
provisionerName: "eab-enabled",
want: true,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
h := &Handler{
auth: tc.auth,
adminDB: tc.adminDB,
acmeDB: nil,
}
got, prov, err := h.provisionerHasEABEnabled(context.TODO(), tc.provisionerName)
if (err != nil) != (tc.err != nil) {
t.Errorf("Handler.provisionerHasEABEnabled() error = %v, want err %v", err, tc.err)
return
}
if tc.err != nil {
assert.Type(t, &linkedca.Provisioner{}, prov)
assert.Type(t, &admin.Error{}, err)
adminError, _ := err.(*admin.Error)
assert.Equals(t, tc.err.Type, adminError.Type)
assert.Equals(t, tc.err.Status, adminError.Status)
assert.Equals(t, tc.err.StatusCode(), adminError.StatusCode())
assert.Equals(t, tc.err.Message, adminError.Message)
assert.Equals(t, tc.err.Detail, adminError.Detail)
return
}
if got != tc.want {
t.Errorf("Handler.provisionerHasEABEnabled() = %v, want %v", got, tc.want)
}
})
}
}
func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) { func TestCreateExternalAccountKeyRequest_Validate(t *testing.T) {
type fields struct { type fields struct {
Reference string Reference string
@ -585,3 +353,206 @@ func TestHandler_GetExternalAccountKeys(t *testing.T) {
}) })
} }
} }
func Test_eakToLinked(t *testing.T) {
tests := []struct {
name string
k *acme.ExternalAccountKey
want *linkedca.EABKey
}{
{
name: "no-key",
k: nil,
want: nil,
},
{
name: "no-policy",
k: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: nil,
},
want: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: nil,
},
},
{
name: "with-policy",
k: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.local"},
IPRanges: []string{"10.0.0.0/24"},
},
Denied: acme.PolicyNames{
DNSNames: []string{"badhost.local"},
IPRanges: []string{"10.0.0.30"},
},
},
},
},
want: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
Ips: []string{"10.0.0.0/24"},
},
Deny: &linkedca.X509Names{
Dns: []string{"badhost.local"},
Ips: []string{"10.0.0.30"},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := eakToLinked(tt.k); !reflect.DeepEqual(got, tt.want) {
t.Errorf("eakToLinked() = %v, want %v", got, tt.want)
}
})
}
}
func Test_linkedEAKToCertificates(t *testing.T) {
tests := []struct {
name string
k *linkedca.EABKey
want *acme.ExternalAccountKey
}{
{
name: "no-key",
k: nil,
want: nil,
},
{
name: "no-policy",
k: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: nil,
},
want: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: nil,
},
},
{
name: "no-x509-policy",
k: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: &linkedca.Policy{},
},
want: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: &acme.Policy{},
},
},
{
name: "with-x509-policy",
k: &linkedca.EABKey{
Id: "keyID",
Provisioner: "provID",
HmacKey: []byte{1, 3, 3, 7},
Reference: "ref",
Account: "accID",
CreatedAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour)),
BoundAt: timestamppb.New(time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC)),
Policy: &linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
Ips: []string{"10.0.0.0/24"},
},
Deny: &linkedca.X509Names{
Dns: []string{"badhost.local"},
Ips: []string{"10.0.0.30"},
},
AllowWildcardNames: true,
},
},
},
want: &acme.ExternalAccountKey{
ID: "keyID",
ProvisionerID: "provID",
Reference: "ref",
AccountID: "accID",
HmacKey: []byte{1, 3, 3, 7},
CreatedAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC).Add(-1 * time.Hour),
BoundAt: time.Date(2022, 04, 12, 9, 30, 30, 0, time.UTC),
Policy: &acme.Policy{
X509: acme.X509Policy{
Allowed: acme.PolicyNames{
DNSNames: []string{"*.local"},
IPRanges: []string{"10.0.0.0/24"},
},
Denied: acme.PolicyNames{
DNSNames: []string{"badhost.local"},
IPRanges: []string{"10.0.0.30"},
},
AllowWildcardNames: true,
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := linkedEAKToCertificates(tt.k); !reflect.DeepEqual(got, tt.want) {
t.Errorf("linkedEAKToCertificates() = %v, want %v", got, tt.want)
}
})
}
}

@ -29,6 +29,10 @@ type adminAuthority interface {
LoadProvisionerByID(id string) (provisioner.Interface, error) LoadProvisionerByID(id string) (provisioner.Interface, error)
UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error UpdateProvisioner(ctx context.Context, nu *linkedca.Provisioner) error
RemoveProvisioner(ctx context.Context, id string) error RemoveProvisioner(ctx context.Context, id string) error
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
CreateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
UpdateAuthorityPolicy(ctx context.Context, admin *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
RemoveAuthorityPolicy(ctx context.Context) error
} }
// CreateAdminRequest represents the body for a CreateAdmin request. // CreateAdminRequest represents the body for a CreateAdmin request.
@ -81,10 +85,10 @@ type DeleteResponse struct {
} }
// GetAdmin returns the requested admin, or an error. // GetAdmin returns the requested admin, or an error.
func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) { func GetAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
adm, ok := h.auth.LoadAdminByID(id) adm, ok := mustAuthority(r.Context()).LoadAdminByID(id)
if !ok { if !ok {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, render.Error(w, admin.NewError(admin.ErrorNotFoundType,
"admin %s not found", id)) "admin %s not found", id))
@ -94,7 +98,7 @@ func (h *Handler) GetAdmin(w http.ResponseWriter, r *http.Request) {
} }
// GetAdmins returns a segment of admins associated with the authority. // GetAdmins returns a segment of admins associated with the authority.
func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) { func GetAdmins(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
@ -102,7 +106,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
return return
} }
admins, nextCursor, err := h.auth.GetAdmins(cursor, limit) admins, nextCursor, err := mustAuthority(r.Context()).GetAdmins(cursor, limit)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins")) render.Error(w, admin.WrapErrorISE(err, "error retrieving paginated admins"))
return return
@ -114,7 +118,7 @@ func (h *Handler) GetAdmins(w http.ResponseWriter, r *http.Request) {
} }
// CreateAdmin creates a new admin. // CreateAdmin creates a new admin.
func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) { func CreateAdmin(w http.ResponseWriter, r *http.Request) {
var body CreateAdminRequest var body CreateAdminRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
@ -126,7 +130,8 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
return return
} }
p, err := h.auth.LoadProvisionerByName(body.Provisioner) auth := mustAuthority(r.Context())
p, err := auth.LoadProvisionerByName(body.Provisioner)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", body.Provisioner))
return return
@ -137,7 +142,7 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
Type: body.Type, Type: body.Type,
} }
// Store to authority collection. // Store to authority collection.
if err := h.auth.StoreAdmin(r.Context(), adm, p); err != nil { if err := auth.StoreAdmin(r.Context(), adm, p); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error storing admin")) render.Error(w, admin.WrapErrorISE(err, "error storing admin"))
return return
} }
@ -146,10 +151,10 @@ func (h *Handler) CreateAdmin(w http.ResponseWriter, r *http.Request) {
} }
// DeleteAdmin deletes admin. // DeleteAdmin deletes admin.
func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) { func DeleteAdmin(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
if err := h.auth.RemoveAdmin(r.Context(), id); err != nil { if err := mustAuthority(r.Context()).RemoveAdmin(r.Context(), id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id)) render.Error(w, admin.WrapErrorISE(err, "error deleting admin %s", id))
return return
} }
@ -158,7 +163,7 @@ func (h *Handler) DeleteAdmin(w http.ResponseWriter, r *http.Request) {
} }
// UpdateAdmin updates an existing admin. // UpdateAdmin updates an existing admin.
func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) { func UpdateAdmin(w http.ResponseWriter, r *http.Request) {
var body UpdateAdminRequest var body UpdateAdminRequest
if err := read.JSON(r.Body, &body); err != nil { if err := read.JSON(r.Body, &body); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body")) render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error reading request body"))
@ -171,8 +176,8 @@ func (h *Handler) UpdateAdmin(w http.ResponseWriter, r *http.Request) {
} }
id := chi.URLParam(r, "id") id := chi.URLParam(r, "id")
auth := mustAuthority(r.Context())
adm, err := h.auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type}) adm, err := auth.UpdateAdmin(r.Context(), id, &linkedca.Admin{Type: body.Type})
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id)) render.Error(w, admin.WrapErrorISE(err, "error updating admin %s", id))
return return

@ -14,11 +14,13 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/types/known/timestamppb"
) )
type mockAdminAuthority struct { type mockAdminAuthority struct {
@ -37,6 +39,11 @@ type mockAdminAuthority struct {
MockLoadProvisionerByID func(id string) (provisioner.Interface, error) MockLoadProvisionerByID func(id string) (provisioner.Interface, error)
MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error MockUpdateProvisioner func(ctx context.Context, nu *linkedca.Provisioner) error
MockRemoveProvisioner func(ctx context.Context, id string) error MockRemoveProvisioner func(ctx context.Context, id string) error
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
MockCreateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
MockUpdateAuthorityPolicy func(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error)
MockRemoveAuthorityPolicy func(ctx context.Context) error
} }
func (m *mockAdminAuthority) IsAdminAPIEnabled() bool { func (m *mockAdminAuthority) IsAdminAPIEnabled() bool {
@ -130,6 +137,34 @@ func (m *mockAdminAuthority) RemoveProvisioner(ctx context.Context, id string) e
return m.MockErr return m.MockErr
} }
func (m *mockAdminAuthority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
if m.MockGetAuthorityPolicy != nil {
return m.MockGetAuthorityPolicy(ctx)
}
return m.MockRet1.(*linkedca.Policy), m.MockErr
}
func (m *mockAdminAuthority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
if m.MockCreateAuthorityPolicy != nil {
return m.MockCreateAuthorityPolicy(ctx, adm, policy)
}
return m.MockRet1.(*linkedca.Policy), m.MockErr
}
func (m *mockAdminAuthority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, policy *linkedca.Policy) (*linkedca.Policy, error) {
if m.MockUpdateAuthorityPolicy != nil {
return m.MockUpdateAuthorityPolicy(ctx, adm, policy)
}
return m.MockRet1.(*linkedca.Policy), m.MockErr
}
func (m *mockAdminAuthority) RemoveAuthorityPolicy(ctx context.Context) error {
if m.MockRemoveAuthorityPolicy != nil {
return m.MockRemoveAuthorityPolicy(ctx)
}
return m.MockErr
}
func TestCreateAdminRequest_Validate(t *testing.T) { func TestCreateAdminRequest_Validate(t *testing.T) {
type fields struct { type fields struct {
Subject string Subject string
@ -317,14 +352,11 @@ func TestHandler_GetAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetAdmin(w, req) GetAdmin(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -456,13 +488,10 @@ func TestHandler_GetAdmins(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := tc.req.WithContext(tc.ctx) req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetAdmins(w, req) GetAdmins(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -640,13 +669,11 @@ func TestHandler_CreateAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.CreateAdmin(w, req) CreateAdmin(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -732,13 +759,11 @@ func TestHandler_DeleteAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup req := httptest.NewRequest("DELETE", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.DeleteAdmin(w, req) DeleteAdmin(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -877,13 +902,11 @@ func TestHandler_UpdateAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("GET", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.UpdateAdmin(w, req) UpdateAdmin(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)

@ -1,56 +1,117 @@
package api package api
import ( import (
"context"
"net/http"
"github.com/smallstep/certificates/acme" "github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api" "github.com/smallstep/certificates/api"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
) )
// Handler is the Admin API request handler. // Handler is the Admin API request handler.
type Handler struct { type Handler struct {
adminDB admin.DB acmeResponder ACMEAdminResponder
auth adminAuthority policyResponder PolicyAdminResponder
acmeDB acme.DB }
acmeResponder acmeAdminResponderInterface
// Route traffic and implement the Router interface.
//
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
func (h *Handler) Route(r api.Router) {
Route(r, h.acmeResponder, h.policyResponder)
} }
// NewHandler returns a new Authority Config Handler. // NewHandler returns a new Authority Config Handler.
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder acmeAdminResponderInterface) api.RouterHandler { //
// Deprecated: use Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder)
func NewHandler(auth adminAuthority, adminDB admin.DB, acmeDB acme.DB, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) api.RouterHandler {
return &Handler{ return &Handler{
auth: auth, acmeResponder: acmeResponder,
adminDB: adminDB, policyResponder: policyResponder,
acmeDB: acmeDB,
acmeResponder: acmeResponder,
} }
} }
var mustAuthority = func(ctx context.Context) adminAuthority {
return authority.MustFromContext(ctx)
}
// Route traffic and implement the Router interface. // Route traffic and implement the Router interface.
func (h *Handler) Route(r api.Router) { func Route(r api.Router, acmeResponder ACMEAdminResponder, policyResponder PolicyAdminResponder) {
authnz := func(next nextHTTP) nextHTTP { authnz := func(next http.HandlerFunc) http.HandlerFunc {
return h.extractAuthorizeTokenAdmin(h.requireAPIEnabled(next)) return extractAuthorizeTokenAdmin(requireAPIEnabled(next))
}
enabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
return checkAction(next, true)
}
disabledInStandalone := func(next http.HandlerFunc) http.HandlerFunc {
return checkAction(next, false)
} }
requireEABEnabled := func(next nextHTTP) nextHTTP { acmeEABMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return h.requireEABEnabled(next) return authnz(loadProvisionerByName(requireEABEnabled(next)))
}
authorityPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(enabledInStandalone(next))
}
provisionerPolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(disabledInStandalone(loadProvisionerByName(next)))
}
acmePolicyMiddleware := func(next http.HandlerFunc) http.HandlerFunc {
return authnz(disabledInStandalone(loadProvisionerByName(requireEABEnabled(loadExternalAccountKey(next)))))
} }
// Provisioners // Provisioners
r.MethodFunc("GET", "/provisioners/{name}", authnz(h.GetProvisioner)) r.MethodFunc("GET", "/provisioners/{name}", authnz(GetProvisioner))
r.MethodFunc("GET", "/provisioners", authnz(h.GetProvisioners)) r.MethodFunc("GET", "/provisioners", authnz(GetProvisioners))
r.MethodFunc("POST", "/provisioners", authnz(h.CreateProvisioner)) r.MethodFunc("POST", "/provisioners", authnz(CreateProvisioner))
r.MethodFunc("PUT", "/provisioners/{name}", authnz(h.UpdateProvisioner)) r.MethodFunc("PUT", "/provisioners/{name}", authnz(UpdateProvisioner))
r.MethodFunc("DELETE", "/provisioners/{name}", authnz(h.DeleteProvisioner)) r.MethodFunc("DELETE", "/provisioners/{name}", authnz(DeleteProvisioner))
// Admins // Admins
r.MethodFunc("GET", "/admins/{id}", authnz(h.GetAdmin)) r.MethodFunc("GET", "/admins/{id}", authnz(GetAdmin))
r.MethodFunc("GET", "/admins", authnz(h.GetAdmins)) r.MethodFunc("GET", "/admins", authnz(GetAdmins))
r.MethodFunc("POST", "/admins", authnz(h.CreateAdmin)) r.MethodFunc("POST", "/admins", authnz(CreateAdmin))
r.MethodFunc("PATCH", "/admins/{id}", authnz(h.UpdateAdmin)) r.MethodFunc("PATCH", "/admins/{id}", authnz(UpdateAdmin))
r.MethodFunc("DELETE", "/admins/{id}", authnz(h.DeleteAdmin)) r.MethodFunc("DELETE", "/admins/{id}", authnz(DeleteAdmin))
// ACME External Account Binding Keys // ACME responder
r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) if acmeResponder != nil {
r.MethodFunc("GET", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.GetExternalAccountKeys))) // ACME External Account Binding Keys
r.MethodFunc("POST", "/acme/eab/{provisionerName}", authnz(requireEABEnabled(h.acmeResponder.CreateExternalAccountKey))) r.MethodFunc("GET", "/acme/eab/{provisionerName}/{reference}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", authnz(requireEABEnabled(h.acmeResponder.DeleteExternalAccountKey))) r.MethodFunc("GET", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.GetExternalAccountKeys))
r.MethodFunc("POST", "/acme/eab/{provisionerName}", acmeEABMiddleware(acmeResponder.CreateExternalAccountKey))
r.MethodFunc("DELETE", "/acme/eab/{provisionerName}/{id}", acmeEABMiddleware(acmeResponder.DeleteExternalAccountKey))
}
// Policy responder
if policyResponder != nil {
// Policy - Authority
r.MethodFunc("GET", "/policy", authorityPolicyMiddleware(policyResponder.GetAuthorityPolicy))
r.MethodFunc("POST", "/policy", authorityPolicyMiddleware(policyResponder.CreateAuthorityPolicy))
r.MethodFunc("PUT", "/policy", authorityPolicyMiddleware(policyResponder.UpdateAuthorityPolicy))
r.MethodFunc("DELETE", "/policy", authorityPolicyMiddleware(policyResponder.DeleteAuthorityPolicy))
// Policy - Provisioner
r.MethodFunc("GET", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.GetProvisionerPolicy))
r.MethodFunc("POST", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.CreateProvisionerPolicy))
r.MethodFunc("PUT", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.UpdateProvisionerPolicy))
r.MethodFunc("DELETE", "/provisioners/{provisionerName}/policy", provisionerPolicyMiddleware(policyResponder.DeleteProvisionerPolicy))
// Policy - ACME Account
r.MethodFunc("GET", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
r.MethodFunc("GET", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.GetACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("POST", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.CreateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("PUT", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.UpdateACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/reference/{reference}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
r.MethodFunc("DELETE", "/acme/policy/{provisionerName}/key/{keyID}", acmePolicyMiddleware(policyResponder.DeleteACMEAccountPolicy))
}
} }

@ -1,22 +1,26 @@
package api package api
import ( import (
"context" "errors"
"net/http" "net/http"
"github.com/go-chi/chi"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/render" "github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/provisioner"
) )
type nextHTTP = func(http.ResponseWriter, *http.Request)
// requireAPIEnabled is a middleware that ensures the Administration API // requireAPIEnabled is a middleware that ensures the Administration API
// is enabled before servicing requests. // is enabled before servicing requests.
func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP { func requireAPIEnabled(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if !h.auth.IsAdminAPIEnabled() { if !mustAuthority(r.Context()).IsAdminAPIEnabled() {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType, render.Error(w, admin.NewError(admin.ErrorNotImplementedType, "administration API not enabled"))
"administration API not enabled"))
return return
} }
next(w, r) next(w, r)
@ -24,8 +28,9 @@ func (h *Handler) requireAPIEnabled(next nextHTTP) nextHTTP {
} }
// extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token. // extractAuthorizeTokenAdmin is a middleware that extracts and caches the bearer token.
func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP { func extractAuthorizeTokenAdmin(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
tok := r.Header.Get("Authorization") tok := r.Header.Get("Authorization")
if tok == "" { if tok == "" {
render.Error(w, admin.NewError(admin.ErrorUnauthorizedType, render.Error(w, admin.NewError(admin.ErrorUnauthorizedType,
@ -33,22 +38,111 @@ func (h *Handler) extractAuthorizeTokenAdmin(next nextHTTP) nextHTTP {
return return
} }
adm, err := h.auth.AuthorizeAdminToken(r, tok) ctx := r.Context()
adm, err := mustAuthority(ctx).AuthorizeAdminToken(r, tok)
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := context.WithValue(r.Context(), adminContextKey, adm) ctx = linkedca.NewContextWithAdmin(ctx, adm)
next(w, r.WithContext(ctx)) next(w, r.WithContext(ctx))
} }
} }
// ContextKey is the key type for storing and searching for ACME request // loadProvisionerByName is a middleware that searches for a provisioner
// essentials in the context of a request. // by name and stores it in the context.
type ContextKey string func loadProvisionerByName(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var (
p provisioner.Interface
err error
)
const ( ctx := r.Context()
// adminContextKey account key auth := mustAuthority(ctx)
adminContextKey = ContextKey("admin") adminDB := admin.MustFromContext(ctx)
) name := chi.URLParam(r, "provisionerName")
// TODO(hs): distinguish 404 vs. 500
if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return
}
prov, err := adminDB.GetProvisioner(ctx, p.GetID())
if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error retrieving provisioner %s", name))
return
}
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
next(w, r.WithContext(ctx))
}
}
// checkAction checks if an action is supported in standalone or not
func checkAction(next http.HandlerFunc, supportedInStandalone bool) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// actions allowed in standalone mode are always supported
if supportedInStandalone {
next(w, r)
return
}
// when an action is not supported in standalone mode and when
// using a nosql.DB backend, actions are not supported
if _, ok := admin.MustFromContext(r.Context()).(*nosql.DB); ok {
render.Error(w, admin.NewError(admin.ErrorNotImplementedType,
"operation not supported in standalone mode"))
return
}
// continue to next http handler
next(w, r)
}
}
// loadExternalAccountKey is a middleware that searches for an ACME
// External Account Key by reference or keyID and stores it in the context.
func loadExternalAccountKey(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
prov := linkedca.MustProvisionerFromContext(ctx)
acmeDB := acme.MustDatabaseFromContext(ctx)
reference := chi.URLParam(r, "reference")
keyID := chi.URLParam(r, "keyID")
var (
eak *acme.ExternalAccountKey
err error
)
if keyID != "" {
eak, err = acmeDB.GetExternalAccountKey(ctx, prov.GetId(), keyID)
} else {
eak, err = acmeDB.GetExternalAccountKeyByReference(ctx, prov.GetId(), reference)
}
if err != nil {
if errors.Is(err, acme.ErrNotFound) {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error retrieving ACME External Account Key"))
return
}
if eak == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found"))
return
}
linkedEAK := eakToLinked(eak)
ctx = linkedca.NewContextWithExternalAccountKey(ctx, linkedEAK)
next(w, r.WithContext(ctx))
}
}

@ -4,25 +4,32 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"time" "time"
"github.com/go-chi/chi"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"go.step.sm/linkedca" "github.com/smallstep/certificates/authority/admin/db/nosql"
"google.golang.org/protobuf/types/known/timestamppb" "github.com/smallstep/certificates/authority/provisioner"
) )
func TestHandler_requireAPIEnabled(t *testing.T) { func TestHandler_requireAPIEnabled(t *testing.T) {
type test struct { type test struct {
ctx context.Context ctx context.Context
auth adminAuthority auth adminAuthority
next nextHTTP next http.HandlerFunc
err *admin.Error err *admin.Error
statusCode int statusCode int
} }
@ -64,13 +71,11 @@ func TestHandler_requireAPIEnabled(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.requireAPIEnabled(tc.next)(w, req) requireAPIEnabled(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -102,7 +107,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
ctx context.Context ctx context.Context
auth adminAuthority auth adminAuthority
req *http.Request req *http.Request
next nextHTTP next http.HandlerFunc
err *admin.Error err *admin.Error
statusCode int statusCode int
} }
@ -152,7 +157,7 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
req.Header["Authorization"] = []string{"token"} req.Header["Authorization"] = []string{"token"}
createdAt := time.Now() createdAt := time.Now()
var deletedAt time.Time var deletedAt time.Time
admin := &linkedca.Admin{ adm := &linkedca.Admin{
Id: "adminID", Id: "adminID",
AuthorityId: "authorityID", AuthorityId: "authorityID",
Subject: "admin", Subject: "admin",
@ -164,20 +169,15 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
auth := &mockAdminAuthority{ auth := &mockAdminAuthority{
MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) { MockAuthorizeAdminToken: func(r *http.Request, token string) (*linkedca.Admin, error) {
assert.Equals(t, "token", token) assert.Equals(t, "token", token)
return admin, nil return adm, nil
}, },
} }
next := func(w http.ResponseWriter, r *http.Request) { next := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
a := ctx.Value(adminContextKey) // verifying that the context now has a linkedca.Admin adm := linkedca.MustAdminFromContext(ctx) // verifying that the context now has a linkedca.Admin
adm, ok := a.(*linkedca.Admin)
if !ok {
t.Errorf("expected *linkedca.Admin; got %T", a)
return
}
opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})} opts := []cmp.Option{cmpopts.IgnoreUnexported(linkedca.Admin{}, timestamppb.Timestamp{})}
if !cmp.Equal(admin, adm, opts...) { if !cmp.Equal(adm, adm, opts...) {
t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(admin, adm, opts...)) t.Errorf("linkedca.Admin diff =\n%s", cmp.Diff(adm, adm, opts...))
} }
w.Write(nil) // mock response with status 200 w.Write(nil) // mock response with status 200
} }
@ -194,13 +194,459 @@ func TestHandler_extractAuthorizeTokenAdmin(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth, req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder()
extractAuthorizeTokenAdmin(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 {
err := admin.Error{}
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
assert.Equals(t, tc.err.Type, err.Type)
assert.Equals(t, tc.err.Message, err.Message)
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
return
} }
})
}
}
req := tc.req.WithContext(tc.ctx) func TestHandler_loadProvisionerByName(t *testing.T) {
type test struct {
adminDB admin.DB
auth adminAuthority
ctx context.Context
next http.HandlerFunc
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/auth.LoadProvisionerByName": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return nil, errors.New("force")
},
}
err := admin.WrapErrorISE(errors.New("force"), "error loading provisioner provName")
err.Message = "error loading provisioner provName: force"
return test{
ctx: ctx,
auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500,
err: err,
}
},
"fail/db.GetProvisioner": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return nil, errors.New("force")
},
}
err := admin.WrapErrorISE(errors.New("force"), "error retrieving provisioner provName")
err.Message = "error retrieving provisioner provName: force"
return test{
ctx: ctx,
auth: auth,
adminDB: db,
statusCode: 500,
err: err,
}
},
"ok": func(t *testing.T) test {
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("provisionerName", "provName")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
auth := &mockAdminAuthority{
MockLoadProvisionerByName: func(name string) (provisioner.Interface, error) {
assert.Equals(t, "provName", name)
return &provisioner.MockProvisioner{
MgetID: func() string {
return "provID"
},
}, nil
},
}
db := &admin.MockDB{
MockGetProvisioner: func(ctx context.Context, id string) (*linkedca.Provisioner, error) {
assert.Equals(t, "provID", id)
return &linkedca.Provisioner{
Id: "provID",
Name: "provName",
}, nil
},
}
return test{
ctx: ctx,
auth: auth,
adminDB: db,
statusCode: 200,
next: func(w http.ResponseWriter, r *http.Request) {
prov := linkedca.MustProvisionerFromContext(r.Context())
assert.NotNil(t, prov)
assert.Equals(t, "provID", prov.GetId())
assert.Equals(t, "provName", prov.GetName())
w.Write(nil) // mock response with status 200
},
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
mockMustAuthority(t, tc.auth)
ctx := admin.NewContext(tc.ctx, tc.adminDB)
req := httptest.NewRequest("GET", "/foo", nil) // chi routing is prepared in test setup
req = req.WithContext(ctx)
w := httptest.NewRecorder()
loadProvisionerByName(tc.next)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 {
err := admin.Error{}
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
assert.Equals(t, tc.err.Type, err.Type)
assert.Equals(t, tc.err.Message, err.Message)
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
return
}
})
}
}
func TestHandler_checkAction(t *testing.T) {
type test struct {
adminDB admin.DB
next http.HandlerFunc
supportedInStandalone bool
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"standalone-nosql-supported": func(t *testing.T) test {
return test{
supportedInStandalone: true,
adminDB: &nosql.DB{},
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(nil) // mock response with status 200
},
statusCode: 200,
}
},
"standalone-nosql-not-supported": func(t *testing.T) test {
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported in standalone mode")
err.Message = "operation not supported in standalone mode"
return test{
supportedInStandalone: false,
adminDB: &nosql.DB{},
statusCode: 501,
err: err,
}
},
"standalone-no-nosql-not-supported": func(t *testing.T) test {
err := admin.NewError(admin.ErrorNotImplementedType, "operation not supported")
err.Message = "operation not supported"
return test{
supportedInStandalone: false,
adminDB: &admin.MockDB{},
next: func(w http.ResponseWriter, r *http.Request) {
w.Write(nil) // mock response with status 200
},
statusCode: 200,
err: err,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
ctx := admin.NewContext(context.Background(), tc.adminDB)
req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx)
w := httptest.NewRecorder()
checkAction(tc.next, tc.supportedInStandalone)(w, req)
res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode)
body, err := io.ReadAll(res.Body)
res.Body.Close()
assert.FatalError(t, err)
if res.StatusCode >= 400 {
err := admin.Error{}
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &err))
assert.Equals(t, tc.err.Type, err.Type)
assert.Equals(t, tc.err.Message, err.Message)
assert.Equals(t, tc.err.StatusCode(), res.StatusCode)
assert.Equals(t, tc.err.Detail, err.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
return
}
})
}
}
func TestHandler_loadExternalAccountKey(t *testing.T) {
type test struct {
ctx context.Context
acmeDB acme.DB
next http.HandlerFunc
err *admin.Error
statusCode int
}
var tests = map[string]func(t *testing.T) test{
"fail/keyID-not-found-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("keyID", "key")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "key", keyID)
return nil, acme.ErrNotFound
},
},
err: err,
statusCode: 404,
}
},
"fail/keyID-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("keyID", "key")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
err.Message = "error retrieving ACME External Account Key: force"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "key", keyID)
return nil, errors.New("force")
},
},
err: err,
statusCode: 500,
}
},
"fail/reference-not-found-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return nil, acme.ErrNotFound
},
},
err: err,
statusCode: 404,
}
},
"fail/reference-error": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.WrapErrorISE(errors.New("force"), "error retrieving ACME External Account Key")
err.Message = "error retrieving ACME External Account Key: force"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return nil, errors.New("force")
},
},
err: err,
statusCode: 500,
}
},
"fail/no-key": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return nil, nil
},
},
err: err,
statusCode: 404,
}
},
"ok/keyID": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("keyID", "eakID")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
createdAt := time.Now().Add(-1 * time.Hour)
var boundAt time.Time
eak := &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: "provID",
CreatedAt: createdAt,
BoundAt: boundAt,
}
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKey: func(ctx context.Context, provisionerID, keyID string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "eakID", keyID)
return eak, nil
},
},
next: func(w http.ResponseWriter, r *http.Request) {
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
assert.NotNil(t, eak)
exp := &linkedca.EABKey{
Id: "eakID",
Provisioner: "provID",
CreatedAt: timestamppb.New(createdAt),
BoundAt: timestamppb.New(boundAt),
}
assert.Equals(t, exp, contextEAK)
w.Write(nil) // mock response with status 200
},
err: nil,
statusCode: 200,
}
},
"ok/reference": func(t *testing.T) test {
prov := &linkedca.Provisioner{
Id: "provID",
}
chiCtx := chi.NewRouteContext()
chiCtx.URLParams.Add("reference", "ref")
ctx := context.WithValue(context.Background(), chi.RouteCtxKey, chiCtx)
ctx = linkedca.NewContextWithProvisioner(ctx, prov)
err := admin.NewError(admin.ErrorNotFoundType, "ACME External Account Key not found")
err.Message = "ACME External Account Key not found"
createdAt := time.Now().Add(-1 * time.Hour)
var boundAt time.Time
eak := &acme.ExternalAccountKey{
ID: "eakID",
ProvisionerID: "provID",
Reference: "ref",
CreatedAt: createdAt,
BoundAt: boundAt,
}
return test{
ctx: ctx,
acmeDB: &acme.MockDB{
MockGetExternalAccountKeyByReference: func(ctx context.Context, provisionerID, reference string) (*acme.ExternalAccountKey, error) {
assert.Equals(t, "provID", provisionerID)
assert.Equals(t, "ref", reference)
return eak, nil
},
},
next: func(w http.ResponseWriter, r *http.Request) {
contextEAK := linkedca.MustExternalAccountKeyFromContext(r.Context())
assert.NotNil(t, eak)
exp := &linkedca.EABKey{
Id: "eakID",
Provisioner: "provID",
Reference: "ref",
CreatedAt: timestamppb.New(createdAt),
BoundAt: timestamppb.New(boundAt),
}
assert.Equals(t, exp, contextEAK)
w.Write(nil) // mock response with status 200
},
err: nil,
statusCode: 200,
}
},
}
for name, prep := range tests {
tc := prep(t)
t.Run(name, func(t *testing.T) {
ctx := acme.NewDatabaseContext(tc.ctx, tc.acmeDB)
req := httptest.NewRequest("GET", "/foo", nil)
req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.extractAuthorizeTokenAdmin(tc.next)(w, req) loadExternalAccountKey(tc.next)(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)

@ -0,0 +1,496 @@
package api
import (
"context"
"errors"
"net/http"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/acme"
"github.com/smallstep/certificates/api/read"
"github.com/smallstep/certificates/api/render"
"github.com/smallstep/certificates/authority"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/policy"
)
// PolicyAdminResponder is the interface responsible for writing ACME admin
// responses.
type PolicyAdminResponder interface {
GetAuthorityPolicy(w http.ResponseWriter, r *http.Request)
CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request)
DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request)
GetProvisionerPolicy(w http.ResponseWriter, r *http.Request)
CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request)
DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request)
GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request)
}
// policyAdminResponder implements PolicyAdminResponder.
type policyAdminResponder struct{}
// NewACMEAdminResponder returns a new PolicyAdminResponder.
func NewPolicyAdminResponder() PolicyAdminResponder {
return &policyAdminResponder{}
}
// GetAuthorityPolicy handles the GET /admin/authority/policy request
func (par *policyAdminResponder) GetAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
auth := mustAuthority(ctx)
authorityPolicy, err := auth.GetAuthorityPolicy(r.Context())
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return
}
if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return
}
render.ProtoJSONStatus(w, authorityPolicy, http.StatusOK)
}
// CreateAuthorityPolicy handles the POST /admin/authority/policy request
func (par *policyAdminResponder) CreateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
auth := mustAuthority(ctx)
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
return
}
if authorityPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "authority already has a policy")
render.Error(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
return
}
adm := linkedca.MustAdminFromContext(ctx)
var createdPolicy *linkedca.Policy
if createdPolicy, err = auth.CreateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error storing authority policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error storing authority policy"))
return
}
render.ProtoJSONStatus(w, createdPolicy, http.StatusCreated)
}
// UpdateAuthorityPolicy handles the PUT /admin/authority/policy request
func (par *policyAdminResponder) UpdateAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
auth := mustAuthority(ctx)
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(err, "error retrieving authority policy"))
return
}
if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating authority policy"))
return
}
adm := linkedca.MustAdminFromContext(ctx)
var updatedPolicy *linkedca.Policy
if updatedPolicy, err = auth.UpdateAuthorityPolicy(ctx, adm, newPolicy); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating authority policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error updating authority policy"))
return
}
render.ProtoJSONStatus(w, updatedPolicy, http.StatusOK)
}
// DeleteAuthorityPolicy handles the DELETE /admin/authority/policy request
func (par *policyAdminResponder) DeleteAuthorityPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
auth := mustAuthority(ctx)
authorityPolicy, err := auth.GetAuthorityPolicy(ctx)
if ae, ok := err.(*admin.Error); ok && !ae.IsType(admin.ErrorNotFoundType) {
render.Error(w, admin.WrapErrorISE(ae, "error retrieving authority policy"))
return
}
if authorityPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "authority policy does not exist"))
return
}
if err := auth.RemoveAuthorityPolicy(ctx); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting authority policy"))
return
}
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
// GetProvisionerPolicy handles the GET /admin/provisioners/{name}/policy request
func (par *policyAdminResponder) GetProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return
}
render.ProtoJSONStatus(w, provisionerPolicy, http.StatusOK)
}
// CreateProvisionerPolicy handles the POST /admin/provisioners/{name}/policy request
func (par *policyAdminResponder) CreateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy()
if provisionerPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "provisioner %s already has a policy", prov.Name)
render.Error(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
return
}
prov.Policy = newPolicy
auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error creating provisioner policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error creating provisioner policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
}
// UpdateProvisionerPolicy handles the PUT /admin/provisioners/{name}/policy request
func (par *policyAdminResponder) UpdateProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
provisionerPolicy := prov.GetPolicy()
if provisionerPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating provisioner policy"))
return
}
prov.Policy = newPolicy
auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
if isBadRequest(err) {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error updating provisioner policy"))
return
}
render.Error(w, admin.WrapErrorISE(err, "error updating provisioner policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
}
// DeleteProvisionerPolicy handles the DELETE /admin/provisioners/{name}/policy request
func (par *policyAdminResponder) DeleteProvisionerPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
if prov.Policy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "provisioner policy does not exist"))
return
}
// remove the policy
prov.Policy = nil
auth := mustAuthority(ctx)
if err := auth.UpdateProvisioner(ctx, prov); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting provisioner policy"))
return
}
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
func (par *policyAdminResponder) GetACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return
}
render.ProtoJSONStatus(w, eakPolicy, http.StatusOK)
}
func (par *policyAdminResponder) CreateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy != nil {
adminErr := admin.NewError(admin.ErrorConflictType, "ACME EAK %s already has a policy", eak.Id)
render.Error(w, adminErr)
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
return
}
eak.Policy = newPolicy
acmeEAK := linkedEAKToCertificates(eak)
acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error creating ACME EAK policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusCreated)
}
func (par *policyAdminResponder) UpdateACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return
}
var newPolicy = new(linkedca.Policy)
if err := read.ProtoJSON(r.Body, newPolicy); err != nil {
render.Error(w, err)
return
}
newPolicy.Deduplicate()
if err := validatePolicy(newPolicy); err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, "error validating ACME EAK policy"))
return
}
eak.Policy = newPolicy
acmeEAK := linkedEAKToCertificates(eak)
acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error updating ACME EAK policy"))
return
}
render.ProtoJSONStatus(w, newPolicy, http.StatusOK)
}
func (par *policyAdminResponder) DeleteACMEAccountPolicy(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if err := blockLinkedCA(ctx); err != nil {
render.Error(w, err)
return
}
prov := linkedca.MustProvisionerFromContext(ctx)
eak := linkedca.MustExternalAccountKeyFromContext(ctx)
eakPolicy := eak.GetPolicy()
if eakPolicy == nil {
render.Error(w, admin.NewError(admin.ErrorNotFoundType, "ACME EAK policy does not exist"))
return
}
// remove the policy
eak.Policy = nil
acmeEAK := linkedEAKToCertificates(eak)
acmeDB := acme.MustDatabaseFromContext(ctx)
if err := acmeDB.UpdateExternalAccountKey(ctx, prov.GetId(), acmeEAK); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error deleting ACME EAK policy"))
return
}
render.JSONStatus(w, DeleteResponse{Status: "ok"}, http.StatusOK)
}
// blockLinkedCA blocks all API operations on linked deployments
func blockLinkedCA(ctx context.Context) error {
// temporary blocking linked deployments
adminDB := admin.MustFromContext(ctx)
if a, ok := adminDB.(interface{ IsLinkedCA() bool }); ok && a.IsLinkedCA() {
return admin.NewError(admin.ErrorNotImplementedType, "policy operations not yet supported in linked deployments")
}
return nil
}
// isBadRequest checks if an error should result in a bad request error
// returned to the client.
func isBadRequest(err error) bool {
var pe *authority.PolicyError
isPolicyError := errors.As(err, &pe)
return isPolicyError && (pe.Typ == authority.AdminLockOut || pe.Typ == authority.EvaluationFailure || pe.Typ == authority.ConfigurationFailure)
}
func validatePolicy(p *linkedca.Policy) error {
// convert the policy; return early if nil
options := policy.LinkedToCertificates(p)
if options == nil {
return nil
}
var err error
// Initialize a temporary x509 allow/deny policy engine
if _, err = policy.NewX509PolicyEngine(options.GetX509Options()); err != nil {
return err
}
// Initialize a temporary SSH allow/deny policy engine for host certificates
if _, err = policy.NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
return err
}
// Initialize a temporary SSH allow/deny policy engine for user certificates
if _, err = policy.NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
return err
}
return nil
}

File diff suppressed because it is too large Load Diff

@ -23,29 +23,31 @@ type GetProvisionersResponse struct {
} }
// GetProvisioner returns the requested provisioner, or an error. // GetProvisioner returns the requested provisioner, or an error.
func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) { func GetProvisioner(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
var ( var (
p provisioner.Interface p provisioner.Interface
err error err error
) )
ctx := r.Context()
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
auth := mustAuthority(ctx)
db := admin.MustFromContext(ctx)
if len(id) > 0 { if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil { if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil { if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
prov, err := h.adminDB.GetProvisioner(ctx, p.GetID()) prov, err := db.GetProvisioner(ctx, p.GetID())
if err != nil { if err != nil {
render.Error(w, err) render.Error(w, err)
return return
@ -54,7 +56,7 @@ func (h *Handler) GetProvisioner(w http.ResponseWriter, r *http.Request) {
} }
// GetProvisioners returns the given segment of provisioners associated with the authority. // GetProvisioners returns the given segment of provisioners associated with the authority.
func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) { func GetProvisioners(w http.ResponseWriter, r *http.Request) {
cursor, limit, err := api.ParseCursor(r) cursor, limit, err := api.ParseCursor(r)
if err != nil { if err != nil {
render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err, render.Error(w, admin.WrapError(admin.ErrorBadRequestType, err,
@ -62,7 +64,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
return return
} }
p, next, err := h.auth.GetProvisioners(cursor, limit) p, next, err := mustAuthority(r.Context()).GetProvisioners(cursor, limit)
if err != nil { if err != nil {
render.Error(w, errs.InternalServerErr(err)) render.Error(w, errs.InternalServerErr(err))
return return
@ -74,7 +76,7 @@ func (h *Handler) GetProvisioners(w http.ResponseWriter, r *http.Request) {
} }
// CreateProvisioner creates a new prov. // CreateProvisioner creates a new prov.
func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) { func CreateProvisioner(w http.ResponseWriter, r *http.Request) {
var prov = new(linkedca.Provisioner) var prov = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, prov); err != nil { if err := read.ProtoJSON(r.Body, prov); err != nil {
render.Error(w, err) render.Error(w, err)
@ -87,7 +89,7 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := h.auth.StoreProvisioner(r.Context(), prov); err != nil { if err := mustAuthority(r.Context()).StoreProvisioner(r.Context(), prov); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name)) render.Error(w, admin.WrapErrorISE(err, "error storing provisioner %s", prov.Name))
return return
} }
@ -95,27 +97,29 @@ func (h *Handler) CreateProvisioner(w http.ResponseWriter, r *http.Request) {
} }
// DeleteProvisioner deletes a provisioner. // DeleteProvisioner deletes a provisioner.
func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) { func DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
var ( var (
p provisioner.Interface p provisioner.Interface
err error err error
) )
id := r.URL.Query().Get("id")
name := chi.URLParam(r, "name")
auth := mustAuthority(r.Context())
if len(id) > 0 { if len(id) > 0 {
if p, err = h.auth.LoadProvisionerByID(id); err != nil { if p, err = auth.LoadProvisionerByID(id); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", id))
return return
} }
} else { } else {
if p, err = h.auth.LoadProvisionerByName(name); err != nil { if p, err = auth.LoadProvisionerByName(name); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner %s", name))
return return
} }
} }
if err := h.auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil { if err := auth.RemoveProvisioner(r.Context(), p.GetID()); err != nil {
render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName())) render.Error(w, admin.WrapErrorISE(err, "error removing provisioner %s", p.GetName()))
return return
} }
@ -124,23 +128,27 @@ func (h *Handler) DeleteProvisioner(w http.ResponseWriter, r *http.Request) {
} }
// UpdateProvisioner updates an existing prov. // UpdateProvisioner updates an existing prov.
func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) { func UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
var nu = new(linkedca.Provisioner) var nu = new(linkedca.Provisioner)
if err := read.ProtoJSON(r.Body, nu); err != nil { if err := read.ProtoJSON(r.Body, nu); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }
ctx := r.Context()
name := chi.URLParam(r, "name") name := chi.URLParam(r, "name")
_old, err := h.auth.LoadProvisionerByName(name) auth := mustAuthority(ctx)
db := admin.MustFromContext(ctx)
p, err := auth.LoadProvisionerByName(name)
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name)) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from cached configuration '%s'", name))
return return
} }
old, err := h.adminDB.GetProvisioner(r.Context(), _old.GetID()) old, err := db.GetProvisioner(r.Context(), p.GetID())
if err != nil { if err != nil {
render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", _old.GetID())) render.Error(w, admin.WrapErrorISE(err, "error loading provisioner from db '%s'", p.GetID()))
return return
} }
@ -171,7 +179,7 @@ func (h *Handler) UpdateProvisioner(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := h.auth.UpdateProvisioner(r.Context(), nu); err != nil { if err := auth.UpdateProvisioner(r.Context(), nu); err != nil {
render.Error(w, err) render.Error(w, err)
return return
} }

@ -8,18 +8,21 @@ import (
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
"go.step.sm/linkedca"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"go.step.sm/linkedca"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/timestamppb"
) )
func TestHandler_GetProvisioner(t *testing.T) { func TestHandler_GetProvisioner(t *testing.T) {
@ -47,6 +50,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
ctx: ctx, ctx: ctx,
req: req, req: req,
auth: auth, auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500, statusCode: 500,
err: &admin.Error{ err: &admin.Error{
Type: admin.ErrorServerInternalType.String(), Type: admin.ErrorServerInternalType.String(),
@ -71,6 +75,7 @@ func TestHandler_GetProvisioner(t *testing.T) {
ctx: ctx, ctx: ctx,
req: req, req: req,
auth: auth, auth: auth,
adminDB: &admin.MockDB{},
statusCode: 500, statusCode: 500,
err: &admin.Error{ err: &admin.Error{
Type: admin.ErrorServerInternalType.String(), Type: admin.ErrorServerInternalType.String(),
@ -153,13 +158,11 @@ func TestHandler_GetProvisioner(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth, ctx := admin.NewContext(tc.ctx, tc.adminDB)
adminDB: tc.adminDB, req := tc.req.WithContext(ctx)
}
req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetProvisioner(w, req) GetProvisioner(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -277,12 +280,10 @@ func TestHandler_GetProvisioners(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := tc.req.WithContext(tc.ctx) req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.GetProvisioners(w, req) GetProvisioners(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -335,12 +336,12 @@ func TestHandler_CreateProvisioner(t *testing.T) {
return test{ return test{
ctx: context.Background(), ctx: context.Background(),
body: body, body: body,
statusCode: 500, statusCode: 400,
err: &admin.Error{ // TODO(hs): this probably needs a better error err: &admin.Error{
Type: "", Type: "badRequest",
Status: 500, Status: 400,
Detail: "", Detail: "bad request",
Message: "", Message: "proto: syntax error (line 1:2): invalid value !",
}, },
} }
}, },
@ -402,13 +403,11 @@ func TestHandler_CreateProvisioner(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.CreateProvisioner(w, req) CreateProvisioner(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -423,9 +422,15 @@ func TestHandler_CreateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return return
} }
@ -562,12 +567,10 @@ func TestHandler_DeleteProvisioner(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth,
}
req := tc.req.WithContext(tc.ctx) req := tc.req.WithContext(tc.ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.DeleteProvisioner(w, req) DeleteProvisioner(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -616,12 +619,13 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{ return test{
ctx: context.Background(), ctx: context.Background(),
body: body, body: body,
statusCode: 500, adminDB: &admin.MockDB{},
err: &admin.Error{ // TODO(hs): this probably needs a better error statusCode: 400,
Type: "", err: &admin.Error{
Status: 500, Type: "badRequest",
Detail: "", Status: 400,
Message: "", Detail: "bad request",
Message: "proto: syntax error (line 1:2): invalid value !",
}, },
} }
}, },
@ -645,6 +649,7 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
return test{ return test{
ctx: ctx, ctx: ctx,
body: body, body: body,
adminDB: &admin.MockDB{},
auth: auth, auth: auth,
statusCode: 500, statusCode: 500,
err: &admin.Error{ err: &admin.Error{
@ -1052,14 +1057,12 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
for name, prep := range tests { for name, prep := range tests {
tc := prep(t) tc := prep(t)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
h := &Handler{ mockMustAuthority(t, tc.auth)
auth: tc.auth, ctx := admin.NewContext(tc.ctx, tc.adminDB)
adminDB: tc.adminDB,
}
req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body))) req := httptest.NewRequest("POST", "/foo", io.NopCloser(bytes.NewBuffer(tc.body)))
req = req.WithContext(tc.ctx) req = req.WithContext(ctx)
w := httptest.NewRecorder() w := httptest.NewRecorder()
h.UpdateProvisioner(w, req) UpdateProvisioner(w, req)
res := w.Result() res := w.Result()
assert.Equals(t, tc.statusCode, res.StatusCode) assert.Equals(t, tc.statusCode, res.StatusCode)
@ -1074,9 +1077,15 @@ func TestHandler_UpdateProvisioner(t *testing.T) {
assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr)) assert.FatalError(t, json.Unmarshal(bytes.TrimSpace(body), &adminErr))
assert.Equals(t, tc.err.Type, adminErr.Type) assert.Equals(t, tc.err.Type, adminErr.Type)
assert.Equals(t, tc.err.Message, adminErr.Message)
assert.Equals(t, tc.err.Detail, adminErr.Detail) assert.Equals(t, tc.err.Detail, adminErr.Detail)
assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"]) assert.Equals(t, []string{"application/json"}, res.Header["Content-Type"])
if strings.HasPrefix(tc.err.Message, "proto:") {
assert.True(t, strings.Contains(adminErr.Message, "syntax error"))
} else {
assert.Equals(t, tc.err.Message, adminErr.Message)
}
return return
} }

@ -69,6 +69,34 @@ type DB interface {
GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error)
UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error UpdateAdmin(ctx context.Context, admin *linkedca.Admin) error
DeleteAdmin(ctx context.Context, id string) error DeleteAdmin(ctx context.Context, id string) error
CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error)
UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error
DeleteAuthorityPolicy(ctx context.Context) error
}
type dbKey struct{}
// NewContext adds the given admin database to the context.
func NewContext(ctx context.Context, db DB) context.Context {
return context.WithValue(ctx, dbKey{}, db)
}
// FromContext returns the current admin database from the given context.
func FromContext(ctx context.Context) (db DB, ok bool) {
db, ok = ctx.Value(dbKey{}).(DB)
return
}
// MustFromContext returns the current admin database from the given context. It
// will panic if it's not in the context.
func MustFromContext(ctx context.Context) DB {
if db, ok := FromContext(ctx); !ok {
panic("admin database is not in the context")
} else {
return db
}
} }
// MockDB is an implementation of the DB interface that should only be used as // MockDB is an implementation of the DB interface that should only be used as
@ -86,6 +114,11 @@ type MockDB struct {
MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error MockUpdateAdmin func(ctx context.Context, adm *linkedca.Admin) error
MockDeleteAdmin func(ctx context.Context, id string) error MockDeleteAdmin func(ctx context.Context, id string) error
MockCreateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
MockGetAuthorityPolicy func(ctx context.Context) (*linkedca.Policy, error)
MockUpdateAuthorityPolicy func(ctx context.Context, policy *linkedca.Policy) error
MockDeleteAuthorityPolicy func(ctx context.Context) error
MockError error MockError error
MockRet1 interface{} MockRet1 interface{}
} }
@ -179,3 +212,35 @@ func (m *MockDB) DeleteAdmin(ctx context.Context, id string) error {
} }
return m.MockError return m.MockError
} }
// CreateAuthorityPolicy mock
func (m *MockDB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
if m.MockCreateAuthorityPolicy != nil {
return m.MockCreateAuthorityPolicy(ctx, policy)
}
return m.MockError
}
// GetAuthorityPolicy mock
func (m *MockDB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
if m.MockGetAuthorityPolicy != nil {
return m.MockGetAuthorityPolicy(ctx)
}
return m.MockRet1.(*linkedca.Policy), m.MockError
}
// UpdateAuthorityPolicy mock
func (m *MockDB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
if m.MockUpdateAuthorityPolicy != nil {
return m.MockUpdateAuthorityPolicy(ctx, policy)
}
return m.MockError
}
// DeleteAuthorityPolicy mock
func (m *MockDB) DeleteAuthorityPolicy(ctx context.Context) error {
if m.MockDeleteAuthorityPolicy != nil {
return m.MockDeleteAuthorityPolicy(ctx)
}
return m.MockError
}

@ -11,8 +11,9 @@ import (
) )
var ( var (
adminsTable = []byte("admins") adminsTable = []byte("admins")
provisionersTable = []byte("provisioners") provisionersTable = []byte("provisioners")
authorityPoliciesTable = []byte("authority_policies")
) )
// DB is a struct that implements the AdminDB interface. // DB is a struct that implements the AdminDB interface.
@ -23,7 +24,7 @@ type DB struct {
// New configures and returns a new Authority DB backend implemented using a nosql DB. // New configures and returns a new Authority DB backend implemented using a nosql DB.
func New(db nosqlDB.DB, authorityID string) (*DB, error) { func New(db nosqlDB.DB, authorityID string) (*DB, error) {
tables := [][]byte{adminsTable, provisionersTable} tables := [][]byte{adminsTable, provisionersTable, authorityPoliciesTable}
for _, b := range tables { for _, b := range tables {
if err := db.CreateTable(b); err != nil { if err := db.CreateTable(b); err != nil {
return nil, errors.Wrapf(err, "error creating table %s", return nil, errors.Wrapf(err, "error creating table %s",

@ -0,0 +1,339 @@
package nosql
import (
"context"
"encoding/json"
"fmt"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
"github.com/smallstep/nosql"
)
type dbX509Policy struct {
Allow *dbX509Names `json:"allow,omitempty"`
Deny *dbX509Names `json:"deny,omitempty"`
AllowWildcardNames bool `json:"allow_wildcard_names,omitempty"`
}
type dbX509Names struct {
CommonNames []string `json:"cn,omitempty"`
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
URIDomains []string `json:"uri,omitempty"`
}
type dbSSHPolicy struct {
// User contains SSH user certificate options.
User *dbSSHUserPolicy `json:"user,omitempty"`
// Host contains SSH host certificate options.
Host *dbSSHHostPolicy `json:"host,omitempty"`
}
type dbSSHHostPolicy struct {
Allow *dbSSHHostNames `json:"allow,omitempty"`
Deny *dbSSHHostNames `json:"deny,omitempty"`
}
type dbSSHHostNames struct {
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
Principals []string `json:"principal,omitempty"`
}
type dbSSHUserPolicy struct {
Allow *dbSSHUserNames `json:"allow,omitempty"`
Deny *dbSSHUserNames `json:"deny,omitempty"`
}
type dbSSHUserNames struct {
EmailAddresses []string `json:"email,omitempty"`
Principals []string `json:"principal,omitempty"`
}
type dbPolicy struct {
X509 *dbX509Policy `json:"x509,omitempty"`
SSH *dbSSHPolicy `json:"ssh,omitempty"`
}
type dbAuthorityPolicy struct {
ID string `json:"id"`
AuthorityID string `json:"authorityID"`
Policy *dbPolicy `json:"policy,omitempty"`
}
func (dbap *dbAuthorityPolicy) convert() *linkedca.Policy {
if dbap == nil {
return nil
}
return dbToLinked(dbap.Policy)
}
func (db *DB) getDBAuthorityPolicyBytes(ctx context.Context, authorityID string) ([]byte, error) {
data, err := db.db.Get(authorityPoliciesTable, []byte(authorityID))
if nosql.IsErrNotFound(err) {
return nil, admin.NewError(admin.ErrorNotFoundType, "authority policy not found")
} else if err != nil {
return nil, fmt.Errorf("error loading authority policy: %w", err)
}
return data, nil
}
func (db *DB) unmarshalDBAuthorityPolicy(data []byte) (*dbAuthorityPolicy, error) {
if len(data) == 0 {
return nil, nil
}
var dba = new(dbAuthorityPolicy)
if err := json.Unmarshal(data, dba); err != nil {
return nil, fmt.Errorf("error unmarshaling policy bytes into dbAuthorityPolicy: %w", err)
}
return dba, nil
}
func (db *DB) getDBAuthorityPolicy(ctx context.Context, authorityID string) (*dbAuthorityPolicy, error) {
data, err := db.getDBAuthorityPolicyBytes(ctx, authorityID)
if err != nil {
return nil, err
}
dbap, err := db.unmarshalDBAuthorityPolicy(data)
if err != nil {
return nil, err
}
if dbap == nil {
return nil, nil
}
if dbap.AuthorityID != authorityID {
return nil, admin.NewError(admin.ErrorAuthorityMismatchType,
"authority policy is not owned by authority %s", authorityID)
}
return dbap, nil
}
func (db *DB) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
dbap := &dbAuthorityPolicy{
ID: db.authorityID,
AuthorityID: db.authorityID,
Policy: linkedToDB(policy),
}
if err := db.save(ctx, dbap.ID, dbap, nil, "authority_policy", authorityPoliciesTable); err != nil {
return admin.WrapErrorISE(err, "error creating authority policy")
}
return nil
}
func (db *DB) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
dbap, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return nil, err
}
return dbap.convert(), nil
}
func (db *DB) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return err
}
dbap := &dbAuthorityPolicy{
ID: db.authorityID,
AuthorityID: db.authorityID,
Policy: linkedToDB(policy),
}
if err := db.save(ctx, dbap.ID, dbap, old, "authority_policy", authorityPoliciesTable); err != nil {
return admin.WrapErrorISE(err, "error updating authority policy")
}
return nil
}
func (db *DB) DeleteAuthorityPolicy(ctx context.Context) error {
old, err := db.getDBAuthorityPolicy(ctx, db.authorityID)
if err != nil {
return err
}
if err := db.save(ctx, old.ID, nil, old, "authority_policy", authorityPoliciesTable); err != nil {
return admin.WrapErrorISE(err, "error deleting authority policy")
}
return nil
}
func dbToLinked(p *dbPolicy) *linkedca.Policy {
if p == nil {
return nil
}
r := &linkedca.Policy{}
if x509 := p.X509; x509 != nil {
r.X509 = &linkedca.X509Policy{}
if allow := x509.Allow; allow != nil {
r.X509.Allow = &linkedca.X509Names{}
r.X509.Allow.Dns = allow.DNSDomains
r.X509.Allow.Emails = allow.EmailAddresses
r.X509.Allow.Ips = allow.IPRanges
r.X509.Allow.Uris = allow.URIDomains
r.X509.Allow.CommonNames = allow.CommonNames
}
if deny := x509.Deny; deny != nil {
r.X509.Deny = &linkedca.X509Names{}
r.X509.Deny.Dns = deny.DNSDomains
r.X509.Deny.Emails = deny.EmailAddresses
r.X509.Deny.Ips = deny.IPRanges
r.X509.Deny.Uris = deny.URIDomains
r.X509.Deny.CommonNames = deny.CommonNames
}
r.X509.AllowWildcardNames = x509.AllowWildcardNames
}
if ssh := p.SSH; ssh != nil {
r.Ssh = &linkedca.SSHPolicy{}
if host := ssh.Host; host != nil {
r.Ssh.Host = &linkedca.SSHHostPolicy{}
if allow := host.Allow; allow != nil {
r.Ssh.Host.Allow = &linkedca.SSHHostNames{}
r.Ssh.Host.Allow.Dns = allow.DNSDomains
r.Ssh.Host.Allow.Ips = allow.IPRanges
r.Ssh.Host.Allow.Principals = allow.Principals
}
if deny := host.Deny; deny != nil {
r.Ssh.Host.Deny = &linkedca.SSHHostNames{}
r.Ssh.Host.Deny.Dns = deny.DNSDomains
r.Ssh.Host.Deny.Ips = deny.IPRanges
r.Ssh.Host.Deny.Principals = deny.Principals
}
}
if user := ssh.User; user != nil {
r.Ssh.User = &linkedca.SSHUserPolicy{}
if allow := user.Allow; allow != nil {
r.Ssh.User.Allow = &linkedca.SSHUserNames{}
r.Ssh.User.Allow.Emails = allow.EmailAddresses
r.Ssh.User.Allow.Principals = allow.Principals
}
if deny := user.Deny; deny != nil {
r.Ssh.User.Deny = &linkedca.SSHUserNames{}
r.Ssh.User.Deny.Emails = deny.EmailAddresses
r.Ssh.User.Deny.Principals = deny.Principals
}
}
}
return r
}
func linkedToDB(p *linkedca.Policy) *dbPolicy {
if p == nil {
return nil
}
// return early if x509 nor SSH is set
if p.GetX509() == nil && p.GetSsh() == nil {
return nil
}
r := &dbPolicy{}
// fill x509 policy configuration
if x509 := p.GetX509(); x509 != nil {
r.X509 = &dbX509Policy{}
if allow := x509.GetAllow(); allow != nil {
r.X509.Allow = &dbX509Names{}
if allow.Dns != nil {
r.X509.Allow.DNSDomains = allow.Dns
}
if allow.Ips != nil {
r.X509.Allow.IPRanges = allow.Ips
}
if allow.Emails != nil {
r.X509.Allow.EmailAddresses = allow.Emails
}
if allow.Uris != nil {
r.X509.Allow.URIDomains = allow.Uris
}
if allow.CommonNames != nil {
r.X509.Allow.CommonNames = allow.CommonNames
}
}
if deny := x509.GetDeny(); deny != nil {
r.X509.Deny = &dbX509Names{}
if deny.Dns != nil {
r.X509.Deny.DNSDomains = deny.Dns
}
if deny.Ips != nil {
r.X509.Deny.IPRanges = deny.Ips
}
if deny.Emails != nil {
r.X509.Deny.EmailAddresses = deny.Emails
}
if deny.Uris != nil {
r.X509.Deny.URIDomains = deny.Uris
}
if deny.CommonNames != nil {
r.X509.Deny.CommonNames = deny.CommonNames
}
}
r.X509.AllowWildcardNames = x509.GetAllowWildcardNames()
}
// fill ssh policy configuration
if ssh := p.GetSsh(); ssh != nil {
r.SSH = &dbSSHPolicy{}
if host := ssh.GetHost(); host != nil {
r.SSH.Host = &dbSSHHostPolicy{}
if allow := host.GetAllow(); allow != nil {
r.SSH.Host.Allow = &dbSSHHostNames{}
if allow.Dns != nil {
r.SSH.Host.Allow.DNSDomains = allow.Dns
}
if allow.Ips != nil {
r.SSH.Host.Allow.IPRanges = allow.Ips
}
if allow.Principals != nil {
r.SSH.Host.Allow.Principals = allow.Principals
}
}
if deny := host.GetDeny(); deny != nil {
r.SSH.Host.Deny = &dbSSHHostNames{}
if deny.Dns != nil {
r.SSH.Host.Deny.DNSDomains = deny.Dns
}
if deny.Ips != nil {
r.SSH.Host.Deny.IPRanges = deny.Ips
}
if deny.Principals != nil {
r.SSH.Host.Deny.Principals = deny.Principals
}
}
}
if user := ssh.GetUser(); user != nil {
r.SSH.User = &dbSSHUserPolicy{}
if allow := user.GetAllow(); allow != nil {
r.SSH.User.Allow = &dbSSHUserNames{}
if allow.Emails != nil {
r.SSH.User.Allow.EmailAddresses = allow.Emails
}
if allow.Principals != nil {
r.SSH.User.Allow.Principals = allow.Principals
}
}
if deny := user.GetDeny(); deny != nil {
r.SSH.User.Deny = &dbSSHUserNames{}
if deny.Emails != nil {
r.SSH.User.Deny.EmailAddresses = deny.Emails
}
if deny.Principals != nil {
r.SSH.User.Deny.Principals = deny.Principals
}
}
}
}
return r
}

File diff suppressed because it is too large Load Diff

@ -24,10 +24,12 @@ const (
ErrorBadRequestType ErrorBadRequestType
// ErrorNotImplementedType not implemented. // ErrorNotImplementedType not implemented.
ErrorNotImplementedType ErrorNotImplementedType
// ErrorUnauthorizedType internal server error. // ErrorUnauthorizedType unauthorized.
ErrorUnauthorizedType ErrorUnauthorizedType
// ErrorServerInternalType internal server error. // ErrorServerInternalType internal server error.
ErrorServerInternalType ErrorServerInternalType
// ErrorConflictType conflict.
ErrorConflictType
) )
// String returns the string representation of the admin problem type, // String returns the string representation of the admin problem type,
@ -48,6 +50,8 @@ func (ap ProblemType) String() string {
return "unauthorized" return "unauthorized"
case ErrorServerInternalType: case ErrorServerInternalType:
return "internalServerError" return "internalServerError"
case ErrorConflictType:
return "conflict"
default: default:
return fmt.Sprintf("unsupported error type '%d'", int(ap)) return fmt.Sprintf("unsupported error type '%d'", int(ap))
} }
@ -64,7 +68,7 @@ var (
errorServerInternalMetadata = errorMetadata{ errorServerInternalMetadata = errorMetadata{
typ: ErrorServerInternalType.String(), typ: ErrorServerInternalType.String(),
details: "the server experienced an internal error", details: "the server experienced an internal error",
status: 500, status: http.StatusInternalServerError,
} }
errorMap = map[ProblemType]errorMetadata{ errorMap = map[ProblemType]errorMetadata{
ErrorNotFoundType: { ErrorNotFoundType: {
@ -98,6 +102,11 @@ var (
status: http.StatusUnauthorized, status: http.StatusUnauthorized,
}, },
ErrorServerInternalType: errorServerInternalMetadata, ErrorServerInternalType: errorServerInternalMetadata,
ErrorConflictType: {
typ: ErrorConflictType.String(),
details: "conflict",
status: http.StatusConflict,
},
} }
) )

@ -59,12 +59,12 @@ func newSubProv(subject, prov string) subProv {
return subProv{subject, prov} return subProv{subject, prov}
} }
// LoadBySubProv a admin by the subject and provisioner name. // LoadBySubProv loads an admin by subject and provisioner name.
func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) { func (c *Collection) LoadBySubProv(sub, provName string) (*linkedca.Admin, bool) {
return loadAdmin(c.bySubProv, newSubProv(sub, provName)) return loadAdmin(c.bySubProv, newSubProv(sub, provName))
} }
// LoadByProvisioner a admin by the subject and provisioner name. // LoadByProvisioner loads admins by provisioner name.
func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) { func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool) {
val, ok := c.byProv.Load(provName) val, ok := c.byProv.Load(provName)
if !ok { if !ok {
@ -78,7 +78,7 @@ func (c *Collection) LoadByProvisioner(provName string) ([]*linkedca.Admin, bool
} }
// Store adds an admin to the collection and enforces the uniqueness of // Store adds an admin to the collection and enforces the uniqueness of
// admin IDs and amdin subject <-> provisioner name combos. // admin IDs and admin subject <-> provisioner name combos.
func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error { func (c *Collection) Store(adm *linkedca.Admin, prov provisioner.Interface) error {
// Input validation. // Input validation.
if adm.ProvisionerId != prov.GetID() { if adm.ProvisionerId != prov.GetID() {

@ -49,7 +49,7 @@ func (a *Authority) StoreAdmin(ctx context.Context, adm *linkedca.Admin, prov pr
return admin.WrapErrorISE(err, "error creating admin") return admin.WrapErrorISE(err, "error creating admin")
} }
if err := a.admins.Store(adm, prov); err != nil { if err := a.admins.Store(adm, prov); err != nil {
if err := a.reloadAdminResources(ctx); err != nil { if err := a.ReloadAdminResources(ctx); err != nil {
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin store") return admin.WrapErrorISE(err, "error reloading admin resources on failed admin store")
} }
return admin.WrapErrorISE(err, "error storing admin in authority cache") return admin.WrapErrorISE(err, "error storing admin in authority cache")
@ -66,7 +66,7 @@ func (a *Authority) UpdateAdmin(ctx context.Context, id string, nu *linkedca.Adm
return nil, admin.WrapErrorISE(err, "error updating cached admin %s", id) return nil, admin.WrapErrorISE(err, "error updating cached admin %s", id)
} }
if err := a.adminDB.UpdateAdmin(ctx, adm); err != nil { if err := a.adminDB.UpdateAdmin(ctx, adm); err != nil {
if err := a.reloadAdminResources(ctx); err != nil { if err := a.ReloadAdminResources(ctx); err != nil {
return nil, admin.WrapErrorISE(err, "error reloading admin resources on failed admin update") return nil, admin.WrapErrorISE(err, "error reloading admin resources on failed admin update")
} }
return nil, admin.WrapErrorISE(err, "error updating admin %s", id) return nil, admin.WrapErrorISE(err, "error updating admin %s", id)
@ -88,7 +88,7 @@ func (a *Authority) removeAdmin(ctx context.Context, id string) error {
return admin.WrapErrorISE(err, "error removing admin %s from authority cache", id) return admin.WrapErrorISE(err, "error removing admin %s from authority cache", id)
} }
if err := a.adminDB.DeleteAdmin(ctx, id); err != nil { if err := a.adminDB.DeleteAdmin(ctx, id); err != nil {
if err := a.reloadAdminResources(ctx); err != nil { if err := a.ReloadAdminResources(ctx); err != nil {
return admin.WrapErrorISE(err, "error reloading admin resources on failed admin remove") return admin.WrapErrorISE(err, "error reloading admin resources on failed admin remove")
} }
return admin.WrapErrorISE(err, "error deleting admin %s", id) return admin.WrapErrorISE(err, "error deleting admin %s", id)

@ -13,10 +13,16 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"go.step.sm/crypto/pemutil"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin" "github.com/smallstep/certificates/authority/admin"
adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql" adminDBNosql "github.com/smallstep/certificates/authority/admin/db/nosql"
"github.com/smallstep/certificates/authority/administrator" "github.com/smallstep/certificates/authority/administrator"
"github.com/smallstep/certificates/authority/config" "github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/cas" "github.com/smallstep/certificates/cas"
casapi "github.com/smallstep/certificates/cas/apiv1" casapi "github.com/smallstep/certificates/cas/apiv1"
@ -27,9 +33,6 @@ import (
"github.com/smallstep/certificates/scep" "github.com/smallstep/certificates/scep"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"github.com/smallstep/nosql" "github.com/smallstep/nosql"
"go.step.sm/crypto/pemutil"
"go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
) )
// Authority implements the Certificate Authority internal interface. // Authority implements the Certificate Authority internal interface.
@ -81,9 +84,16 @@ type Authority struct {
authorizeRenewFunc provisioner.AuthorizeRenewFunc authorizeRenewFunc provisioner.AuthorizeRenewFunc
authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc authorizeSSHRenewFunc provisioner.AuthorizeSSHRenewFunc
// Policy engines
policyEngine *policy.Engine
adminMutex sync.RWMutex adminMutex sync.RWMutex
// Do Not initialize the authority
skipInit bool
} }
// Info contains information about the authority.
type Info struct { type Info struct {
StartTime time.Time StartTime time.Time
RootX509Certs []*x509.Certificate RootX509Certs []*x509.Certificate
@ -111,9 +121,11 @@ func New(cfg *config.Config, opts ...Option) (*Authority, error) {
} }
} }
// Initialize authority from options or configuration. if !a.skipInit {
if err := a.init(); err != nil { // Initialize authority from options or configuration.
return nil, err if err := a.init(); err != nil {
return nil, err
}
} }
return a, nil return a, nil
@ -149,16 +161,41 @@ func NewEmbedded(opts ...Option) (*Authority, error) {
// Initialize config required fields. // Initialize config required fields.
a.config.Init() a.config.Init()
// Initialize authority from options or configuration. if !a.skipInit {
if err := a.init(); err != nil { // Initialize authority from options or configuration.
return nil, err if err := a.init(); err != nil {
return nil, err
}
} }
return a, nil return a, nil
} }
// reloadAdminResources reloads admins and provisioners from the DB. type authorityKey struct{}
func (a *Authority) reloadAdminResources(ctx context.Context) error {
// NewContext adds the given authority to the context.
func NewContext(ctx context.Context, a *Authority) context.Context {
return context.WithValue(ctx, authorityKey{}, a)
}
// FromContext returns the current authority from the given context.
func FromContext(ctx context.Context) (a *Authority, ok bool) {
a, ok = ctx.Value(authorityKey{}).(*Authority)
return
}
// MustFromContext returns the current authority from the given context. It will
// panic if the authority is not in the context.
func MustFromContext(ctx context.Context) *Authority {
if a, ok := FromContext(ctx); !ok {
panic("authority is not in the context")
} else {
return a
}
}
// ReloadAdminResources reloads admins and provisioners from the DB.
func (a *Authority) ReloadAdminResources(ctx context.Context) error {
var ( var (
provList provisioner.List provList provisioner.List
adminList []*linkedca.Admin adminList []*linkedca.Admin
@ -213,6 +250,7 @@ func (a *Authority) reloadAdminResources(ctx context.Context) error {
a.provisioners = provClxn a.provisioners = provClxn
a.config.AuthorityConfig.Admins = adminList a.config.AuthorityConfig.Admins = adminList
a.admins = adminClxn a.admins = adminClxn
return nil return nil
} }
@ -224,6 +262,7 @@ func (a *Authority) init() error {
} }
var err error var err error
ctx := NewContext(context.Background(), a)
// Set password if they are not set. // Set password if they are not set.
var configPassword []byte var configPassword []byte
@ -259,10 +298,25 @@ func (a *Authority) init() error {
if a.config.KMS != nil { if a.config.KMS != nil {
options = *a.config.KMS options = *a.config.KMS
} }
a.keyManager, err = kms.New(context.Background(), options) a.keyManager, err = kms.New(ctx, options)
if err != nil {
return err
}
}
// Initialize linkedca client if necessary. On a linked RA, the issuer
// configuration might come from majordomo.
var linkedcaClient *linkedCaClient
if a.config.AuthorityConfig.EnableAdmin && a.linkedCAToken != "" && a.adminDB == nil {
linkedcaClient, err = newLinkedCAClient(a.linkedCAToken)
if err != nil { if err != nil {
return err return err
} }
// If authorityId is configured make sure it matches the one in the token
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, linkedcaClient.authorityID) {
return errors.New("error initializing linkedca: token authority and configured authority do not match")
}
linkedcaClient.Run()
} }
// Initialize the X.509 CA Service if it has not been set in the options. // Initialize the X.509 CA Service if it has not been set in the options.
@ -272,6 +326,22 @@ func (a *Authority) init() error {
options = *a.config.AuthorityConfig.Options options = *a.config.AuthorityConfig.Options
} }
// Configure linked RA
if linkedcaClient != nil && options.CertificateAuthority == "" {
conf, err := linkedcaClient.GetConfiguration(ctx)
if err != nil {
return err
}
if conf.RaConfig != nil {
options.CertificateAuthority = conf.RaConfig.CaUrl
options.CertificateAuthorityFingerprint = conf.RaConfig.Fingerprint
options.CertificateIssuer = &casapi.CertificateIssuer{
Type: conf.RaConfig.Provisioner.Type.String(),
Provisioner: conf.RaConfig.Provisioner.Name,
}
}
}
// Set the issuer password if passed in the flags. // Set the issuer password if passed in the flags.
if options.CertificateIssuer != nil && a.issuerPassword != nil { if options.CertificateIssuer != nil && a.issuerPassword != nil {
options.CertificateIssuer.Password = string(a.issuerPassword) options.CertificateIssuer.Password = string(a.issuerPassword)
@ -292,7 +362,7 @@ func (a *Authority) init() error {
} }
} }
a.x509CAService, err = cas.New(context.Background(), options) a.x509CAService, err = cas.New(ctx, options)
if err != nil { if err != nil {
return err return err
} }
@ -479,7 +549,7 @@ func (a *Authority) init() error {
} }
} }
a.scepService, err = scep.NewService(context.Background(), options) a.scepService, err = scep.NewService(ctx, options)
if err != nil { if err != nil {
return err return err
} }
@ -491,40 +561,29 @@ func (a *Authority) init() error {
// Initialize step-ca Admin Database if it's not already initialized using // Initialize step-ca Admin Database if it's not already initialized using
// WithAdminDB. // WithAdminDB.
if a.adminDB == nil { if a.adminDB == nil {
if a.linkedCAToken == "" { if linkedcaClient != nil {
// Check if AuthConfig already exists a.adminDB = linkedcaClient
a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
if err != nil {
return err
}
} else { } else {
// Use the linkedca client as the admindb. a.adminDB, err = adminDBNosql.New(a.db.(nosql.DB), admin.DefaultAuthorityID)
client, err := newLinkedCAClient(a.linkedCAToken)
if err != nil { if err != nil {
return err return err
} }
// If authorityId is configured make sure it matches the one in the token
if id := a.config.AuthorityConfig.AuthorityID; id != "" && !strings.EqualFold(id, client.authorityID) {
return errors.New("error initializing linkedca: token authority and configured authority do not match")
}
client.Run()
a.adminDB = client
} }
} }
provs, err := a.adminDB.GetProvisioners(context.Background()) provs, err := a.adminDB.GetProvisioners(ctx)
if err != nil { if err != nil {
return admin.WrapErrorISE(err, "error loading provisioners to initialize authority") return admin.WrapErrorISE(err, "error loading provisioners to initialize authority")
} }
if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") { if len(provs) == 0 && !strings.EqualFold(a.config.AuthorityConfig.DeploymentType, "linked") {
// Create First Provisioner // Create First Provisioner
prov, err := CreateFirstProvisioner(context.Background(), a.adminDB, string(a.password)) prov, err := CreateFirstProvisioner(ctx, a.adminDB, string(a.password))
if err != nil { if err != nil {
return admin.WrapErrorISE(err, "error creating first provisioner") return admin.WrapErrorISE(err, "error creating first provisioner")
} }
// Create first admin // Create first admin
if err := a.adminDB.CreateAdmin(context.Background(), &linkedca.Admin{ if err := a.adminDB.CreateAdmin(ctx, &linkedca.Admin{
ProvisionerId: prov.Id, ProvisionerId: prov.Id,
Subject: "step", Subject: "step",
Type: linkedca.Admin_SUPER_ADMIN, Type: linkedca.Admin_SUPER_ADMIN,
@ -535,7 +594,12 @@ func (a *Authority) init() error {
} }
// Load Provisioners and Admins // Load Provisioners and Admins
if err := a.reloadAdminResources(context.Background()); err != nil { if err := a.ReloadAdminResources(ctx); err != nil {
return err
}
// Load x509 and SSH Policy Engines
if err := a.reloadPolicyEngines(ctx); err != nil {
return err return err
} }
@ -570,6 +634,15 @@ func (a *Authority) init() error {
return nil return nil
} }
// GetID returns the define authority id or a zero uuid.
func (a *Authority) GetID() string {
const zeroUUID = "00000000-0000-0000-0000-000000000000"
if id := a.config.AuthorityConfig.AuthorityID; id != "" {
return id
}
return zeroUUID
}
// GetDatabase returns the authority database. If the configuration does not // GetDatabase returns the authority database. If the configuration does not
// define a database, GetDatabase will return a db.SimpleDB instance. // define a database, GetDatabase will return a db.SimpleDB instance.
func (a *Authority) GetDatabase() db.AuthDB { func (a *Authority) GetDatabase() db.AuthDB {
@ -581,6 +654,12 @@ func (a *Authority) GetAdminDatabase() admin.DB {
return a.adminDB return a.adminDB
} }
// GetConfig returns the config.
func (a *Authority) GetConfig() *config.Config {
return a.config
}
// GetInfo returns information about the authority.
func (a *Authority) GetInfo() Info { func (a *Authority) GetInfo() Info {
ai := Info{ ai := Info{
StartTime: a.startTime, StartTime: a.startTime,

@ -14,6 +14,7 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/assert" "github.com/smallstep/assert"
"github.com/smallstep/certificates/authority/config"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
@ -421,3 +422,31 @@ func TestAuthority_GetSCEPService(t *testing.T) {
}) })
} }
} }
func TestAuthority_GetID(t *testing.T) {
type fields struct {
authorityID string
}
tests := []struct {
name string
fields fields
want string
}{
{"ok", fields{""}, "00000000-0000-0000-0000-000000000000"},
{"ok with id", fields{"10b9a431-ed3b-4a5f-abee-ec35119b65e7"}, "10b9a431-ed3b-4a5f-abee-ec35119b65e7"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Authority{
config: &config.Config{
AuthorityConfig: &config.AuthConfig{
AuthorityID: tt.fields.authorityID,
},
},
}
if got := a.GetID(); got != tt.want {
t.Errorf("Authority.GetID() = %v, want %v", got, tt.want)
}
})
}
}

@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"crypto/x509" "crypto/x509"
"encoding/hex" "encoding/hex"
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -41,14 +42,12 @@ func SkipTokenReuseFromContext(ctx context.Context) bool {
return m return m
} }
// authorizeToken parses the token and returns the provisioner used to generate // getProvisionerFromToken extracts a provisioner from the given token without
// the token. This method enforces the One-Time use policy (tokens can only be // doing any token validation.
// used once). func (a *Authority) getProvisionerFromToken(token string) (provisioner.Interface, *Claims, error) {
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
// Validate payload
tok, err := jose.ParseSigned(token) tok, err := jose.ParseSigned(token)
if err != nil { if err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken: error parsing token") return nil, nil, fmt.Errorf("error parsing token: %w", err)
} }
// Get claims w/out verification. We need to look up the provisioner // Get claims w/out verification. We need to look up the provisioner
@ -56,7 +55,25 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// before we can look up the provisioner. // before we can look up the provisioner.
var claims Claims var claims Claims
if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, errs.Wrap(http.StatusUnauthorized, err, "authority.authorizeToken") return nil, nil, fmt.Errorf("error unmarshaling token: %w", err)
}
// This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok {
return nil, nil, fmt.Errorf("provisioner not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
return p, &claims, nil
}
// authorizeToken parses the token and returns the provisioner used to generate
// the token. This method enforces the One-Time use policy (tokens can only be
// used once).
func (a *Authority) authorizeToken(ctx context.Context, token string) (provisioner.Interface, error) {
p, claims, err := a.getProvisionerFromToken(token)
if err != nil {
return nil, errs.UnauthorizedErr(err)
} }
// TODO: use new persistence layer abstraction. // TODO: use new persistence layer abstraction.
@ -64,17 +81,10 @@ func (a *Authority) authorizeToken(ctx context.Context, token string) (provision
// This check is meant as a stopgap solution to the current lack of a persistence layer. // This check is meant as a stopgap solution to the current lack of a persistence layer.
if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck { if a.config.AuthorityConfig != nil && !a.config.AuthorityConfig.DisableIssuedAtCheck {
if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) { if claims.IssuedAt != nil && claims.IssuedAt.Time().Before(a.startTime) {
return nil, errs.Unauthorized("authority.authorizeToken: token issued before the bootstrap of certificate authority") return nil, errs.Unauthorized("token issued before the bootstrap of certificate authority")
} }
} }
// This method will also validate the audiences for JWK provisioners.
p, ok := a.provisioners.LoadByToken(tok, &claims.Claims)
if !ok {
return nil, errs.Unauthorized("authority.authorizeToken: provisioner "+
"not found or invalid audience (%s)", strings.Join(claims.Audience, ", "))
}
// Store the token to protect against reuse unless it's skipped. // Store the token to protect against reuse unless it's skipped.
// If we cannot get a token id from the provisioner, just hash the token. // If we cannot get a token id from the provisioner, just hash the token.
if !SkipTokenReuseFromContext(ctx) { if !SkipTokenReuseFromContext(ctx) {
@ -130,22 +140,24 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
// According to "rfc7519 JSON Web Token" acceptable skew should be no // According to "rfc7519 JSON Web Token" acceptable skew should be no
// more than a few minutes. // more than a few minutes.
if err := claims.ValidateWithLeeway(jose.Expected{ if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: prov.GetName(), Time: time.Now().UTC(),
Time: time.Now().UTC(),
}, time.Minute); err != nil { }, time.Minute); err != nil {
return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "x5c.authorizeToken; invalid x5c claims") return nil, admin.WrapError(admin.ErrorUnauthorizedType, err, "x5c.authorizeToken; invalid x5c claims")
} }
// validate audience: path matches the current path // validate audience: path matches the current path
if r.URL.Path != claims.Audience[0] { if !matchesAudience(claims.Audience, a.config.Audience(r.URL.Path)) {
return nil, admin.NewError(admin.ErrorUnauthorizedType, return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid audience claim (aud)")
"x5c.authorizeToken; x5c token has invalid audience "+ }
"claim (aud); expected %s, but got %s", r.URL.Path, claims.Audience)
// validate issuer: old versions used the provisioner name, new version uses
// 'step-admin-client/1.0'
if claims.Issuer != "step-admin-client/1.0" && claims.Issuer != prov.GetName() {
return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token has invalid issuer claim (iss)")
} }
if claims.Subject == "" { if claims.Subject == "" {
return nil, admin.NewError(admin.ErrorUnauthorizedType, return nil, admin.NewError(admin.ErrorUnauthorizedType, "x5c.authorizeToken; x5c token subject cannot be empty")
"x5c.authorizeToken; x5c token subject cannot be empty")
} }
var ( var (
@ -156,7 +168,7 @@ func (a *Authority) AuthorizeAdminToken(r *http.Request, token string) (*linkedc
adminSANs := append([]string{leaf.Subject.CommonName}, leaf.DNSNames...) adminSANs := append([]string{leaf.Subject.CommonName}, leaf.DNSNames...)
adminSANs = append(adminSANs, leaf.EmailAddresses...) adminSANs = append(adminSANs, leaf.EmailAddresses...)
for _, san := range adminSANs { for _, san := range adminSANs {
if adm, ok = a.LoadAdminBySubProv(san, claims.Issuer); ok { if adm, ok = a.LoadAdminBySubProv(san, prov.GetName()); ok {
adminFound = true adminFound = true
break break
} }
@ -186,11 +198,10 @@ func (a *Authority) UseToken(token string, prov provisioner.Interface) error {
} }
ok, err := a.db.UseToken(reuseKey, token) ok, err := a.db.UseToken(reuseKey, token)
if err != nil { if err != nil {
return errs.Wrap(http.StatusInternalServerError, err, return errs.Wrap(http.StatusInternalServerError, err, "failed when attempting to store token")
"authority.authorizeToken: failed when attempting to store token")
} }
if !ok { if !ok {
return errs.Unauthorized("authority.authorizeToken: token already used") return errs.Unauthorized("token already used")
} }
} }
return nil return nil
@ -249,10 +260,10 @@ func (a *Authority) authorizeSign(ctx context.Context, token string) ([]provisio
// AuthorizeSign authorizes a signature request by validating and authenticating // AuthorizeSign authorizes a signature request by validating and authenticating
// a token that must be sent w/ the request. // a token that must be sent w/ the request.
// //
// NOTE: This method is deprecated and should not be used. We make it available // Deprecated: Use Authorize(context.Context, string) ([]provisioner.SignOption, error).
// in the short term os as not to break existing clients.
func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) { func (a *Authority) AuthorizeSign(token string) ([]provisioner.SignOption, error) {
ctx := provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod) ctx := NewContext(context.Background(), a)
ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod)
return a.Authorize(ctx, token) return a.Authorize(ctx, token)
} }
@ -285,9 +296,16 @@ func (a *Authority) authorizeRenew(cert *x509.Certificate) error {
if isRevoked { if isRevoked {
return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...) return errs.Unauthorized("authority.authorizeRenew: certificate has been revoked", opts...)
} }
p, ok := a.provisioners.LoadByCertificate(cert) p, err := a.LoadProvisionerByCertificate(cert)
if !ok { if err != nil {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...) var ok bool
// For backward compatibility this method will also succeed if the
// certificate does not have a provisioner extension. LoadByCertificate
// returns the noop provisioner if this happens, and it allows
// certificate renewals.
if p, ok = a.provisioners.LoadByCertificate(cert); !ok {
return errs.Unauthorized("authority.authorizeRenew: provisioner not found", opts...)
}
} }
if err := p.AuthorizeRenew(context.Background(), cert); err != nil { if err := p.AuthorizeRenew(context.Background(), cert); err != nil {
return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...) return errs.Wrap(http.StatusInternalServerError, err, "authority.authorizeRenew", opts...)
@ -386,8 +404,8 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token")) return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token"))
} }
p, ok := a.provisioners.LoadByCertificate(leaf) p, err := a.LoadProvisionerByCertificate(leaf)
if !ok { if err != nil {
return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate") return nil, errs.Unauthorized("error validating renew token: cannot get provisioner from certificate")
} }
if err := a.UseToken(ott, p); err != nil { if err := a.UseToken(ott, p); err != nil {
@ -395,7 +413,6 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
} }
if err := claims.ValidateWithLeeway(jose.Expected{ if err := claims.ValidateWithLeeway(jose.Expected{
Issuer: p.GetName(),
Subject: leaf.Subject.CommonName, Subject: leaf.Subject.CommonName,
Time: time.Now().UTC(), Time: time.Now().UTC(),
}, time.Minute); err != nil { }, time.Minute); err != nil {
@ -420,6 +437,12 @@ func (a *Authority) AuthorizeRenewToken(ctx context.Context, ott string) (*x509.
return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)")) return nil, errs.InternalServerErr(err, errs.WithMessage("error validating renew token: invalid audience claim (aud)"))
} }
// validate issuer: old versions used the provisioner name, new version uses
// 'step-ca-client/1.0'
if claims.Issuer != "step-ca-client/1.0" && claims.Issuer != p.GetName() {
return nil, admin.NewError(admin.ErrorUnauthorizedType, "error validating renew token: invalid issuer claim (iss)")
}
return leaf, nil return leaf, nil
} }

@ -114,7 +114,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeToken: error parsing token"), err: errors.New("error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -133,7 +133,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: token issued before the bootstrap of certificate authority"), err: errors.New("token issued before the bootstrap of certificate authority"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -155,7 +155,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: provisioner not found or invalid audience (https://example.com/revoke)"), err: errors.New("provisioner not found or invalid audience (https://example.com/revoke)"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -192,7 +192,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: _a, auth: _a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: token already used"), err: errors.New("token already used"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -227,7 +227,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: _a, auth: _a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: token already used"), err: errors.New("token already used"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -275,7 +275,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: _a, auth: _a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: failed when attempting to store token: force"), err: errors.New("failed when attempting to store token: force"),
code: http.StatusInternalServerError, code: http.StatusInternalServerError,
} }
}, },
@ -300,7 +300,7 @@ func TestAuthority_authorizeToken(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: _a, auth: _a,
token: raw, token: raw,
err: errors.New("authority.authorizeToken: token already used"), err: errors.New("token already used"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -353,7 +353,7 @@ func TestAuthority_authorizeRevoke(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeRevoke: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeRevoke: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -437,7 +437,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -491,7 +491,7 @@ func TestAuthority_authorizeSign(t *testing.T) {
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Len(t, 7, got) assert.Equals(t, 9, len(got)) // number of provisioner.SignOptions returned
} }
} }
}) })
@ -524,7 +524,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: context.Background(), ctx: context.Background(),
err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -533,7 +533,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SignMethod),
err: errors.New("authority.Authorize: authority.authorizeSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -559,7 +559,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.RevokeMethod),
err: errors.New("authority.Authorize: authority.authorizeRevoke: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeRevoke: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -585,7 +585,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHSignMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSSHSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -615,7 +615,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRenewMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSSHRenew: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -659,7 +659,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRevokeMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSSHRevoke: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -685,7 +685,7 @@ func TestAuthority_Authorize(t *testing.T) {
auth: a, auth: a,
token: "foo", token: "foo",
ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod), ctx: provisioner.NewContextWithMethod(context.Background(), provisioner.SSHRekeyMethod),
err: errors.New("authority.Authorize: authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), err: errors.New("authority.Authorize: authority.authorizeSSHRekey: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -847,6 +847,29 @@ func TestAuthority_authorizeRenew(t *testing.T) {
cert: fooCrt, cert: fooCrt,
} }
}, },
"ok/from db": func(t *testing.T) *authorizeTest {
a := testAuthority(t)
a.db = &db.MockAuthDB{
MIsRevoked: func(key string) (bool, error) {
return false, nil
},
MGetCertificateData: func(serialNumber string) (*db.CertificateData, error) {
p, ok := a.provisioners.LoadByName("step-cli")
if !ok {
t.Fatal("provisioner step-cli not found")
}
return &db.CertificateData{
Provisioner: &db.ProvisionerData{
ID: p.GetID(),
},
}, nil
},
}
return &authorizeTest{
auth: a,
cert: fooCrt,
}
},
} }
for name, genTestCase := range tests { for name, genTestCase := range tests {
@ -965,7 +988,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSSHSign: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSSHSign: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -1011,7 +1034,7 @@ func TestAuthority_authorizeSSHSign(t *testing.T) {
} }
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Len(t, 7, got) assert.Len(t, 9, got) // number of provisioner.SignOptions returned
} }
} }
}) })
@ -1059,7 +1082,7 @@ func TestAuthority_authorizeSSHRenew(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSSHRenew: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSSHRenew: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -1167,7 +1190,7 @@ func TestAuthority_authorizeSSHRevoke(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSSHRevoke: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSSHRevoke: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -1259,7 +1282,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
return &authorizeTest{ return &authorizeTest{
auth: a, auth: a,
token: "foo", token: "foo",
err: errors.New("authority.authorizeSSHRekey: authority.authorizeToken: error parsing token"), err: errors.New("authority.authorizeSSHRekey: error parsing token"),
code: http.StatusUnauthorized, code: http.StatusUnauthorized,
} }
}, },
@ -1322,7 +1345,7 @@ func TestAuthority_authorizeSSHRekey(t *testing.T) {
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
assert.Equals(t, tc.cert.Serial, cert.Serial) assert.Equals(t, tc.cert.Serial, cert.Serial)
assert.Len(t, 3, signOpts) assert.Len(t, 4, signOpts)
} }
} }
}) })
@ -1381,7 +1404,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
t1, c1 := generateX5cToken(a1, signer, jose.Claims{ t1, c1 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"}, Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com", Subject: "test.example.com",
Issuer: "step-cli", Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1400,7 +1423,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
t2, c2 := generateX5cToken(a1, signer, jose.Claims{ t2, c2 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"}, Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com", Subject: "test.example.com",
Issuer: "step-cli", Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now), IssuedAt: jose.NewNumericDate(now),
@ -1417,12 +1440,31 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
}) })
return nil return nil
})) }))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{ t3, c3 := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"}, Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com", Subject: "test.example.com",
Issuer: "step-cli", Issuer: "step-cli",
NotBefore: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour)
b, err := asn1.Marshal(stepProvisionerASN1{int(provisioner.TypeJWK), []byte("step-cli"), nil, nil})
if err != nil {
return err
}
cert.ExtraExtensions = append(cert.ExtraExtensions, pkix.Extension{
Id: asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 37476, 9000, 64, 1},
Value: b,
})
return nil
}))
badSigner, _ := generateX5cToken(a1, otherSigner, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com",
Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
cert.NotBefore = now cert.NotBefore = now
cert.NotAfter = now.Add(time.Hour) cert.NotAfter = now.Add(time.Hour)
@ -1439,7 +1481,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{ badProvisioner, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"}, Audience: []string{"https://example.com/1.0/renew"},
Subject: "test.example.com", Subject: "test.example.com",
Issuer: "step-cli", Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1477,7 +1519,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badSubject, _ := generateX5cToken(a1, signer, jose.Claims{ badSubject, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/renew"}, Audience: []string{"https://example.com/1.0/renew"},
Subject: "bad-subject", Subject: "bad-subject",
Issuer: "step-cli", Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1496,7 +1538,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{ badNotBefore, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"}, Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com", Subject: "test.example.com",
Issuer: "step-cli", Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)), NotBefore: jose.NewNumericDate(now.Add(5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(10 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1515,7 +1557,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{ badExpiry, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"}, Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com", Subject: "test.example.com",
Issuer: "step-cli", Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)), NotBefore: jose.NewNumericDate(now.Add(-5 * time.Minute)),
Expiry: jose.NewNumericDate(now.Add(-time.Minute)), Expiry: jose.NewNumericDate(now.Add(-time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1534,7 +1576,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{ badIssuedAt, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"}, Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com", Subject: "test.example.com",
Issuer: "step-cli", Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)), IssuedAt: jose.NewNumericDate(now.Add(5 * time.Minute)),
@ -1554,7 +1596,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
badAudience, _ := generateX5cToken(a1, signer, jose.Claims{ badAudience, _ := generateX5cToken(a1, signer, jose.Claims{
Audience: []string{"https://example.com/1.0/sign"}, Audience: []string{"https://example.com/1.0/sign"},
Subject: "test.example.com", Subject: "test.example.com",
Issuer: "step-cli", Issuer: "step-ca-client/1.0",
NotBefore: jose.NewNumericDate(now), NotBefore: jose.NewNumericDate(now),
Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)), Expiry: jose.NewNumericDate(now.Add(5 * time.Minute)),
}, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error { }, provisioner.CertificateEnforcerFunc(func(cert *x509.Certificate) error {
@ -1584,6 +1626,7 @@ func TestAuthority_AuthorizeRenewToken(t *testing.T) {
}{ }{
{"ok", a1, args{ctx, t1}, c1, false}, {"ok", a1, args{ctx, t1}, c1, false},
{"ok expired cert", a1, args{ctx, t2}, c2, false}, {"ok expired cert", a1, args{ctx, t2}, c2, false},
{"ok provisioner issuer", a1, args{ctx, t3}, c3, false},
{"fail token", a1, args{ctx, "not.a.token"}, nil, true}, {"fail token", a1, args{ctx, "not.a.token"}, nil, true},
{"fail token reuse", a1, args{ctx, t1}, nil, true}, {"fail token reuse", a1, args{ctx, t1}, nil, true},
{"fail token signature", a1, args{ctx, badSigner}, nil, true}, {"fail token signature", a1, args{ctx, badSigner}, nil, true},

@ -8,12 +8,15 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/policy"
"github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/authority/provisioner"
cas "github.com/smallstep/certificates/cas/apiv1" cas "github.com/smallstep/certificates/cas/apiv1"
"github.com/smallstep/certificates/db" "github.com/smallstep/certificates/db"
kms "github.com/smallstep/certificates/kms/apiv1" kms "github.com/smallstep/certificates/kms/apiv1"
"github.com/smallstep/certificates/templates" "github.com/smallstep/certificates/templates"
"go.step.sm/linkedca"
) )
const ( const (
@ -26,27 +29,27 @@ var (
DefaultBackdate = time.Minute DefaultBackdate = time.Minute
// DefaultDisableRenewal disables renewals per provisioner. // DefaultDisableRenewal disables renewals per provisioner.
DefaultDisableRenewal = false DefaultDisableRenewal = false
// DefaultAllowRenewAfterExpiry allows renewals even if the certificate is // DefaultAllowRenewalAfterExpiry allows renewals even if the certificate is
// expired. // expired.
DefaultAllowRenewAfterExpiry = false DefaultAllowRenewalAfterExpiry = false
// DefaultEnableSSHCA enable SSH CA features per provisioner or globally // DefaultEnableSSHCA enable SSH CA features per provisioner or globally
// for all provisioners. // for all provisioners.
DefaultEnableSSHCA = false DefaultEnableSSHCA = false
// GlobalProvisionerClaims default claims for the Authority. Can be overridden // GlobalProvisionerClaims default claims for the Authority. Can be overridden
// by provisioner specific claims. // by provisioner specific claims.
GlobalProvisionerClaims = provisioner.Claims{ GlobalProvisionerClaims = provisioner.Claims{
MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, // TLS certs
MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour},
MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs MinUserSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // User SSH certs
MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour}, MaxUserSSHDur: &provisioner.Duration{Duration: 24 * time.Hour},
DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour}, DefaultUserSSHDur: &provisioner.Duration{Duration: 16 * time.Hour},
MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs MinHostSSHDur: &provisioner.Duration{Duration: 5 * time.Minute}, // Host SSH certs
MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, MaxHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour}, DefaultHostSSHDur: &provisioner.Duration{Duration: 30 * 24 * time.Hour},
EnableSSHCA: &DefaultEnableSSHCA, EnableSSHCA: &DefaultEnableSSHCA,
DisableRenewal: &DefaultDisableRenewal, DisableRenewal: &DefaultDisableRenewal,
AllowRenewAfterExpiry: &DefaultAllowRenewAfterExpiry, AllowRenewalAfterExpiry: &DefaultAllowRenewalAfterExpiry,
} }
) )
@ -68,6 +71,7 @@ type Config struct {
TLS *TLSOptions `json:"tls,omitempty"` TLS *TLSOptions `json:"tls,omitempty"`
Password string `json:"password,omitempty"` Password string `json:"password,omitempty"`
Templates *templates.Templates `json:"templates,omitempty"` Templates *templates.Templates `json:"templates,omitempty"`
CommonName string `json:"commonName,omitempty"`
CRL *CRLConfig `json:"crl,omitempty"` CRL *CRLConfig `json:"crl,omitempty"`
} }
@ -95,6 +99,7 @@ type AuthConfig struct {
Admins []*linkedca.Admin `json:"-"` Admins []*linkedca.Admin `json:"-"`
Template *ASN1DN `json:"template,omitempty"` Template *ASN1DN `json:"template,omitempty"`
Claims *provisioner.Claims `json:"claims,omitempty"` Claims *provisioner.Claims `json:"claims,omitempty"`
Policy *policy.Options `json:"policy,omitempty"`
DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"` DisableIssuedAtCheck bool `json:"disableIssuedAtCheck,omitempty"`
Backdate *provisioner.Duration `json:"backdate,omitempty"` Backdate *provisioner.Duration `json:"backdate,omitempty"`
EnableAdmin bool `json:"enableAdmin,omitempty"` EnableAdmin bool `json:"enableAdmin,omitempty"`
@ -180,6 +185,9 @@ func (c *Config) Init() {
if c.AuthorityConfig == nil { if c.AuthorityConfig == nil {
c.AuthorityConfig = &AuthConfig{} c.AuthorityConfig = &AuthConfig{}
} }
if c.CommonName == "" {
c.CommonName = "Step Online CA"
}
c.AuthorityConfig.init() c.AuthorityConfig.init()
} }
@ -311,6 +319,18 @@ func (c *Config) GetAudiences() provisioner.Audiences {
return audiences return audiences
} }
// Audience returns the list of audiences for a given path.
func (c *Config) Audience(path string) []string {
audiences := make([]string, len(c.DNSNames)+1)
for i, name := range c.DNSNames {
hostname := toHostname(name)
audiences[i] = "https://" + hostname + path
}
// For backward compatibility
audiences[len(c.DNSNames)] = path
return audiences
}
func toHostname(name string) string { func toHostname(name string) string {
// ensure an IPv6 address is represented with square brackets when used as hostname // ensure an IPv6 address is represented with square brackets when used as hostname
if ip := net.ParseIP(name); ip != nil && ip.To4() == nil { if ip := net.ParseIP(name); ip != nil && ip.To4() == nil {

@ -2,6 +2,7 @@ package config
import ( import (
"fmt" "fmt"
"reflect"
"testing" "testing"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -317,3 +318,38 @@ func Test_toHostname(t *testing.T) {
}) })
} }
} }
func TestConfig_Audience(t *testing.T) {
type fields struct {
DNSNames []string
}
type args struct {
path string
}
tests := []struct {
name string
fields fields
args args
want []string
}{
{"ok", fields{[]string{
"ca", "ca.example.com", "127.0.0.1", "::1",
}}, args{"/path"}, []string{
"https://ca/path",
"https://ca.example.com/path",
"https://127.0.0.1/path",
"https://[::1]/path",
"/path",
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Config{
DNSNames: tt.fields.DNSNames,
}
if got := c.Audience(tt.args.path); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Config.Audience() = %v, want %v", got, tt.want)
}
})
}
}

@ -15,15 +15,19 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/db" "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/keyutil" "go.step.sm/crypto/keyutil"
"go.step.sm/crypto/tlsutil" "go.step.sm/crypto/tlsutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"go.step.sm/linkedca" "go.step.sm/linkedca"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc" "github.com/smallstep/certificates/authority/admin"
"google.golang.org/grpc/credentials" "github.com/smallstep/certificates/authority/provisioner"
"github.com/smallstep/certificates/db"
) )
const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$" const uuidPattern = "^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$"
@ -34,6 +38,9 @@ type linkedCaClient struct {
authorityID string authorityID string
} }
// interface guard
var _ admin.DB = (*linkedCaClient)(nil)
type linkedCAClaims struct { type linkedCAClaims struct {
jose.Claims jose.Claims
SANs []string `json:"sans"` SANs []string `json:"sans"`
@ -115,6 +122,13 @@ func newLinkedCAClient(token string) (*linkedCaClient, error) {
}, nil }, nil
} }
// IsLinkedCA is a sentinel function that can be used to
// check if a linkedCaClient is the underlying type of an
// admin.DB interface.
func (c *linkedCaClient) IsLinkedCA() bool {
return true
}
func (c *linkedCaClient) Run() { func (c *linkedCaClient) Run() {
c.renewer.Run() c.renewer.Run()
} }
@ -151,13 +165,21 @@ func (c *linkedCaClient) GetProvisioner(ctx context.Context, id string) (*linked
} }
func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) { func (c *linkedCaClient) GetProvisioners(ctx context.Context) ([]*linkedca.Provisioner, error) {
resp, err := c.GetConfiguration(ctx)
if err != nil {
return nil, err
}
return resp.Provisioners, nil
}
func (c *linkedCaClient) GetConfiguration(ctx context.Context) (*linkedca.ConfigurationResponse, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{
AuthorityId: c.authorityID, AuthorityId: c.authorityID,
}) })
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error getting provisioners") return nil, errors.Wrap(err, "error getting configuration")
} }
return resp.Provisioners, nil return resp, nil
} }
func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error { func (c *linkedCaClient) UpdateProvisioner(ctx context.Context, prov *linkedca.Provisioner) error {
@ -204,11 +226,9 @@ func (c *linkedCaClient) GetAdmin(ctx context.Context, id string) (*linkedca.Adm
} }
func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) { func (c *linkedCaClient) GetAdmins(ctx context.Context) ([]*linkedca.Admin, error) {
resp, err := c.client.GetConfiguration(ctx, &linkedca.ConfigurationRequest{ resp, err := c.GetConfiguration(ctx)
AuthorityId: c.authorityID,
})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error getting admins") return nil, err
} }
return resp.Admins, nil return resp.Admins, nil
} }
@ -228,12 +248,35 @@ func (c *linkedCaClient) DeleteAdmin(ctx context.Context, id string) error {
return errors.Wrap(err, "error deleting admin") return errors.Wrap(err, "error deleting admin")
} }
func (c *linkedCaClient) StoreCertificateChain(fullchain ...*x509.Certificate) error { func (c *linkedCaClient) GetCertificateData(serial string) (*db.CertificateData, error) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
resp, err := c.client.GetCertificate(ctx, &linkedca.GetCertificateRequest{
Serial: serial,
})
if err != nil {
return nil, err
}
var pd *db.ProvisionerData
if p := resp.Provisioner; p != nil {
pd = &db.ProvisionerData{
ID: p.Id, Name: p.Name, Type: p.Type.String(),
}
}
return &db.CertificateData{
Provisioner: pd,
}, nil
}
func (c *linkedCaClient) StoreCertificateChain(p provisioner.Interface, fullchain ...*x509.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel() defer cancel()
_, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{ _, err := c.client.PostCertificate(ctx, &linkedca.CertificateRequest{
PemCertificate: serializeCertificateChain(fullchain[0]), PemCertificate: serializeCertificateChain(fullchain[0]),
PemCertificateChain: serializeCertificateChain(fullchain[1:]...), PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
Provisioner: createProvisionerIdentity(p),
}) })
return errors.Wrap(err, "error posting certificate") return errors.Wrap(err, "error posting certificate")
} }
@ -246,18 +289,30 @@ func (c *linkedCaClient) StoreRenewedCertificate(parent *x509.Certificate, fullc
PemCertificateChain: serializeCertificateChain(fullchain[1:]...), PemCertificateChain: serializeCertificateChain(fullchain[1:]...),
PemParentCertificate: serializeCertificateChain(parent), PemParentCertificate: serializeCertificateChain(parent),
}) })
return errors.Wrap(err, "error posting certificate") return errors.Wrap(err, "error posting renewed certificate")
} }
func (c *linkedCaClient) StoreSSHCertificate(crt *ssh.Certificate) error { func (c *linkedCaClient) StoreSSHCertificate(p provisioner.Interface, crt *ssh.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel() defer cancel()
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{ _, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
Certificate: string(ssh.MarshalAuthorizedKey(crt)), Certificate: string(ssh.MarshalAuthorizedKey(crt)),
Provisioner: createProvisionerIdentity(p),
}) })
return errors.Wrap(err, "error posting ssh certificate") return errors.Wrap(err, "error posting ssh certificate")
} }
func (c *linkedCaClient) StoreRenewedSSHCertificate(p provisioner.Interface, parent, crt *ssh.Certificate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
_, err := c.client.PostSSHCertificate(ctx, &linkedca.SSHCertificateRequest{
Certificate: string(ssh.MarshalAuthorizedKey(crt)),
ParentCertificate: string(ssh.MarshalAuthorizedKey(parent)),
Provisioner: createProvisionerIdentity(p),
})
return errors.Wrap(err, "error posting renewed ssh certificate")
}
func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error { func (c *linkedCaClient) Revoke(crt *x509.Certificate, rci *db.RevokedCertificateInfo) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel() defer cancel()
@ -310,6 +365,33 @@ func (c *linkedCaClient) IsSSHRevoked(serial string) (bool, error) {
return resp.Status != linkedca.RevocationStatus_ACTIVE, nil return resp.Status != linkedca.RevocationStatus_ACTIVE, nil
} }
func (c *linkedCaClient) CreateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet")
}
func (c *linkedCaClient) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
return nil, errors.New("not implemented yet")
}
func (c *linkedCaClient) UpdateAuthorityPolicy(ctx context.Context, policy *linkedca.Policy) error {
return errors.New("not implemented yet")
}
func (c *linkedCaClient) DeleteAuthorityPolicy(ctx context.Context) error {
return errors.New("not implemented yet")
}
func createProvisionerIdentity(p provisioner.Interface) *linkedca.ProvisionerIdentity {
if p == nil {
return nil
}
return &linkedca.ProvisionerIdentity{
Id: p.GetID(),
Type: linkedca.Provisioner_Type(p.GetType()),
Name: p.GetName(),
}
}
func serializeCertificate(crt *x509.Certificate) string { func serializeCertificate(crt *x509.Certificate) string {
if crt == nil { if crt == nil {
return "" return ""

@ -266,6 +266,16 @@ func WithAdminDB(d admin.DB) Option {
} }
} }
// WithProvisioners is an option to set the provisioner collection.
//
// Deprecated: provisioner collections will likely change
func WithProvisioners(ps *provisioner.Collection) Option {
return func(a *Authority) error {
a.provisioners = ps
return nil
}
}
// WithLinkedCAToken is an option to set the authentication token used to enable // WithLinkedCAToken is an option to set the authentication token used to enable
// linked ca. // linked ca.
func WithLinkedCAToken(token string) Option { func WithLinkedCAToken(token string) Option {
@ -284,6 +294,15 @@ func WithX509Enforcers(ces ...provisioner.CertificateEnforcer) Option {
} }
} }
// WithSkipInit is an option that allows the constructor to skip initializtion
// of the authority.
func WithSkipInit() Option {
return func(a *Authority) error {
a.skipInit = true
return nil
}
}
func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) { func readCertificateBundle(pemCerts []byte) ([]*x509.Certificate, error) {
var block *pem.Block var block *pem.Block
var certs []*x509.Certificate var certs []*x509.Certificate

@ -0,0 +1,265 @@
package authority
import (
"context"
"errors"
"fmt"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/authority/admin"
authPolicy "github.com/smallstep/certificates/authority/policy"
policy "github.com/smallstep/certificates/policy"
)
type policyErrorType int
const (
AdminLockOut policyErrorType = iota + 1
StoreFailure
ReloadFailure
ConfigurationFailure
EvaluationFailure
InternalFailure
)
type PolicyError struct {
Typ policyErrorType
Err error
}
func (p *PolicyError) Error() string {
return p.Err.Error()
}
func (a *Authority) GetAuthorityPolicy(ctx context.Context) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
p, err := a.adminDB.GetAuthorityPolicy(ctx)
if err != nil {
return nil, &PolicyError{
Typ: InternalFailure,
Err: err,
}
}
return p, nil
}
func (a *Authority) CreateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil {
return nil, err
}
if err := a.adminDB.CreateAuthorityPolicy(ctx, p); err != nil {
return nil, &PolicyError{
Typ: StoreFailure,
Err: err,
}
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return nil, &PolicyError{
Typ: ReloadFailure,
Err: fmt.Errorf("error reloading policy engines when creating authority policy: %w", err),
}
}
return p, nil
}
func (a *Authority) UpdateAuthorityPolicy(ctx context.Context, adm *linkedca.Admin, p *linkedca.Policy) (*linkedca.Policy, error) {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.checkAuthorityPolicy(ctx, adm, p); err != nil {
return nil, err
}
if err := a.adminDB.UpdateAuthorityPolicy(ctx, p); err != nil {
return nil, &PolicyError{
Typ: StoreFailure,
Err: err,
}
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return nil, &PolicyError{
Typ: ReloadFailure,
Err: fmt.Errorf("error reloading policy engines when updating authority policy: %w", err),
}
}
return p, nil
}
func (a *Authority) RemoveAuthorityPolicy(ctx context.Context) error {
a.adminMutex.Lock()
defer a.adminMutex.Unlock()
if err := a.adminDB.DeleteAuthorityPolicy(ctx); err != nil {
return &PolicyError{
Typ: StoreFailure,
Err: err,
}
}
if err := a.reloadPolicyEngines(ctx); err != nil {
return &PolicyError{
Typ: ReloadFailure,
Err: fmt.Errorf("error reloading policy engines when deleting authority policy: %w", err),
}
}
return nil
}
func (a *Authority) checkAuthorityPolicy(ctx context.Context, currentAdmin *linkedca.Admin, p *linkedca.Policy) error {
// no policy and thus nothing to evaluate; return early
if p == nil {
return nil
}
// get all current admins from the database
allAdmins, err := a.adminDB.GetAdmins(ctx)
if err != nil {
return &PolicyError{
Typ: InternalFailure,
Err: fmt.Errorf("error retrieving admins: %w", err),
}
}
return a.checkPolicy(ctx, currentAdmin, allAdmins, p)
}
func (a *Authority) checkProvisionerPolicy(ctx context.Context, provName string, p *linkedca.Policy) error {
// no policy and thus nothing to evaluate; return early
if p == nil {
return nil
}
// get all admins for the provisioner; ignoring case in which they're not found
allProvisionerAdmins, _ := a.admins.LoadByProvisioner(provName)
// check the policy; pass in nil as the current admin, as all admins for the
// provisioner will be checked by looping through allProvisionerAdmins. Also,
// the current admin may be a super admin not belonging to the provisioner, so
// can't be blocked, but is not required to be in the policy, either.
return a.checkPolicy(ctx, nil, allProvisionerAdmins, p)
}
// checkPolicy checks if a new or updated policy configuration results in the user
// locking themselves or other admins out of the CA.
func (a *Authority) checkPolicy(ctx context.Context, currentAdmin *linkedca.Admin, otherAdmins []*linkedca.Admin, p *linkedca.Policy) error {
// convert the policy; return early if nil
policyOptions := authPolicy.LinkedToCertificates(p)
if policyOptions == nil {
return nil
}
engine, err := authPolicy.NewX509PolicyEngine(policyOptions.GetX509Options())
if err != nil {
return &PolicyError{
Typ: ConfigurationFailure,
Err: err,
}
}
// when an empty X.509 policy is provided, the resulting engine is nil
// and there's no policy to evaluate.
if engine == nil {
return nil
}
// TODO(hs): Provide option to force the policy, even when the admin subject would be locked out?
// check if the admin user that instructed the authority policy to be
// created or updated, would still be allowed when the provided policy
// would be applied. This case is skipped when current admin is nil, which
// is the case when a provisioner policy is checked.
if currentAdmin != nil {
sans := []string{currentAdmin.GetSubject()}
if err := isAllowed(engine, sans); err != nil {
return err
}
}
// loop through admins to verify that none of them would be
// locked out when the new policy were to be applied. Returns
// an error with a message that includes the admin subject that
// would be locked out.
for _, adm := range otherAdmins {
sans := []string{adm.GetSubject()}
if err := isAllowed(engine, sans); err != nil {
return err
}
}
// TODO(hs): mask the error message for non-super admins?
return nil
}
// reloadPolicyEngines reloads x509 and SSH policy engines using
// configuration stored in the DB or from the configuration file.
func (a *Authority) reloadPolicyEngines(ctx context.Context) error {
var (
err error
policyOptions *authPolicy.Options
)
if a.config.AuthorityConfig.EnableAdmin {
// temporarily disable policy loading when LinkedCA is in use
if _, ok := a.adminDB.(*linkedCaClient); ok {
return nil
}
linkedPolicy, err := a.adminDB.GetAuthorityPolicy(ctx)
if err != nil {
var ae *admin.Error
if isAdminError := errors.As(err, &ae); (isAdminError && ae.Type != admin.ErrorNotFoundType.String()) || !isAdminError {
return fmt.Errorf("error getting policy to (re)load policy engines: %w", err)
}
}
policyOptions = authPolicy.LinkedToCertificates(linkedPolicy)
} else {
policyOptions = a.config.AuthorityConfig.Policy
}
engine, err := authPolicy.New(policyOptions)
if err != nil {
return err
}
// only update the policy engine when no error was returned
a.policyEngine = engine
return nil
}
func isAllowed(engine authPolicy.X509Policy, sans []string) error {
if err := engine.AreSANsAllowed(sans); err != nil {
var policyErr *policy.NamePolicyError
isNamePolicyError := errors.As(err, &policyErr)
if isNamePolicyError && policyErr.Reason == policy.NotAllowed {
return &PolicyError{
Typ: AdminLockOut,
Err: fmt.Errorf("the provided policy would lock out %s from the CA. Please update your policy to include %s as an allowed name", sans, sans),
}
}
return &PolicyError{
Typ: EvaluationFailure,
Err: err,
}
}
return nil
}

@ -0,0 +1,114 @@
package policy
import (
"crypto/x509"
"errors"
"fmt"
"golang.org/x/crypto/ssh"
)
// Engine is a container for multiple policies.
type Engine struct {
x509Policy X509Policy
sshUserPolicy UserPolicy
sshHostPolicy HostPolicy
}
// New returns a new Engine using Options.
func New(options *Options) (*Engine, error) {
// if no options provided, return early
if options == nil {
return nil, nil
}
var (
x509Policy X509Policy
sshHostPolicy HostPolicy
sshUserPolicy UserPolicy
err error
)
// initialize the x509 allow/deny policy engine
if x509Policy, err = NewX509PolicyEngine(options.GetX509Options()); err != nil {
return nil, err
}
// initialize the SSH allow/deny policy engine for host certificates
if sshHostPolicy, err = NewSSHHostPolicyEngine(options.GetSSHOptions()); err != nil {
return nil, err
}
// initialize the SSH allow/deny policy engine for user certificates
if sshUserPolicy, err = NewSSHUserPolicyEngine(options.GetSSHOptions()); err != nil {
return nil, err
}
return &Engine{
x509Policy: x509Policy,
sshHostPolicy: sshHostPolicy,
sshUserPolicy: sshUserPolicy,
}, nil
}
// IsX509CertificateAllowed evaluates an X.509 certificate against
// the X.509 policy (if available) and returns an error if one of the
// names in the certificate is not allowed.
func (e *Engine) IsX509CertificateAllowed(cert *x509.Certificate) error {
// return early if there's no policy to evaluate
if e == nil || e.x509Policy == nil {
return nil
}
// return result of X.509 policy evaluation
return e.x509Policy.IsX509CertificateAllowed(cert)
}
// AreSANsAllowed evaluates the slice of SANs against the X.509 policy
// (if available) and returns an error if one of the SANs is not allowed.
func (e *Engine) AreSANsAllowed(sans []string) error {
// return early if there's no policy to evaluate
if e == nil || e.x509Policy == nil {
return nil
}
// return result of X.509 policy evaluation
return e.x509Policy.AreSANsAllowed(sans)
}
// IsSSHCertificateAllowed evaluates an SSH certificate against the
// user or host policy (if configured) and returns an error if one of the
// principals in the certificate is not allowed.
func (e *Engine) IsSSHCertificateAllowed(cert *ssh.Certificate) error {
// return early if there's no policy to evaluate
if e == nil || (e.sshHostPolicy == nil && e.sshUserPolicy == nil) {
return nil
}
switch cert.CertType {
case ssh.HostCert:
// when no host policy engine is configured, but a user policy engine is
// configured, the host certificate is denied.
if e.sshHostPolicy == nil && e.sshUserPolicy != nil {
return errors.New("authority not allowed to sign ssh host certificates")
}
// return result of SSH host policy evaluation
return e.sshHostPolicy.IsSSHCertificateAllowed(cert)
case ssh.UserCert:
// when no user policy engine is configured, but a host policy engine is
// configured, the user certificate is denied.
if e.sshUserPolicy == nil && e.sshHostPolicy != nil {
return errors.New("authority not allowed to sign ssh user certificates")
}
// return result of SSH user policy evaluation
return e.sshUserPolicy.IsSSHCertificateAllowed(cert)
default:
return fmt.Errorf("unexpected ssh certificate type %q", cert.CertType)
}
}

@ -0,0 +1,194 @@
package policy
// Options is a container for authority level x509 and SSH
// policy configuration.
type Options struct {
X509 *X509PolicyOptions `json:"x509,omitempty"`
SSH *SSHPolicyOptions `json:"ssh,omitempty"`
}
// GetX509Options returns the x509 authority level policy
// configuration
func (o *Options) GetX509Options() *X509PolicyOptions {
if o == nil {
return nil
}
return o.X509
}
// GetSSHOptions returns the SSH authority level policy
// configuration
func (o *Options) GetSSHOptions() *SSHPolicyOptions {
if o == nil {
return nil
}
return o.SSH
}
// X509PolicyOptionsInterface is an interface for providers
// of x509 allowed and denied names.
type X509PolicyOptionsInterface interface {
GetAllowedNameOptions() *X509NameOptions
GetDeniedNameOptions() *X509NameOptions
AreWildcardNamesAllowed() bool
}
// X509PolicyOptions is a container for x509 allowed and denied
// names.
type X509PolicyOptions struct {
// AllowedNames contains the x509 allowed names
AllowedNames *X509NameOptions `json:"allow,omitempty"`
// DeniedNames contains the x509 denied names
DeniedNames *X509NameOptions `json:"deny,omitempty"`
// AllowWildcardNames indicates if literal wildcard names
// like *.example.com are allowed. Defaults to false.
AllowWildcardNames bool `json:"allowWildcardNames,omitempty"`
}
// X509NameOptions models the X509 name policy configuration.
type X509NameOptions struct {
CommonNames []string `json:"cn,omitempty"`
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
URIDomains []string `json:"uri,omitempty"`
}
// HasNames checks if the AllowedNameOptions has one or more
// names configured.
func (o *X509NameOptions) HasNames() bool {
return len(o.CommonNames) > 0 ||
len(o.DNSDomains) > 0 ||
len(o.IPRanges) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.URIDomains) > 0
}
// GetAllowedNameOptions returns x509 allowed name policy configuration
func (o *X509PolicyOptions) GetAllowedNameOptions() *X509NameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the x509 denied name policy configuration
func (o *X509PolicyOptions) GetDeniedNameOptions() *X509NameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
// AreWildcardNamesAllowed returns whether the authority allows
// literal wildcard names to be signed.
func (o *X509PolicyOptions) AreWildcardNamesAllowed() bool {
if o == nil {
return true
}
return o.AllowWildcardNames
}
// SSHPolicyOptionsInterface is an interface for providers of
// SSH user and host name policy configuration.
type SSHPolicyOptionsInterface interface {
GetAllowedUserNameOptions() *SSHNameOptions
GetDeniedUserNameOptions() *SSHNameOptions
GetAllowedHostNameOptions() *SSHNameOptions
GetDeniedHostNameOptions() *SSHNameOptions
}
// SSHPolicyOptions is a container for SSH user and host policy
// configuration
type SSHPolicyOptions struct {
// User contains SSH user certificate options.
User *SSHUserCertificateOptions `json:"user,omitempty"`
// Host contains SSH host certificate options.
Host *SSHHostCertificateOptions `json:"host,omitempty"`
}
// GetAllowedUserNameOptions returns the SSH allowed user name policy
// configuration.
func (o *SSHPolicyOptions) GetAllowedUserNameOptions() *SSHNameOptions {
if o == nil || o.User == nil {
return nil
}
return o.User.AllowedNames
}
// GetDeniedUserNameOptions returns the SSH denied user name policy
// configuration.
func (o *SSHPolicyOptions) GetDeniedUserNameOptions() *SSHNameOptions {
if o == nil || o.User == nil {
return nil
}
return o.User.DeniedNames
}
// GetAllowedHostNameOptions returns the SSH allowed host name policy
// configuration.
func (o *SSHPolicyOptions) GetAllowedHostNameOptions() *SSHNameOptions {
if o == nil || o.Host == nil {
return nil
}
return o.Host.AllowedNames
}
// GetDeniedHostNameOptions returns the SSH denied host name policy
// configuration.
func (o *SSHPolicyOptions) GetDeniedHostNameOptions() *SSHNameOptions {
if o == nil || o.Host == nil {
return nil
}
return o.Host.DeniedNames
}
// SSHUserCertificateOptions is a collection of SSH user certificate options.
type SSHUserCertificateOptions struct {
// AllowedNames contains the names the provisioner is authorized to sign
AllowedNames *SSHNameOptions `json:"allow,omitempty"`
// DeniedNames contains the names the provisioner is not authorized to sign
DeniedNames *SSHNameOptions `json:"deny,omitempty"`
}
// SSHHostCertificateOptions is a collection of SSH host certificate options.
// It's an alias of SSHUserCertificateOptions, as the options are the same
// for both types of certificates.
type SSHHostCertificateOptions SSHUserCertificateOptions
// SSHNameOptions models the SSH name policy configuration.
type SSHNameOptions struct {
DNSDomains []string `json:"dns,omitempty"`
IPRanges []string `json:"ip,omitempty"`
EmailAddresses []string `json:"email,omitempty"`
Principals []string `json:"principal,omitempty"`
}
// GetAllowedNameOptions returns the AllowedSSHNameOptions, which models the
// names that a provisioner is authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetAllowedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.AllowedNames
}
// GetDeniedNameOptions returns the DeniedSSHNameOptions, which models the
// names that a provisioner is NOT authorized to sign SSH certificates for.
func (o *SSHUserCertificateOptions) GetDeniedNameOptions() *SSHNameOptions {
if o == nil {
return nil
}
return o.DeniedNames
}
// HasNames checks if the SSHNameOptions has one or more
// names configured.
func (o *SSHNameOptions) HasNames() bool {
return len(o.DNSDomains) > 0 ||
len(o.IPRanges) > 0 ||
len(o.EmailAddresses) > 0 ||
len(o.Principals) > 0
}

@ -0,0 +1,45 @@
package policy
import (
"testing"
)
func TestX509PolicyOptions_IsWildcardLiteralAllowed(t *testing.T) {
tests := []struct {
name string
options *X509PolicyOptions
want bool
}{
{
name: "nil-options",
options: nil,
want: true,
},
{
name: "not-set",
options: &X509PolicyOptions{},
want: false,
},
{
name: "set-true",
options: &X509PolicyOptions{
AllowWildcardNames: true,
},
want: true,
},
{
name: "set-false",
options: &X509PolicyOptions{
AllowWildcardNames: false,
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.options.AreWildcardNamesAllowed(); got != tt.want {
t.Errorf("X509PolicyOptions.IsWildcardLiteralAllowed() = %v, want %v", got, tt.want)
}
})
}
}

@ -0,0 +1,256 @@
package policy
import (
"fmt"
"go.step.sm/linkedca"
"github.com/smallstep/certificates/policy"
)
// X509Policy is an alias for policy.X509NamePolicyEngine
type X509Policy policy.X509NamePolicyEngine
// UserPolicy is an alias for policy.SSHNamePolicyEngine
type UserPolicy policy.SSHNamePolicyEngine
// HostPolicy is an alias for policy.SSHNamePolicyEngine
type HostPolicy policy.SSHNamePolicyEngine
// NewX509PolicyEngine creates a new x509 name policy engine
func NewX509PolicyEngine(policyOptions X509PolicyOptionsInterface) (X509Policy, error) {
// return early if no policy engine options to configure
if policyOptions == nil {
return nil, nil
}
options := []policy.NamePolicyOption{}
allowed := policyOptions.GetAllowedNameOptions()
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedCommonNames(allowed.CommonNames...),
policy.WithPermittedDNSDomains(allowed.DNSDomains...),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses...),
policy.WithPermittedURIDomains(allowed.URIDomains...),
)
}
denied := policyOptions.GetDeniedNameOptions()
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedCommonNames(denied.CommonNames...),
policy.WithExcludedDNSDomains(denied.DNSDomains...),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges...),
policy.WithExcludedEmailAddresses(denied.EmailAddresses...),
policy.WithExcludedURIDomains(denied.URIDomains...),
)
}
// ensure no policy engine is returned when no name options were provided
if len(options) == 0 {
return nil, nil
}
// check if configuration specifies that wildcard names are allowed
if policyOptions.AreWildcardNamesAllowed() {
options = append(options, policy.WithAllowLiteralWildcardNames())
}
// enable subject common name verification by default
options = append(options, policy.WithSubjectCommonNameVerification())
return policy.New(options...)
}
type sshPolicyEngineType string
const (
UserPolicyEngineType sshPolicyEngineType = "user"
HostPolicyEngineType sshPolicyEngineType = "host"
)
// newSSHUserPolicyEngine creates a new SSH user certificate policy engine
func NewSSHUserPolicyEngine(policyOptions SSHPolicyOptionsInterface) (UserPolicy, error) {
policyEngine, err := newSSHPolicyEngine(policyOptions, UserPolicyEngineType)
if err != nil {
return nil, err
}
return policyEngine, nil
}
// newSSHHostPolicyEngine create a new SSH host certificate policy engine
func NewSSHHostPolicyEngine(policyOptions SSHPolicyOptionsInterface) (HostPolicy, error) {
policyEngine, err := newSSHPolicyEngine(policyOptions, HostPolicyEngineType)
if err != nil {
return nil, err
}
return policyEngine, nil
}
// newSSHPolicyEngine creates a new SSH name policy engine
func newSSHPolicyEngine(policyOptions SSHPolicyOptionsInterface, typ sshPolicyEngineType) (policy.SSHNamePolicyEngine, error) {
// return early if no policy engine options to configure
if policyOptions == nil {
return nil, nil
}
var (
allowed *SSHNameOptions
denied *SSHNameOptions
)
switch typ {
case UserPolicyEngineType:
allowed = policyOptions.GetAllowedUserNameOptions()
denied = policyOptions.GetDeniedUserNameOptions()
case HostPolicyEngineType:
allowed = policyOptions.GetAllowedHostNameOptions()
denied = policyOptions.GetDeniedHostNameOptions()
default:
return nil, fmt.Errorf("unknown SSH policy engine type %s provided", typ)
}
options := []policy.NamePolicyOption{}
if allowed != nil && allowed.HasNames() {
options = append(options,
policy.WithPermittedDNSDomains(allowed.DNSDomains...),
policy.WithPermittedIPsOrCIDRs(allowed.IPRanges...),
policy.WithPermittedEmailAddresses(allowed.EmailAddresses...),
policy.WithPermittedPrincipals(allowed.Principals...),
)
}
if denied != nil && denied.HasNames() {
options = append(options,
policy.WithExcludedDNSDomains(denied.DNSDomains...),
policy.WithExcludedIPsOrCIDRs(denied.IPRanges...),
policy.WithExcludedEmailAddresses(denied.EmailAddresses...),
policy.WithExcludedPrincipals(denied.Principals...),
)
}
// ensure no policy engine is returned when no name options were provided
if len(options) == 0 {
return nil, nil
}
return policy.New(options...)
}
func LinkedToCertificates(p *linkedca.Policy) *Options {
// return early
if p == nil {
return nil
}
// return early if x509 nor SSH is set
if p.GetX509() == nil && p.GetSsh() == nil {
return nil
}
opts := &Options{}
// fill x509 policy configuration
if x509 := p.GetX509(); x509 != nil {
opts.X509 = &X509PolicyOptions{}
if allow := x509.GetAllow(); allow != nil {
opts.X509.AllowedNames = &X509NameOptions{}
if allow.Dns != nil {
opts.X509.AllowedNames.DNSDomains = allow.Dns
}
if allow.Ips != nil {
opts.X509.AllowedNames.IPRanges = allow.Ips
}
if allow.Emails != nil {
opts.X509.AllowedNames.EmailAddresses = allow.Emails
}
if allow.Uris != nil {
opts.X509.AllowedNames.URIDomains = allow.Uris
}
if allow.CommonNames != nil {
opts.X509.AllowedNames.CommonNames = allow.CommonNames
}
}
if deny := x509.GetDeny(); deny != nil {
opts.X509.DeniedNames = &X509NameOptions{}
if deny.Dns != nil {
opts.X509.DeniedNames.DNSDomains = deny.Dns
}
if deny.Ips != nil {
opts.X509.DeniedNames.IPRanges = deny.Ips
}
if deny.Emails != nil {
opts.X509.DeniedNames.EmailAddresses = deny.Emails
}
if deny.Uris != nil {
opts.X509.DeniedNames.URIDomains = deny.Uris
}
if deny.CommonNames != nil {
opts.X509.DeniedNames.CommonNames = deny.CommonNames
}
}
opts.X509.AllowWildcardNames = x509.GetAllowWildcardNames()
}
// fill ssh policy configuration
if ssh := p.GetSsh(); ssh != nil {
opts.SSH = &SSHPolicyOptions{}
if host := ssh.GetHost(); host != nil {
opts.SSH.Host = &SSHHostCertificateOptions{}
if allow := host.GetAllow(); allow != nil {
opts.SSH.Host.AllowedNames = &SSHNameOptions{}
if allow.Dns != nil {
opts.SSH.Host.AllowedNames.DNSDomains = allow.Dns
}
if allow.Ips != nil {
opts.SSH.Host.AllowedNames.IPRanges = allow.Ips
}
if allow.Principals != nil {
opts.SSH.Host.AllowedNames.Principals = allow.Principals
}
}
if deny := host.GetDeny(); deny != nil {
opts.SSH.Host.DeniedNames = &SSHNameOptions{}
if deny.Dns != nil {
opts.SSH.Host.DeniedNames.DNSDomains = deny.Dns
}
if deny.Ips != nil {
opts.SSH.Host.DeniedNames.IPRanges = deny.Ips
}
if deny.Principals != nil {
opts.SSH.Host.DeniedNames.Principals = deny.Principals
}
}
}
if user := ssh.GetUser(); user != nil {
opts.SSH.User = &SSHUserCertificateOptions{}
if allow := user.GetAllow(); allow != nil {
opts.SSH.User.AllowedNames = &SSHNameOptions{}
if allow.Emails != nil {
opts.SSH.User.AllowedNames.EmailAddresses = allow.Emails
}
if allow.Principals != nil {
opts.SSH.User.AllowedNames.Principals = allow.Principals
}
}
if deny := user.GetDeny(); deny != nil {
opts.SSH.User.DeniedNames = &SSHNameOptions{}
if deny.Emails != nil {
opts.SSH.User.DeniedNames.EmailAddresses = deny.Emails
}
if deny.Principals != nil {
opts.SSH.User.DeniedNames.Principals = deny.Principals
}
}
}
}
return opts
}

@ -0,0 +1,155 @@
package policy
import (
"testing"
"github.com/google/go-cmp/cmp"
"go.step.sm/linkedca"
)
func TestPolicyToCertificates(t *testing.T) {
type args struct {
policy *linkedca.Policy
}
tests := []struct {
name string
args args
want *Options
}{
{
name: "nil",
args: args{
policy: nil,
},
want: nil,
},
{
name: "no-policy",
args: args{
&linkedca.Policy{},
},
want: nil,
},
{
name: "partial-policy",
args: args{
&linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"*.local"},
},
AllowWildcardNames: false,
},
},
},
want: &Options{
X509: &X509PolicyOptions{
AllowedNames: &X509NameOptions{
DNSDomains: []string{"*.local"},
},
AllowWildcardNames: false,
},
},
},
{
name: "full-policy",
args: args{
&linkedca.Policy{
X509: &linkedca.X509Policy{
Allow: &linkedca.X509Names{
Dns: []string{"step"},
Ips: []string{"127.0.0.1/24"},
Emails: []string{"*.example.com"},
Uris: []string{"https://*.local"},
CommonNames: []string{"some name"},
},
Deny: &linkedca.X509Names{
Dns: []string{"bad"},
Ips: []string{"127.0.0.30"},
Emails: []string{"badhost.example.com"},
Uris: []string{"https://badhost.local"},
CommonNames: []string{"another name"},
},
AllowWildcardNames: true,
},
Ssh: &linkedca.SSHPolicy{
Host: &linkedca.SSHHostPolicy{
Allow: &linkedca.SSHHostNames{
Dns: []string{"*.localhost"},
Ips: []string{"127.0.0.1/24"},
Principals: []string{"user"},
},
Deny: &linkedca.SSHHostNames{
Dns: []string{"badhost.localhost"},
Ips: []string{"127.0.0.40"},
Principals: []string{"root"},
},
},
User: &linkedca.SSHUserPolicy{
Allow: &linkedca.SSHUserNames{
Emails: []string{"@work"},
Principals: []string{"user"},
},
Deny: &linkedca.SSHUserNames{
Emails: []string{"root@work"},
Principals: []string{"root"},
},
},
},
},
},
want: &Options{
X509: &X509PolicyOptions{
AllowedNames: &X509NameOptions{
DNSDomains: []string{"step"},
IPRanges: []string{"127.0.0.1/24"},
EmailAddresses: []string{"*.example.com"},
URIDomains: []string{"https://*.local"},
CommonNames: []string{"some name"},
},
DeniedNames: &X509NameOptions{
DNSDomains: []string{"bad"},
IPRanges: []string{"127.0.0.30"},
EmailAddresses: []string{"badhost.example.com"},
URIDomains: []string{"https://badhost.local"},
CommonNames: []string{"another name"},
},
AllowWildcardNames: true,
},
SSH: &SSHPolicyOptions{
Host: &SSHHostCertificateOptions{
AllowedNames: &SSHNameOptions{
DNSDomains: []string{"*.localhost"},
IPRanges: []string{"127.0.0.1/24"},
Principals: []string{"user"},
},
DeniedNames: &SSHNameOptions{
DNSDomains: []string{"badhost.localhost"},
IPRanges: []string{"127.0.0.40"},
Principals: []string{"root"},
},
},
User: &SSHUserCertificateOptions{
AllowedNames: &SSHNameOptions{
EmailAddresses: []string{"@work"},
Principals: []string{"user"},
},
DeniedNames: &SSHNameOptions{
EmailAddresses: []string{"root@work"},
Principals: []string{"root"},
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := LinkedToCertificates(tt.args.policy)
if !cmp.Equal(tt.want, got) {
t.Errorf("policyToCertificates() diff=\n%s", cmp.Diff(tt.want, got))
}
})
}
}

File diff suppressed because it is too large Load Diff

@ -3,6 +3,8 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"fmt"
"net"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -23,7 +25,8 @@ type ACME struct {
RequireEAB bool `json:"requireEAB,omitempty"` RequireEAB bool `json:"requireEAB,omitempty"`
Claims *Claims `json:"claims,omitempty"` Claims *Claims `json:"claims,omitempty"`
Options *Options `json:"options,omitempty"` Options *Options `json:"options,omitempty"`
ctl *Controller
ctl *Controller
} }
// GetID returns the provisioner unique identifier. // GetID returns the provisioner unique identifier.
@ -71,7 +74,7 @@ func (p *ACME) DefaultTLSCertDuration() time.Duration {
return p.ctl.Claimer.DefaultTLSCertDuration() return p.ctl.Claimer.DefaultTLSCertDuration()
} }
// Init initializes and validates the fields of a JWK type. // Init initializes and validates the fields of an ACME type.
func (p *ACME) Init(config Config) (err error) { func (p *ACME) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
@ -80,15 +83,57 @@ func (p *ACME) Init(config Config) (err error) {
return errors.New("provisioner name cannot be empty") return errors.New("provisioner name cannot be empty")
} }
p.ctl, err = NewController(p, p.Claims, config) p.ctl, err = NewController(p, p.Claims, config, p.Options)
return return
} }
// ACMEIdentifierType encodes ACME Identifier types
type ACMEIdentifierType string
const (
// IP is the ACME ip identifier type
IP ACMEIdentifierType = "ip"
// DNS is the ACME dns identifier type
DNS ACMEIdentifierType = "dns"
)
// ACMEIdentifier encodes ACME Order Identifiers
type ACMEIdentifier struct {
Type ACMEIdentifierType
Value string
}
// AuthorizeOrderIdentifier verifies the provisioner is allowed to issue a
// certificate for an ACME Order Identifier.
func (p *ACME) AuthorizeOrderIdentifier(ctx context.Context, identifier ACMEIdentifier) error {
x509Policy := p.ctl.getPolicy().getX509()
// identifier is allowed if no policy is configured
if x509Policy == nil {
return nil
}
// assuming only valid identifiers (IP or DNS) are provided
var err error
switch identifier.Type {
case IP:
err = x509Policy.IsIPAllowed(net.ParseIP(identifier.Value))
case DNS:
err = x509Policy.IsDNSAllowed(identifier.Value)
default:
err = fmt.Errorf("invalid ACME identifier type '%s' provided", identifier.Type)
}
return err
}
// AuthorizeSign does not do any validation, because all validation is handled // AuthorizeSign does not do any validation, because all validation is handled
// in the ACME protocol. This method returns a list of modifiers / constraints // in the ACME protocol. This method returns a list of modifiers / constraints
// on the resulting certificate. // on the resulting certificate.
func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{ opts := []SignOption{
p,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeACME, p.Name, ""), newProvisionerExtensionOption(TypeACME, p.Name, ""),
newForceCNOption(p.ForceCN), newForceCNOption(p.ForceCN),
@ -96,7 +141,10 @@ func (p *ACME) AuthorizeSign(ctx context.Context, token string) ([]SignOption, e
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
}, nil newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}
return opts, nil
} }
// AuthorizeRevoke is called just before the certificate is to be revoked by // AuthorizeRevoke is called just before the certificate is to be revoked by

@ -176,9 +176,10 @@ func TestACME_AuthorizeSign(t *testing.T) {
} }
} else { } else {
if assert.Nil(t, tc.err) && assert.NotNil(t, opts) { if assert.Nil(t, tc.err) && assert.NotNil(t, opts) {
assert.Len(t, 5, opts) assert.Equals(t, 7, len(opts)) // number of SignOptions returned
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case *ACME:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeACME) assert.Equals(t, v.Type, TypeACME)
assert.Equals(t, v.Name, tc.p.GetName()) assert.Equals(t, v.Name, tc.p.GetName())
@ -192,6 +193,8 @@ func TestACME_AuthorizeSign(t *testing.T) {
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }

@ -17,10 +17,12 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
) )
// awsIssuer is the string used as issuer in the generated tokens. // awsIssuer is the string used as issuer in the generated tokens.
@ -49,22 +51,27 @@ const awsMetadataTokenTTLHeader = "X-aws-ec2-metadata-token-ttl-seconds"
// signature. // signature.
// //
// The first certificate is used in: // The first certificate is used in:
// ap-northeast-2, ap-south-1, ap-southeast-1, ap-southeast-2 //
// eu-central-1, eu-north-1, eu-west-1, eu-west-2, eu-west-3 // ap-northeast-2, ap-south-1, ap-southeast-1, ap-southeast-2
// us-east-1, us-east-2, us-west-1, us-west-2 // eu-central-1, eu-north-1, eu-west-1, eu-west-2, eu-west-3
// ca-central-1, sa-east-1 // us-east-1, us-east-2, us-west-1, us-west-2
// ca-central-1, sa-east-1
// //
// The second certificate is used in: // The second certificate is used in:
// eu-south-1 //
// eu-south-1
// //
// The third certificate is used in: // The third certificate is used in:
// ap-east-1 //
// ap-east-1
// //
// The fourth certificate is used in: // The fourth certificate is used in:
// af-south-1 //
// af-south-1
// //
// The fifth certificate is used in: // The fifth certificate is used in:
// me-south-1 //
// me-south-1
const awsCertificate = `-----BEGIN CERTIFICATE----- const awsCertificate = `-----BEGIN CERTIFICATE-----
MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV MIIDIjCCAougAwIBAgIJAKnL4UEDMN/FMA0GCSqGSIb3DQEBBQUAMGoxCzAJBgNV
BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw BAYTAlVTMRMwEQYDVQQIEwpXYXNoaW5ndG9uMRAwDgYDVQQHEwdTZWF0dGxlMRgw
@ -421,7 +428,7 @@ func (p *AWS) Init(config Config) (err error) {
} }
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config) p.ctl, err = NewController(p, p.Claims, config, p.Options)
return return
} }
@ -467,6 +474,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
} }
return append(so, return append(so,
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID), newProvisionerExtensionOption(TypeAWS, p.Name, doc.AccountID, "InstanceID", doc.InstanceID),
@ -475,6 +483,7 @@ func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
commonNameValidator(payload.Claims.Subject), commonNameValidator(payload.Claims.Subject),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
), nil ), nil
} }
@ -542,7 +551,7 @@ func (p *AWS) readURL(url string) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return nil, fmt.Errorf("Request for metadata returned non-successful status code %d", return nil, fmt.Errorf("request for metadata returned non-successful status code %d",
resp.StatusCode) resp.StatusCode)
} }
@ -575,7 +584,7 @@ func (p *AWS) readURLv2(url string) (*http.Response, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return nil, fmt.Errorf("Request for API token returned non-successful status code %d", resp.StatusCode) return nil, fmt.Errorf("request for API token returned non-successful status code %d", resp.StatusCode)
} }
token, err := io.ReadAll(resp.Body) token, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
@ -743,6 +752,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
signOptions = append(signOptions, templateOptions) signOptions = append(signOptions, templateOptions)
return append(signOptions, return append(signOptions,
p,
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
@ -753,5 +763,7 @@ func (p *AWS) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil ), nil
} }

@ -642,11 +642,11 @@ func TestAWS_AuthorizeSign(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1, "foo.local"}, 6, http.StatusOK, false}, {"ok", p1, args{t1, "foo.local"}, 8, http.StatusOK, false},
{"ok", p2, args{t2, "instance-id"}, 10, http.StatusOK, false}, {"ok", p2, args{t2, "instance-id"}, 12, http.StatusOK, false},
{"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 10, http.StatusOK, false}, {"ok", p2, args{t2Hostname, "ip-127-0-0-1.us-west-1.compute.internal"}, 12, http.StatusOK, false},
{"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 10, http.StatusOK, false}, {"ok", p2, args{t2PrivateIP, "127.0.0.1"}, 12, http.StatusOK, false},
{"ok", p1, args{t4, "instance-id"}, 6, http.StatusOK, false}, {"ok", p1, args{t4, "instance-id"}, 8, http.StatusOK, false},
{"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true}, {"fail account", p3, args{token: t3}, 0, http.StatusUnauthorized, true},
{"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{token: "token"}, 0, http.StatusUnauthorized, true},
{"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true}, {"fail subject", p1, args{token: failSubject}, 0, http.StatusUnauthorized, true},
@ -673,9 +673,10 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
default: default:
assert.Len(t, tt.wantLen, got) assert.Equals(t, tt.wantLen, len(got))
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *AWS:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeAWS) assert.Equals(t, v.Type, TypeAWS)
@ -698,6 +699,8 @@ func TestAWS_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil) assert.Equals(t, v, nil)
case dnsNamesValidator: case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"})
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
@ -810,7 +813,6 @@ func TestAWS_AuthorizeSSHSign(t *testing.T) {
} else if assert.NotNil(t, got) { } else if assert.NotNil(t, got) {
cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer)) cert, err := signSSHCertificate(tt.args.key, tt.args.sshOpts, got, signer.Key.(crypto.Signer))
if (err != nil) != tt.wantSignErr { if (err != nil) != tt.wantSignErr {
t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr) t.Errorf("SignSSH error = %v, wantSignErr %v", err, tt.wantSignErr)
} else { } else {
if tt.wantSignErr { if tt.wantSignErr {

@ -13,10 +13,12 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
) )
// azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens. // azureOIDCBaseURL is the base discovery url for Microsoft Azure tokens.
@ -219,7 +221,7 @@ func (p *Azure) Init(config Config) (err error) {
return return
} }
p.ctl, err = NewController(p, p.Claims, config) p.ctl, err = NewController(p, p.Claims, config, p.Options)
return return
} }
@ -352,6 +354,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
} }
return append(so, return append(so,
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID), newProvisionerExtensionOption(TypeAzure, p.Name, p.TenantID),
@ -359,6 +362,7 @@ func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
), nil ), nil
} }
@ -414,6 +418,7 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
signOptions = append(signOptions, templateOptions) signOptions = append(signOptions, templateOptions)
return append(signOptions, return append(signOptions,
p,
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
@ -424,6 +429,8 @@ func (p *Azure) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
&sshCertValidityValidator{p.ctl.Claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil ), nil
} }

@ -474,11 +474,11 @@ func TestAzure_AuthorizeSign(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1}, 5, http.StatusOK, false}, {"ok", p1, args{t1}, 7, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false}, {"ok", p2, args{t2}, 12, http.StatusOK, false},
{"ok", p1, args{t11}, 5, http.StatusOK, false}, {"ok", p1, args{t11}, 7, http.StatusOK, false},
{"ok", p5, args{t5}, 5, http.StatusOK, false}, {"ok", p5, args{t5}, 7, http.StatusOK, false},
{"ok", p7, args{t7}, 5, http.StatusOK, false}, {"ok", p7, args{t7}, 7, http.StatusOK, false},
{"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true}, {"fail tenant", p3, args{t3}, 0, http.StatusUnauthorized, true},
{"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true}, {"fail resource group", p4, args{t4}, 0, http.StatusUnauthorized, true},
{"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true}, {"fail subscription", p6, args{t6}, 0, http.StatusUnauthorized, true},
@ -502,9 +502,10 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
default: default:
assert.Len(t, tt.wantLen, got) assert.Equals(t, tt.wantLen, len(got))
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *Azure:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeAzure) assert.Equals(t, v.Type, TypeAzure)
@ -527,6 +528,8 @@ func TestAzure_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil) assert.Equals(t, v, nil)
case dnsNamesValidator: case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"virtualMachine"}) assert.Equals(t, []string(v), []string{"virtualMachine"})
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }

@ -24,8 +24,8 @@ type Claims struct {
EnableSSHCA *bool `json:"enableSSHCA,omitempty"` EnableSSHCA *bool `json:"enableSSHCA,omitempty"`
// Renewal properties // Renewal properties
DisableRenewal *bool `json:"disableRenewal,omitempty"` DisableRenewal *bool `json:"disableRenewal,omitempty"`
AllowRenewAfterExpiry *bool `json:"allowRenewAfterExpiry,omitempty"` AllowRenewalAfterExpiry *bool `json:"allowRenewalAfterExpiry,omitempty"`
} }
// Claimer is the type that controls claims. It provides an interface around the // Claimer is the type that controls claims. It provides an interface around the
@ -44,22 +44,22 @@ func NewClaimer(claims *Claims, global Claims) (*Claimer, error) {
// Claims returns the merge of the inner and global claims. // Claims returns the merge of the inner and global claims.
func (c *Claimer) Claims() Claims { func (c *Claimer) Claims() Claims {
disableRenewal := c.IsDisableRenewal() disableRenewal := c.IsDisableRenewal()
allowRenewAfterExpiry := c.AllowRenewAfterExpiry() allowRenewalAfterExpiry := c.AllowRenewalAfterExpiry()
enableSSHCA := c.IsSSHCAEnabled() enableSSHCA := c.IsSSHCAEnabled()
return Claims{ return Claims{
MinTLSDur: &Duration{c.MinTLSCertDuration()}, MinTLSDur: &Duration{c.MinTLSCertDuration()},
MaxTLSDur: &Duration{c.MaxTLSCertDuration()}, MaxTLSDur: &Duration{c.MaxTLSCertDuration()},
DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()}, DefaultTLSDur: &Duration{c.DefaultTLSCertDuration()},
MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()}, MinUserSSHDur: &Duration{c.MinUserSSHCertDuration()},
MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()}, MaxUserSSHDur: &Duration{c.MaxUserSSHCertDuration()},
DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()}, DefaultUserSSHDur: &Duration{c.DefaultUserSSHCertDuration()},
MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()}, MinHostSSHDur: &Duration{c.MinHostSSHCertDuration()},
MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()}, MaxHostSSHDur: &Duration{c.MaxHostSSHCertDuration()},
DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()}, DefaultHostSSHDur: &Duration{c.DefaultHostSSHCertDuration()},
EnableSSHCA: &enableSSHCA, EnableSSHCA: &enableSSHCA,
DisableRenewal: &disableRenewal, DisableRenewal: &disableRenewal,
AllowRenewAfterExpiry: &allowRenewAfterExpiry, AllowRenewalAfterExpiry: &allowRenewalAfterExpiry,
} }
} }
@ -109,14 +109,14 @@ func (c *Claimer) IsDisableRenewal() bool {
return *c.claims.DisableRenewal return *c.claims.DisableRenewal
} }
// AllowRenewAfterExpiry returns if the renewal flow is authorized if the // AllowRenewalAfterExpiry returns if the renewal flow is authorized if the
// certificate is expired. If the property is not set within the provisioner // certificate is expired. If the property is not set within the provisioner
// then the global value from the authority configuration will be used. // then the global value from the authority configuration will be used.
func (c *Claimer) AllowRenewAfterExpiry() bool { func (c *Claimer) AllowRenewalAfterExpiry() bool {
if c.claims == nil || c.claims.AllowRenewAfterExpiry == nil { if c.claims == nil || c.claims.AllowRenewalAfterExpiry == nil {
return *c.global.AllowRenewAfterExpiry return *c.global.AllowRenewalAfterExpiry
} }
return *c.claims.AllowRenewAfterExpiry return *c.claims.AllowRenewalAfterExpiry
} }
// DefaultSSHCertDuration returns the default SSH certificate duration for the // DefaultSSHCertDuration returns the default SSH certificate duration for the

@ -3,6 +3,7 @@ package provisioner
import ( import (
"context" "context"
"crypto/x509" "crypto/x509"
"net/http"
"regexp" "regexp"
"strings" "strings"
"time" "time"
@ -21,14 +22,19 @@ type Controller struct {
IdentityFunc GetIdentityFunc IdentityFunc GetIdentityFunc
AuthorizeRenewFunc AuthorizeRenewFunc AuthorizeRenewFunc AuthorizeRenewFunc
AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc AuthorizeSSHRenewFunc
policy *policyEngine
} }
// NewController initializes a new provisioner controller. // NewController initializes a new provisioner controller.
func NewController(p Interface, claims *Claims, config Config) (*Controller, error) { func NewController(p Interface, claims *Claims, config Config, options *Options) (*Controller, error) {
claimer, err := NewClaimer(claims, config.Claims) claimer, err := NewClaimer(claims, config.Claims)
if err != nil { if err != nil {
return nil, err return nil, err
} }
policy, err := newPolicyEngine(options)
if err != nil {
return nil, err
}
return &Controller{ return &Controller{
Interface: p, Interface: p,
Audiences: &config.Audiences, Audiences: &config.Audiences,
@ -36,6 +42,7 @@ func NewController(p Interface, claims *Claims, config Config) (*Controller, err
IdentityFunc: config.GetIdentityFunc, IdentityFunc: config.GetIdentityFunc,
AuthorizeRenewFunc: config.AuthorizeRenewFunc, AuthorizeRenewFunc: config.AuthorizeRenewFunc,
AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc, AuthorizeSSHRenewFunc: config.AuthorizeSSHRenewFunc,
policy: policy,
}, nil }, nil
} }
@ -124,8 +131,10 @@ func DefaultAuthorizeRenew(ctx context.Context, p *Controller, cert *x509.Certif
if now.Before(cert.NotBefore) { if now.Before(cert.NotBefore) {
return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano)) return errs.Unauthorized("certificate is not yet valid" + " " + now.UTC().Format(time.RFC3339Nano) + " vs " + cert.NotBefore.Format(time.RFC3339Nano))
} }
if now.After(cert.NotAfter) && !p.Claimer.AllowRenewAfterExpiry() { if now.After(cert.NotAfter) && !p.Claimer.AllowRenewalAfterExpiry() {
return errs.Unauthorized("certificate has expired") // return a custom 401 Unauthorized error with a clearer message for the client
// TODO(hs): these errors likely need to be refactored as a whole; HTTP status codes shouldn't be in this layer.
return errs.New(http.StatusUnauthorized, "The request lacked necessary authorization to be completed: certificate expired on %s", cert.NotAfter)
} }
return nil return nil
@ -144,7 +153,7 @@ func DefaultAuthorizeSSHRenew(ctx context.Context, p *Controller, cert *ssh.Cert
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return errs.Unauthorized("certificate is not yet valid") return errs.Unauthorized("certificate is not yet valid")
} }
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewAfterExpiry() { if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(ssh.CertTimeInfinity) && (unixNow >= before || before < 0) && !p.Claimer.AllowRenewalAfterExpiry() {
return errs.Unauthorized("certificate has expired") return errs.Unauthorized("certificate has expired")
} }
@ -192,3 +201,10 @@ func SanitizeSSHUserPrincipal(email string) string {
} }
}, strings.ToLower(email)) }, strings.ToLower(email))
} }
func (c *Controller) getPolicy() *policyEngine {
if c == nil {
return nil
}
return c.policy
}

@ -9,6 +9,8 @@ import (
"time" "time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/authority/policy"
) )
var trueValue = true var trueValue = true
@ -30,11 +32,40 @@ func mustDuration(t *testing.T, s string) *Duration {
return d return d
} }
func mustNewPolicyEngine(t *testing.T, options *Options) *policyEngine {
t.Helper()
c, err := newPolicyEngine(options)
if err != nil {
t.Fatal(err)
}
return c
}
func TestNewController(t *testing.T) { func TestNewController(t *testing.T) {
options := &Options{
X509: &X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"*.local"},
},
},
SSH: &SSHOptions{
Host: &policy.SSHHostCertificateOptions{
AllowedNames: &policy.SSHNameOptions{
DNSDomains: []string{"*.local"},
},
},
User: &policy.SSHUserCertificateOptions{
AllowedNames: &policy.SSHNameOptions{
EmailAddresses: []string{"@example.com"},
},
},
},
}
type args struct { type args struct {
p Interface p Interface
claims *Claims claims *Claims
config Config config Config
options *Options
} }
tests := []struct { tests := []struct {
name string name string
@ -45,7 +76,7 @@ func TestNewController(t *testing.T) {
{"ok", args{&JWK{}, nil, Config{ {"ok", args{&JWK{}, nil, Config{
Claims: globalProvisionerClaims, Claims: globalProvisionerClaims,
Audiences: testAudiences, Audiences: testAudiences,
}}, &Controller{ }, nil}, &Controller{
Interface: &JWK{}, Interface: &JWK{},
Audiences: &testAudiences, Audiences: &testAudiences,
Claimer: mustClaimer(t, nil, globalProvisionerClaims), Claimer: mustClaimer(t, nil, globalProvisionerClaims),
@ -55,24 +86,49 @@ func TestNewController(t *testing.T) {
}, Config{ }, Config{
Claims: globalProvisionerClaims, Claims: globalProvisionerClaims,
Audiences: testAudiences, Audiences: testAudiences,
}}, &Controller{ }, nil}, &Controller{
Interface: &JWK{}, Interface: &JWK{},
Audiences: &testAudiences, Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{ Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal, DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims), }, globalProvisionerClaims),
}, false}, }, false},
{"ok with claims and options", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}, options}, &Controller{
Interface: &JWK{},
Audiences: &testAudiences,
Claimer: mustClaimer(t, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, globalProvisionerClaims),
policy: mustNewPolicyEngine(t, options),
}, false},
{"fail claimer", args{&JWK{}, &Claims{ {"fail claimer", args{&JWK{}, &Claims{
MinTLSDur: mustDuration(t, "24h"), MinTLSDur: mustDuration(t, "24h"),
MaxTLSDur: mustDuration(t, "2h"), MaxTLSDur: mustDuration(t, "2h"),
}, Config{ }, Config{
Claims: globalProvisionerClaims, Claims: globalProvisionerClaims,
Audiences: testAudiences, Audiences: testAudiences,
}, nil}, nil, true},
{"fail options", args{&JWK{}, &Claims{
DisableRenewal: &defaultDisableRenewal,
}, Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}, &Options{
X509: &X509Options{
AllowedNames: &policy.X509NameOptions{
DNSDomains: []string{"**.local"},
},
},
}}, nil, true}, }}, nil, true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := NewController(tt.args.p, tt.args.claims, tt.args.config) got, err := NewController(tt.args.p, tt.args.claims, tt.args.config, tt.args.options)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("NewController() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -160,13 +216,13 @@ func TestController_AuthorizeRenew(t *testing.T) {
NotBefore: now, NotBefore: now,
NotAfter: now.Add(time.Hour), NotAfter: now.Add(time.Hour),
}}, false}, }}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error { {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *x509.Certificate) error {
return nil return nil
}}, args{ctx, &x509.Certificate{ }}, args{ctx, &x509.Certificate{
NotBefore: now, NotBefore: now,
NotAfter: now.Add(time.Hour), NotAfter: now.Add(time.Hour),
}}, false}, }}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{ {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &x509.Certificate{
NotBefore: now.Add(-time.Hour), NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute), NotAfter: now.Add(-time.Minute),
}}, false}, }}, false},
@ -231,13 +287,13 @@ func TestController_AuthorizeSSHRenew(t *testing.T) {
ValidAfter: uint64(now.Unix()), ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false}, }}, false},
{"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error { {"ok custom disabled", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), func(ctx context.Context, p *Controller, cert *ssh.Certificate) error {
return nil return nil
}}, args{ctx, &ssh.Certificate{ }}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Unix()), ValidAfter: uint64(now.Unix()),
ValidBefore: uint64(now.Add(time.Hour).Unix()), ValidBefore: uint64(now.Add(time.Hour).Unix()),
}}, false}, }}, false},
{"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{ {"ok renew after expiry", fields{&JWK{}, mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims), nil}, args{ctx, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()),
}}, false}, }}, false},
@ -296,7 +352,7 @@ func TestDefaultAuthorizeRenew(t *testing.T) {
}}, false}, }}, false},
{"ok renew after expiry", args{ctx, &Controller{ {"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{}, Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &x509.Certificate{ }, &x509.Certificate{
NotBefore: now.Add(-time.Hour), NotBefore: now.Add(-time.Hour),
NotAfter: now.Add(-time.Minute), NotAfter: now.Add(-time.Minute),
@ -354,7 +410,7 @@ func TestDefaultAuthorizeSSHRenew(t *testing.T) {
}}, false}, }}, false},
{"ok renew after expiry", args{ctx, &Controller{ {"ok renew after expiry", args{ctx, &Controller{
Interface: &JWK{}, Interface: &JWK{},
Claimer: mustClaimer(t, &Claims{AllowRenewAfterExpiry: &trueValue}, globalProvisionerClaims), Claimer: mustClaimer(t, &Claims{AllowRenewalAfterExpiry: &trueValue}, globalProvisionerClaims),
}, &ssh.Certificate{ }, &ssh.Certificate{
ValidAfter: uint64(now.Add(-time.Hour).Unix()), ValidAfter: uint64(now.Add(-time.Hour).Unix()),
ValidBefore: uint64(now.Add(-time.Minute).Unix()), ValidBefore: uint64(now.Add(-time.Minute).Unix()),

@ -14,10 +14,12 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
) )
// gcpCertsURL is the url that serves Google OAuth2 public keys. // gcpCertsURL is the url that serves Google OAuth2 public keys.
@ -212,7 +214,7 @@ func (p *GCP) Init(config Config) (err error) {
} }
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config) p.ctl, err = NewController(p, p.Claims, config, p.Options)
return return
} }
@ -262,6 +264,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
} }
return append(so, return append(so,
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName), newProvisionerExtensionOption(TypeGCP, p.Name, claims.Subject, "InstanceID", ce.InstanceID, "InstanceName", ce.InstanceName),
@ -269,6 +272,7 @@ func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
), nil ), nil
} }
@ -421,6 +425,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
signOptions = append(signOptions, templateOptions) signOptions = append(signOptions, templateOptions)
return append(signOptions, return append(signOptions,
p,
// Validate user SignSSHOptions. // Validate user SignSSHOptions.
sshCertOptionsValidator(defaults), sshCertOptionsValidator(defaults),
// Set the validity bounds if not set. // Set the validity bounds if not set.
@ -431,5 +436,7 @@ func (p *GCP) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil ), nil
} }

@ -516,9 +516,9 @@ func TestGCP_AuthorizeSign(t *testing.T) {
code int code int
wantErr bool wantErr bool
}{ }{
{"ok", p1, args{t1}, 5, http.StatusOK, false}, {"ok", p1, args{t1}, 7, http.StatusOK, false},
{"ok", p2, args{t2}, 10, http.StatusOK, false}, {"ok", p2, args{t2}, 12, http.StatusOK, false},
{"ok", p3, args{t3}, 5, http.StatusOK, false}, {"ok", p3, args{t3}, 7, http.StatusOK, false},
{"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true}, {"fail token", p1, args{"token"}, 0, http.StatusUnauthorized, true},
{"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true}, {"fail key", p1, args{failKey}, 0, http.StatusUnauthorized, true},
{"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true}, {"fail iss", p1, args{failIss}, 0, http.StatusUnauthorized, true},
@ -545,9 +545,10 @@ func TestGCP_AuthorizeSign(t *testing.T) {
assert.Fatal(t, ok, "error does not implement StatusCodedError interface") assert.Fatal(t, ok, "error does not implement StatusCodedError interface")
assert.Equals(t, sc.StatusCode(), tt.code) assert.Equals(t, sc.StatusCode(), tt.code)
default: default:
assert.Len(t, tt.wantLen, got) assert.Equals(t, tt.wantLen, len(got))
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *GCP:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeGCP) assert.Equals(t, v.Type, TypeGCP)
@ -570,6 +571,8 @@ func TestGCP_AuthorizeSign(t *testing.T) {
assert.Equals(t, v, nil) assert.Equals(t, v, nil)
case dnsNamesValidator: case dnsNamesValidator:
assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"})
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }

@ -7,10 +7,12 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
) )
// jwtPayload extends jwt.Claims with step attributes. // jwtPayload extends jwt.Claims with step attributes.
@ -97,7 +99,7 @@ func (p *JWK) Init(config Config) (err error) {
return errors.New("provisioner key cannot be empty") return errors.New("provisioner key cannot be empty")
} }
p.ctl, err = NewController(p, p.Claims, config) p.ctl, err = NewController(p, p.Claims, config, p.Options)
return return
} }
@ -141,6 +143,7 @@ func (p *JWK) authorizeToken(token string, audiences []string) (*jwtPayload, err
// revoke the certificate with serial number in the `sub` property. // revoke the certificate with serial number in the `sub` property.
func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.Revoke) _, err := p.authorizeToken(token, p.ctl.Audiences.Revoke)
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeRevoke")
} }
@ -170,6 +173,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
} }
return []SignOption{ return []SignOption{
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID), newProvisionerExtensionOption(TypeJWK, p.Name, p.Key.KeyID),
@ -179,6 +183,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
defaultSANsValidator(claims.SANs), defaultSANsValidator(claims.SANs),
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil }, nil
} }
@ -187,6 +192,7 @@ func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, er
// revocation status. Just confirms that the provisioner that created the // revocation status. Just confirms that the provisioner that created the
// certificate was configured to allow renewals. // certificate was configured to allow renewals.
func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *JWK) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
// TODO(hs): authorize the SANs using x509 name policy allow/deny rules (also for other provisioners with AuthorizeRewew and AuthorizeSSHRenew)
return p.ctl.AuthorizeRenew(ctx, cert) return p.ctl.AuthorizeRenew(ctx, cert)
} }
@ -251,6 +257,7 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
} }
return append(signOptions, return append(signOptions,
p,
// Set the validity bounds if not set. // Set the validity bounds if not set.
&sshDefaultDuration{p.ctl.Claimer}, &sshDefaultDuration{p.ctl.Claimer},
// Validate public key // Validate public key
@ -259,11 +266,14 @@ func (p *JWK) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption,
&sshCertValidityValidator{p.ctl.Claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
), nil ), nil
} }
// AuthorizeSSHRevoke returns nil if the token is valid, false otherwise. // AuthorizeSSHRevoke returns nil if the token is valid, false otherwise.
func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error { func (p *JWK) AuthorizeSSHRevoke(ctx context.Context, token string) error {
_, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke) _, err := p.authorizeToken(token, p.ctl.Audiences.SSHRevoke)
// TODO(hs): authorize the principals using SSH name policy allow/deny rules (also for other provisioners with AuthorizeSSHRevoke)
return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke") return errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSSHRevoke")
} }

@ -297,9 +297,10 @@ func TestJWK_AuthorizeSign(t *testing.T) {
} }
} else { } else {
if assert.NotNil(t, got) { if assert.NotNil(t, got) {
assert.Len(t, 7, got) assert.Equals(t, 9, len(got))
for _, o := range got { for _, o := range got {
switch v := o.(type) { switch v := o.(type) {
case *JWK:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeJWK) assert.Equals(t, v.Type, TypeJWK)
@ -316,6 +317,8 @@ func TestJWK_AuthorizeSign(t *testing.T) {
assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration())
case defaultSANsValidator: case defaultSANsValidator:
assert.Equals(t, []string(v), tt.sans) assert.Equals(t, []string(v), tt.sans)
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }

@ -10,11 +10,13 @@ import (
"net/http" "net/http"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/pemutil" "go.step.sm/crypto/pemutil"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"github.com/smallstep/certificates/errs"
) )
// NOTE: There can be at most one kubernetes service account provisioner configured // NOTE: There can be at most one kubernetes service account provisioner configured
@ -91,6 +93,7 @@ func (p *K8sSA) GetEncryptedKey() (string, string, bool) {
// Init initializes and validates the fields of a K8sSA type. // Init initializes and validates the fields of a K8sSA type.
func (p *K8sSA) Init(config Config) (err error) { func (p *K8sSA) Init(config Config) (err error) {
switch { switch {
case p.Type == "": case p.Type == "":
return errors.New("provisioner type cannot be empty") return errors.New("provisioner type cannot be empty")
@ -137,7 +140,7 @@ func (p *K8sSA) Init(config Config) (err error) {
p.kauthn = k8s.AuthenticationV1() p.kauthn = k8s.AuthenticationV1()
*/ */
p.ctl, err = NewController(p, p.Claims, config) p.ctl, err = NewController(p, p.Claims, config, p.Options)
return return
} }
@ -231,6 +234,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
} }
return []SignOption{ return []SignOption{
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeK8sSA, p.Name, ""), newProvisionerExtensionOption(TypeK8sSA, p.Name, ""),
@ -238,6 +242,7 @@ func (p *K8sSA) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
// validators // validators
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil }, nil
} }
@ -270,6 +275,7 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
signOptions := []SignOption{templateOptions} signOptions := []SignOption{templateOptions}
return append(signOptions, return append(signOptions,
p,
// Require type, key-id and principals in the SignSSHOptions. // Require type, key-id and principals in the SignSSHOptions.
&sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true},
// Set the validity bounds if not set. // Set the validity bounds if not set.
@ -280,6 +286,8 @@ func (p *K8sSA) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOptio
&sshCertValidityValidator{p.ctl.Claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require and validate all the default fields in the SSH certificate. // Require and validate all the default fields in the SSH certificate.
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), p.ctl.getPolicy().getSSHUser()),
), nil ), nil
} }

@ -280,9 +280,9 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) { if assert.NotNil(t, opts) {
tot := 0
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case *K8sSA:
case certificateOptionsFunc: case certificateOptionsFunc:
case *provisionerExtensionOption: case *provisionerExtensionOption:
assert.Equals(t, v.Type, TypeK8sSA) assert.Equals(t, v.Type, TypeK8sSA)
@ -295,12 +295,13 @@ func TestK8sSA_AuthorizeSign(t *testing.T) {
case *validityValidator: case *validityValidator:
assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration())
assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration())
case *x509NamePolicyValidator:
assert.Equals(t, nil, v.policyEngine)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
tot++
} }
assert.Equals(t, tot, 5) assert.Equals(t, 7, len(opts))
} }
} }
} }
@ -367,9 +368,10 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
} else { } else {
if assert.Nil(t, tc.err) { if assert.Nil(t, tc.err) {
if assert.NotNil(t, opts) { if assert.NotNil(t, opts) {
tot := 0 assert.Len(t, 8, opts)
for _, o := range opts { for _, o := range opts {
switch v := o.(type) { switch v := o.(type) {
case Interface:
case sshCertificateOptionsFunc: case sshCertificateOptionsFunc:
case *sshCertOptionsRequireValidator: case *sshCertOptionsRequireValidator:
assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true}) assert.Equals(t, v, &sshCertOptionsRequireValidator{CertType: true, KeyID: true, Principals: true})
@ -379,12 +381,13 @@ func TestK8sSA_AuthorizeSSHSign(t *testing.T) {
case *sshCertDefaultValidator: case *sshCertDefaultValidator:
case *sshDefaultDuration: case *sshDefaultDuration:
assert.Equals(t, v.Claimer, tc.p.ctl.Claimer) assert.Equals(t, v.Claimer, tc.p.ctl.Claimer)
case *sshNamePolicyValidator:
assert.Equals(t, nil, v.userPolicyEngine)
assert.Equals(t, nil, v.hostPolicyEngine)
default: default:
assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v)) assert.FatalError(t, fmt.Errorf("unexpected sign option of type %T", v))
} }
tot++
} }
assert.Equals(t, tot, 6)
} }
} }
} }

@ -61,3 +61,16 @@ func MethodFromContext(ctx context.Context) Method {
m, _ := ctx.Value(methodKey{}).(Method) m, _ := ctx.Value(methodKey{}).(Method)
return m return m
} }
type tokenKey struct{}
// NewContextWithToken creates a new context with the given token.
func NewContextWithToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, tokenKey{}, token)
}
// TokenFromContext returns the token stored in the given context.
func TokenFromContext(ctx context.Context) (string, bool) {
token, ok := ctx.Value(tokenKey{}).(string)
return token, ok
}

@ -10,12 +10,14 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
nebula "github.com/slackhq/nebula/cert" nebula "github.com/slackhq/nebula/cert"
"github.com/smallstep/certificates/errs"
"go.step.sm/crypto/jose" "go.step.sm/crypto/jose"
"go.step.sm/crypto/sshutil" "go.step.sm/crypto/sshutil"
"go.step.sm/crypto/x25519" "go.step.sm/crypto/x25519"
"go.step.sm/crypto/x509util" "go.step.sm/crypto/x509util"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/smallstep/certificates/errs"
) )
const ( const (
@ -61,7 +63,7 @@ func (p *Nebula) Init(config Config) (err error) {
} }
config.Audiences = config.Audiences.WithFragment(p.GetIDForToken()) config.Audiences = config.Audiences.WithFragment(p.GetIDForToken())
p.ctl, err = NewController(p, p.Claims, config) p.ctl, err = NewController(p, p.Claims, config, p.Options)
return return
} }
@ -144,6 +146,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
} }
return []SignOption{ return []SignOption{
p,
templateOptions, templateOptions,
// modifiers / withOptions // modifiers / withOptions
newProvisionerExtensionOption(TypeNebula, p.Name, ""), newProvisionerExtensionOption(TypeNebula, p.Name, ""),
@ -160,6 +163,7 @@ func (p *Nebula) AuthorizeSign(ctx context.Context, token string) ([]SignOption,
}, },
defaultPublicKeyValidator{}, defaultPublicKeyValidator{},
newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()),
newX509NamePolicyValidator(p.ctl.getPolicy().getX509()),
}, nil }, nil
} }
@ -246,6 +250,7 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
} }
return append(signOptions, return append(signOptions,
p,
templateOptions, templateOptions,
// Checks the validity bounds, and set the validity if has not been set. // Checks the validity bounds, and set the validity if has not been set.
&sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter}, &sshLimitDuration{p.ctl.Claimer, crt.Details.NotAfter},
@ -255,6 +260,8 @@ func (p *Nebula) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOpti
&sshCertValidityValidator{p.ctl.Claimer}, &sshCertValidityValidator{p.ctl.Claimer},
// Require all the fields in the SSH certificate // Require all the fields in the SSH certificate
&sshCertDefaultValidator{}, &sshCertDefaultValidator{},
// Ensure that all principal names are allowed
newSSHNamePolicyValidator(p.ctl.getPolicy().getSSHHost(), nil),
), nil ), nil
} }

@ -38,7 +38,7 @@ func (p *noop) Init(config Config) error {
} }
func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { func (p *noop) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{}, nil return []SignOption{p}, nil
} }
func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error { func (p *noop) AuthorizeRenew(ctx context.Context, cert *x509.Certificate) error {
@ -50,7 +50,7 @@ func (p *noop) AuthorizeRevoke(ctx context.Context, token string) error {
} }
func (p *noop) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) { func (p *noop) AuthorizeSSHSign(ctx context.Context, token string) ([]SignOption, error) {
return []SignOption{}, nil return []SignOption{p}, nil
} }
func (p *noop) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) { func (p *noop) AuthorizeSSHRenew(ctx context.Context, token string) (*ssh.Certificate, error) {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save