diff --git a/acme/challenge_test.go b/acme/challenge_test.go index c05b25e7..e1b6816a 100644 --- a/acme/challenge_test.go +++ b/acme/challenge_test.go @@ -29,6 +29,18 @@ import ( "github.com/smallstep/assert" ) +type mockClient struct { + get func(url string) (*http.Response, error) + lookupTxt func(name string) ([]string, error) + tlsDial func(network, addr string, config *tls.Config) (*tls.Conn, error) +} + +func (m *mockClient) Get(url string) (*http.Response, error) { return m.get(url) } +func (m *mockClient) LookupTxt(name string) ([]string, error) { return m.lookupTxt(name) } +func (m *mockClient) TLSDial(network, addr string, config *tls.Config) (*tls.Conn, error) { + return m.tlsDial(network, addr, config) +} + func Test_storeError(t *testing.T) { type test struct { ch *Challenge @@ -229,7 +241,7 @@ func TestKeyAuthorization(t *testing.T) { func TestChallenge_Validate(t *testing.T) { type test struct { ch *Challenge - vo *ValidateChallengeOptions + vc Client jwk *jose.JSONWebKey db DB srv *httptest.Server @@ -273,8 +285,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -309,8 +321,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -344,8 +356,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -381,8 +393,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -416,8 +428,8 @@ func TestChallenge_Validate(t *testing.T) { } return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -466,8 +478,8 @@ func TestChallenge_Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -493,7 +505,8 @@ func TestChallenge_Validate(t *testing.T) { defer tc.srv.Close() } - if err := tc.ch.Validate(context.Background(), tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := tc.ch.Validate(ctx, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -524,7 +537,7 @@ func (errReader) Close() error { func TestHTTP01Validate(t *testing.T) { type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -541,8 +554,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -575,8 +588,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return nil, errors.New("force") }, }, @@ -608,8 +621,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), @@ -645,8 +658,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ StatusCode: http.StatusBadRequest, Body: errReader(0), @@ -681,8 +694,8 @@ func TestHTTP01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: errReader(0), }, nil @@ -704,8 +717,8 @@ func TestHTTP01Validate(t *testing.T) { jwk.Key = "foo" return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -730,8 +743,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -772,8 +785,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString("foo")), }, nil @@ -815,8 +828,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil @@ -857,8 +870,8 @@ func TestHTTP01Validate(t *testing.T) { assert.FatalError(t, err) return test{ ch: ch, - vo: &ValidateChallengeOptions{ - HTTPGet: func(url string) (*http.Response, error) { + vc: &mockClient{ + get: func(url string) (*http.Response, error) { return &http.Response{ Body: io.NopCloser(bytes.NewBufferString(expKeyAuth)), }, nil @@ -887,7 +900,8 @@ func TestHTTP01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := http01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := http01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -911,7 +925,7 @@ func TestDNS01Validate(t *testing.T) { fulldomain := "*.zap.internal" domain := strings.TrimPrefix(fulldomain, "*.") type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -928,8 +942,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -963,8 +977,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return nil, errors.New("force") }, }, @@ -1001,8 +1015,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo"}, nil }, }, @@ -1026,8 +1040,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, @@ -1068,8 +1082,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", "bar"}, nil }, }, @@ -1111,8 +1125,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, @@ -1156,8 +1170,8 @@ func TestDNS01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - LookupTxt: func(url string) ([]string, error) { + vc: &mockClient{ + lookupTxt: func(url string) ([]string, error) { return []string{"foo", expected}, nil }, }, @@ -1186,7 +1200,8 @@ func TestDNS01Validate(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) - if err := dns01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := dns01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: @@ -1206,6 +1221,8 @@ func TestDNS01Validate(t *testing.T) { } } +type tlsDialer func(network, addr string, config *tls.Config) (conn *tls.Conn, err error) + func newTestTLSALPNServer(validationCert *tls.Certificate) (*httptest.Server, tlsDialer) { srv := httptest.NewUnstartedServer(http.NewServeMux()) @@ -1309,7 +1326,7 @@ func TestTLSALPN01Validate(t *testing.T) { } } type test struct { - vo *ValidateChallengeOptions + vc Client ch *Challenge jwk *jose.JSONWebKey db DB @@ -1321,8 +1338,8 @@ func TestTLSALPN01Validate(t *testing.T) { ch := makeTLSCh() return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -1351,8 +1368,8 @@ func TestTLSALPN01Validate(t *testing.T) { ch := makeTLSCh() return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return nil, errors.New("force") }, }, @@ -1384,8 +1401,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1413,8 +1430,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, @@ -1443,8 +1460,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.Client(&noopConn{}, config), nil }, }, @@ -1479,8 +1496,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, @@ -1516,8 +1533,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { + vc: &mockClient{ + tlsDial: func(network, addr string, config *tls.Config) (*tls.Conn, error) { return tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", srv.Listener.Addr().String(), config) }, }, @@ -1562,8 +1579,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1605,8 +1622,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1649,8 +1666,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1692,8 +1709,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1736,8 +1753,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, srv: srv, jwk: jwk, @@ -1758,8 +1775,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1797,8 +1814,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1841,8 +1858,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1884,8 +1901,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1924,8 +1941,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -1963,8 +1980,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2008,8 +2025,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2054,8 +2071,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2100,8 +2117,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2144,8 +2161,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2189,8 +2206,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2226,8 +2243,8 @@ func TestTLSALPN01Validate(t *testing.T) { return test{ ch: ch, - vo: &ValidateChallengeOptions{ - TLSDial: tlsDial, + vc: &mockClient{ + tlsDial: tlsDial, }, db: &MockDB{ MockUpdateChallenge: func(ctx context.Context, updch *Challenge) error { @@ -2253,7 +2270,8 @@ func TestTLSALPN01Validate(t *testing.T) { defer tc.srv.Close() } - if err := tlsalpn01Validate(context.Background(), tc.ch, tc.db, tc.jwk, tc.vo); err != nil { + ctx := NewClientContext(context.Background(), tc.vc) + if err := tlsalpn01Validate(ctx, tc.ch, tc.db, tc.jwk); err != nil { if assert.NotNil(t, tc.err) { switch k := err.(type) { case *Error: diff --git a/acme/linker.go b/acme/linker.go index 6e9110c2..bddc21f1 100644 --- a/acme/linker.go +++ b/acme/linker.go @@ -206,6 +206,11 @@ func (l *linker) Middleware(next http.Handler) http.Handler { // GetLink is a helper for GetLinkExplicit. func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) string { + var name string + if p, ok := ProvisionerFromContext(ctx); ok { + name = p.GetName() + } + var u url.URL if baseURL := baseURLFromContext(ctx); baseURL != nil { u = *baseURL @@ -217,8 +222,7 @@ func (l *linker) GetLink(ctx context.Context, typ LinkType, inputs ...string) st u.Host = l.dns } - p := MustProvisionerFromContext(ctx) - u.Path = l.prefix + GetUnescapedPathSuffix(typ, p.GetName(), inputs...) + u.Path = l.prefix + GetUnescapedPathSuffix(typ, name, inputs...) return u.String() } diff --git a/acme/linker_test.go b/acme/linker_test.go index 1946dd88..b85d1a53 100644 --- a/acme/linker_test.go +++ b/acme/linker_test.go @@ -5,16 +5,34 @@ import ( "fmt" "net/url" "testing" + "time" "github.com/smallstep/assert" + "github.com/smallstep/certificates/authority/provisioner" ) -func TestLinker_GetUnescapedPathSuffix(t *testing.T) { - dns := "ca.smallstep.com" - prefix := "acme" - linker := NewLinker(dns, prefix) +func mockProvisioner(t *testing.T) Provisioner { + t.Helper() + var defaultDisableRenewal = false + + // Initialize provisioners + p := &provisioner.ACME{ + Type: "ACME", + Name: "test@acme-provisioner.com", + } + if err := p.Init(provisioner.Config{Claims: provisioner.Claims{ + MinTLSDur: &provisioner.Duration{Duration: 5 * time.Minute}, + MaxTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DefaultTLSDur: &provisioner.Duration{Duration: 24 * time.Hour}, + DisableRenewal: &defaultDisableRenewal, + }}); err != nil { + fmt.Printf("%v", err) + } + return p +} - getPath := linker.GetUnescapedPathSuffix +func TestGetUnescapedPathSuffix(t *testing.T) { + getPath := GetUnescapedPathSuffix assert.Equals(t, getPath(NewNonceLinkType, "{provisionerID}"), "/{provisionerID}/new-nonce") assert.Equals(t, getPath(DirectoryLinkType, "{provisionerID}"), "/{provisionerID}/directory") @@ -31,9 +49,9 @@ func TestLinker_GetUnescapedPathSuffix(t *testing.T) { } func TestLinker_DNS(t *testing.T) { - prov := newProv() + prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) type test struct { name string dns string @@ -116,19 +134,19 @@ func TestLinker_GetLink(t *testing.T) { linker := NewLinker(dns, prefix) id := "1234" - prov := newProv() + prov := mockProvisioner(t) escProvName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - ctx := context.WithValue(context.Background(), provisionerContextKey, prov) - ctx = context.WithValue(ctx, baseURLContextKey, baseURL) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) // No provisioner and no BaseURL from request assert.Equals(t, linker.GetLink(context.Background(), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", "")) // Provisioner: yes, BaseURL: no - assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerContextKey, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), provisionerKey{}, prov), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://ca.smallstep.com", escProvName)) // Provisioner: no, BaseURL: yes - assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLContextKey, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) + assert.Equals(t, linker.GetLink(context.WithValue(context.Background(), baseURLKey{}, baseURL), NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", "https://test.ca.smallstep.com", "")) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) assert.Equals(t, linker.GetLink(ctx, NewNonceLinkType), fmt.Sprintf("%s/acme/%s/new-nonce", baseURL, escProvName)) @@ -162,10 +180,10 @@ func TestLinker_GetLink(t *testing.T) { func TestLinker_LinkOrder(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) oid := "orderID" certID := "certID" @@ -227,10 +245,10 @@ func TestLinker_LinkOrder(t *testing.T) { func TestLinker_LinkAccount(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) accID := "accountID" linkerPrefix := "acme" @@ -259,10 +277,10 @@ func TestLinker_LinkAccount(t *testing.T) { func TestLinker_LinkChallenge(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID := "chID" azID := "azID" @@ -292,10 +310,10 @@ func TestLinker_LinkChallenge(t *testing.T) { func TestLinker_LinkAuthorization(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) chID0 := "chID-0" chID1 := "chID-1" @@ -334,10 +352,10 @@ func TestLinker_LinkAuthorization(t *testing.T) { func TestLinker_LinkOrdersByAccountID(t *testing.T) { baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} - prov := newProv() + prov := mockProvisioner(t) provName := url.PathEscape(prov.GetName()) - ctx := context.WithValue(context.Background(), baseURLContextKey, baseURL) - ctx = context.WithValue(ctx, provisionerContextKey, prov) + ctx := NewProvisionerContext(context.Background(), prov) + ctx = context.WithValue(ctx, baseURLKey{}, baseURL) linkerPrefix := "acme" l := NewLinker("dns", linkerPrefix)