Merge pull request #1211 from smallstep/herman/ca-client-context-methods

Add `WithContext` methods to the CA client
pull/1217/head
Herman Slatman 1 year ago committed by GitHub
commit b13b527d18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,7 +61,7 @@ func Bootstrap(token string) (*Client, error) {
// }
// resp, err := client.Get("https://internal.smallstep.com")
func BootstrapClient(ctx context.Context, token string, options ...TLSOption) (*http.Client, error) {
b, err := createBootstrap(token)
b, err := createBootstrap(token) //nolint:contextcheck // deeply nested context; temporary
if err != nil {
return nil, err
}
@ -120,7 +120,7 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio
return nil, errors.New("server TLSConfig is already set")
}
b, err := createBootstrap(token)
b, err := createBootstrap(token) //nolint:contextcheck // deeply nested context; temporary
if err != nil {
return nil, err
}
@ -169,7 +169,7 @@ func BootstrapServer(ctx context.Context, token string, base *http.Server, optio
// ... // register services
// srv.Serve(lis)
func BootstrapListener(ctx context.Context, token string, inner net.Listener, options ...TLSOption) (net.Listener, error) {
b, err := createBootstrap(token)
b, err := createBootstrap(token) //nolint:contextcheck // deeply nested context; temporary
if err != nil {
return nil, err
}

@ -2,6 +2,7 @@ package ca
import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
@ -75,7 +76,11 @@ func (c *uaClient) SetTransport(tr http.RoundTripper) {
}
func (c *uaClient) Get(u string) (*http.Response, error) {
req, err := http.NewRequest("GET", u, http.NoBody)
return c.GetWithContext(context.Background(), u)
}
func (c *uaClient) GetWithContext(ctx context.Context, u string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "GET", u, http.NoBody)
if err != nil {
return nil, errors.Wrapf(err, "create GET %s request failed", u)
}
@ -84,7 +89,11 @@ func (c *uaClient) Get(u string) (*http.Response, error) {
}
func (c *uaClient) Post(u, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequest("POST", u, body)
return c.PostWithContext(context.Background(), u, contentType, body)
}
func (c *uaClient) PostWithContext(ctx context.Context, u, contentType string, body io.Reader) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, "POST", u, body)
if err != nil {
return nil, errors.Wrapf(err, "create POST %s request failed", u)
}
@ -581,18 +590,24 @@ func (c *Client) SetTransport(tr http.RoundTripper) {
c.client.SetTransport(tr)
}
// Version performs the version request to the CA and returns the
// Version performs the version request to the CA with an empty context and returns the
// api.VersionResponse struct.
func (c *Client) Version() (*api.VersionResponse, error) {
return c.VersionWithContext(context.Background())
}
// VersionWithContext performs the version request to the CA with the provided context
// and returns the api.VersionResponse struct.
func (c *Client) VersionWithContext(ctx context.Context) (*api.VersionResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/version"})
retry:
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -605,18 +620,24 @@ retry:
return &version, nil
}
// Health performs the health request to the CA and returns the
// api.HealthResponse struct.
// Health performs the health request to the CA with an empty context
// and returns the api.HealthResponse struct.
func (c *Client) Health() (*api.HealthResponse, error) {
return c.HealthWithContext(context.Background())
}
// HealthWithContext performs the health request to the CA with the provided context
// and returns the api.HealthResponse struct.
func (c *Client) HealthWithContext(ctx context.Context) (*api.HealthResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
retry:
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -629,21 +650,29 @@ retry:
return &health, nil
}
// Root performs the root request to the CA with the given SHA256 and returns
// the api.RootResponse struct. It uses an insecure client, but it checks the
// resulting root certificate with the given SHA256, returning an error if they
// do not match.
// Root performs the root request to the CA with an empty context and the provided
// SHA256 and returns the api.RootResponse struct. It uses an insecure client, but
// it checks the resulting root certificate with the given SHA256, returning an error
// if they do not match.
func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) {
return c.RootWithContext(context.Background(), sha256Sum)
}
// RootWithContext performs the root request to the CA with an empty context and the provided
// SHA256 and returns the api.RootResponse struct. It uses an insecure client, but
// it checks the resulting root certificate with the given SHA256, returning an error
// if they do not match.
func (c *Client) RootWithContext(ctx context.Context, sha256Sum string) (*api.RootResponse, error) {
var retried bool
sha256Sum = strings.ToLower(strings.ReplaceAll(sha256Sum, "-", ""))
u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum})
retry:
resp, err := newInsecureClient().Get(u.String())
resp, err := newInsecureClient().GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -661,9 +690,15 @@ retry:
return &root, nil
}
// Sign performs the sign request to the CA and returns the api.SignResponse
// struct.
// Sign performs the sign request to the CA with an empty context and returns
// the api.SignResponse struct.
func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
return c.SignWithContext(context.Background(), req)
}
// SignWithContext performs the sign request to the CA with the provided context
// and returns the api.SignResponse struct.
func (c *Client) SignWithContext(ctx context.Context, req *api.SignRequest) (*api.SignResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
@ -671,12 +706,12 @@ func (c *Client) Sign(req *api.SignRequest) (*api.SignResponse, error) {
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/sign"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -692,19 +727,30 @@ retry:
return &sign, nil
}
// Renew performs the renew request to the CA and returns the api.SignResponse
// struct.
// Renew performs the renew request to the CA with an empty context and
// returns the api.SignResponse struct.
func (c *Client) Renew(tr http.RoundTripper) (*api.SignResponse, error) {
return c.RenewWithContext(context.Background(), tr)
}
// RenewWithContext performs the renew request to the CA with the provided context
// and returns the api.SignResponse struct.
func (c *Client) RenewWithContext(ctx context.Context, tr http.RoundTripper) (*api.SignResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
client := &http.Client{Transport: tr}
retry:
resp, err := client.Post(u.String(), "application/json", http.NoBody)
req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -718,12 +764,19 @@ retry:
}
// RenewWithToken performs the renew request to the CA with the given
// authorization token and returns the api.SignResponse struct. This method is
// generally used to renew an expired certificate.
// authorization token and and empty context and returns the api.SignResponse struct.
// This method is generally used to renew an expired certificate.
func (c *Client) RenewWithToken(token string) (*api.SignResponse, error) {
return c.RenewWithTokenAndContext(context.Background(), token)
}
// RenewWithTokenAndContext performs the renew request to the CA with the given
// authorization token and context and returns the api.SignResponse struct.
// This method is generally used to renew an expired certificate.
func (c *Client) RenewWithTokenAndContext(ctx context.Context, token string) (*api.SignResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/renew"})
req, err := http.NewRequest("POST", u.String(), http.NoBody)
req, err := http.NewRequestWithContext(ctx, "POST", u.String(), http.NoBody)
if err != nil {
return nil, errors.Wrapf(err, "create POST %s request failed", u)
}
@ -734,7 +787,7 @@ retry:
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -747,24 +800,34 @@ retry:
return &sign, nil
}
// Rekey performs the rekey request to the CA and returns the api.SignResponse
// struct.
// Rekey performs the rekey request to the CA with an empty context and
// returns the api.SignResponse struct.
func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {
return c.RekeyWithContext(context.Background(), req, tr)
}
// RekeyWithContext performs the rekey request to the CA with the provided context
// and returns the api.SignResponse struct.
func (c *Client) RekeyWithContext(ctx context.Context, req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
return nil, errors.Wrap(err, "error marshaling request")
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"})
client := &http.Client{Transport: tr}
retry:
resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body))
httpReq, err := http.NewRequestWithContext(ctx, "POST", u.String(), bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := client.Do(httpReq)
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -777,9 +840,15 @@ retry:
return &sign, nil
}
// Revoke performs the revoke request to the CA and returns the api.RevokeResponse
// struct.
// Revoke performs the revoke request to the CA with an empty context and returns
// the api.RevokeResponse struct.
func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {
return c.RevokeWithContext(context.Background(), req, tr)
}
// RevokeWithContext performs the revoke request to the CA with the provided context and
// returns the api.RevokeResponse struct.
func (c *Client) RevokeWithContext(ctx context.Context, req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
@ -794,12 +863,12 @@ retry:
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/revoke"})
resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body))
resp, err := client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -812,12 +881,21 @@ retry:
return &revoke, nil
}
// Provisioners performs the provisioners request to the CA and returns the
// api.ProvisionersResponse struct with a map of provisioners.
// Provisioners performs the provisioners request to the CA with an empty context
// and returns the api.ProvisionersResponse struct with a map of provisioners.
//
// ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to
// paginate the provisioners.
func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersResponse, error) {
return c.ProvisionersWithContext(context.Background(), opts...)
}
// ProvisionersWithContext performs the provisioners request to the CA with the provided context
// and returns the api.ProvisionersResponse struct with a map of provisioners.
//
// ProvisionerOption WithProvisionerCursor and WithProvisionLimit can be used to
// paginate the provisioners.
func (c *Client) ProvisionersWithContext(ctx context.Context, opts ...ProvisionerOption) (*api.ProvisionersResponse, error) {
var retried bool
o := new(ProvisionerOptions)
if err := o.Apply(opts); err != nil {
@ -828,12 +906,12 @@ func (c *Client) Provisioners(opts ...ProvisionerOption) (*api.ProvisionersRespo
RawQuery: o.rawQuery(),
})
retry:
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -846,19 +924,26 @@ retry:
return &provisioners, nil
}
// ProvisionerKey performs the request to the CA to get the encrypted key for
// the given provisioner kid and returns the api.ProvisionerKeyResponse struct
// with the encrypted key.
// ProvisionerKey performs the request to the CA with an empty context to get
// the encrypted key for the given provisioner kid and returns the api.ProvisionerKeyResponse
// struct with the encrypted key.
func (c *Client) ProvisionerKey(kid string) (*api.ProvisionerKeyResponse, error) {
return c.ProvisionerKeyWithContext(context.Background(), kid)
}
// ProvisionerKeyWithContext performs the request to the CA with the provided context to get
// the encrypted key for the given provisioner kid and returns the api.ProvisionerKeyResponse
// struct with the encrypted key.
func (c *Client) ProvisionerKeyWithContext(ctx context.Context, kid string) (*api.ProvisionerKeyResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/provisioners/" + kid + "/encrypted-key"})
retry:
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -871,18 +956,24 @@ retry:
return &key, nil
}
// Roots performs the get roots request to the CA and returns the
// api.RootsResponse struct.
// Roots performs the get roots request to the CA with an empty context
// and returns the api.RootsResponse struct.
func (c *Client) Roots() (*api.RootsResponse, error) {
return c.RootsWithContext(context.Background())
}
// RootsWithContext performs the get roots request to the CA with the provided context
// and returns the api.RootsResponse struct.
func (c *Client) RootsWithContext(ctx context.Context) (*api.RootsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/roots"})
retry:
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -895,18 +986,24 @@ retry:
return &roots, nil
}
// Federation performs the get federation request to the CA and returns the
// api.FederationResponse struct.
// Federation performs the get federation request to the CA with an empty context
// and returns the api.FederationResponse struct.
func (c *Client) Federation() (*api.FederationResponse, error) {
return c.FederationWithContext(context.Background())
}
// FederationWithContext performs the get federation request to the CA with the provided context
// and returns the api.FederationResponse struct.
func (c *Client) FederationWithContext(ctx context.Context) (*api.FederationResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/federation"})
retry:
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -919,9 +1016,15 @@ retry:
return &federation, nil
}
// SSHSign performs the POST /ssh/sign request to the CA and returns the
// api.SSHSignResponse struct.
// SSHSign performs the POST /ssh/sign request to the CA with an empty context
// and returns the api.SSHSignResponse struct.
func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error) {
return c.SSHSignWithContext(context.Background(), req)
}
// SSHSignWithContext performs the POST /ssh/sign request to the CA with the provided context
// and returns the api.SSHSignResponse struct.
func (c *Client) SSHSignWithContext(ctx context.Context, req *api.SSHSignRequest) (*api.SSHSignResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
@ -929,12 +1032,12 @@ func (c *Client) SSHSign(req *api.SSHSignRequest) (*api.SSHSignResponse, error)
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/sign"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -947,9 +1050,15 @@ retry:
return &sign, nil
}
// SSHRenew performs the POST /ssh/renew request to the CA and returns the
// api.SSHRenewResponse struct.
// SSHRenew performs the POST /ssh/renew request to the CA with an empty context
// and returns the api.SSHRenewResponse struct.
func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) {
return c.SSHRenewWithContext(context.Background(), req)
}
// SSHRenewWithContext performs the POST /ssh/renew request to the CA with the provided context
// and returns the api.SSHRenewResponse struct.
func (c *Client) SSHRenewWithContext(ctx context.Context, req *api.SSHRenewRequest) (*api.SSHRenewResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
@ -957,12 +1066,12 @@ func (c *Client) SSHRenew(req *api.SSHRenewRequest) (*api.SSHRenewResponse, erro
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/renew"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -975,9 +1084,15 @@ retry:
return &renew, nil
}
// SSHRekey performs the POST /ssh/rekey request to the CA and returns the
// api.SSHRekeyResponse struct.
// SSHRekey performs the POST /ssh/rekey request to the CA with an empty context
// and returns the api.SSHRekeyResponse struct.
func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) {
return c.SSHRekeyWithContext(context.Background(), req)
}
// SSHRekeyWithContext performs the POST /ssh/rekey request to the CA with the provided context
// and returns the api.SSHRekeyResponse struct.
func (c *Client) SSHRekeyWithContext(ctx context.Context, req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
@ -985,12 +1100,12 @@ func (c *Client) SSHRekey(req *api.SSHRekeyRequest) (*api.SSHRekeyResponse, erro
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/rekey"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -1003,9 +1118,15 @@ retry:
return &rekey, nil
}
// SSHRevoke performs the POST /ssh/revoke request to the CA and returns the
// api.SSHRevokeResponse struct.
// SSHRevoke performs the POST /ssh/revoke request to the CA with an empty context
// and returns the api.SSHRevokeResponse struct.
func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) {
return c.SSHRevokeWithContext(context.Background(), req)
}
// SSHRevokeWithContext performs the POST /ssh/revoke request to the CA with the provided context
// and returns the api.SSHRevokeResponse struct.
func (c *Client) SSHRevokeWithContext(ctx context.Context, req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
@ -1013,12 +1134,12 @@ func (c *Client) SSHRevoke(req *api.SSHRevokeRequest) (*api.SSHRevokeResponse, e
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/revoke"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -1031,18 +1152,24 @@ retry:
return &revoke, nil
}
// SSHRoots performs the GET /ssh/roots request to the CA and returns the
// api.SSHRootsResponse struct.
// SSHRoots performs the GET /ssh/roots request to the CA with an empty context
// and returns the api.SSHRootsResponse struct.
func (c *Client) SSHRoots() (*api.SSHRootsResponse, error) {
return c.SSHRootsWithContext(context.Background())
}
// SSHRootsWithContext performs the GET /ssh/roots request to the CA with the provided context
// and returns the api.SSHRootsResponse struct.
func (c *Client) SSHRootsWithContext(ctx context.Context) (*api.SSHRootsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/roots"})
retry:
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -1055,18 +1182,24 @@ retry:
return &keys, nil
}
// SSHFederation performs the get /ssh/federation request to the CA and returns
// the api.SSHRootsResponse struct.
// SSHFederation performs the get /ssh/federation request to the CA with an empty context
// and returns the api.SSHRootsResponse struct.
func (c *Client) SSHFederation() (*api.SSHRootsResponse, error) {
return c.SSHFederationWithContext(context.Background())
}
// SSHFederationWithContext performs the get /ssh/federation request to the CA with the provided context
// and returns the api.SSHRootsResponse struct.
func (c *Client) SSHFederationWithContext(ctx context.Context) (*api.SSHRootsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/federation"})
retry:
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -1079,9 +1212,15 @@ retry:
return &keys, nil
}
// SSHConfig performs the POST /ssh/config request to the CA to get the ssh
// configuration templates.
// SSHConfig performs the POST /ssh/config request to the CA with an empty context
// to get the ssh configuration templates.
func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) {
return c.SSHConfigWithContext(context.Background(), req)
}
// SSHConfigWithContext performs the POST /ssh/config request to the CA with the provided context
// to get the ssh configuration templates.
func (c *Client) SSHConfigWithContext(ctx context.Context, req *api.SSHConfigRequest) (*api.SSHConfigResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
@ -1089,12 +1228,12 @@ func (c *Client) SSHConfig(req *api.SSHConfigRequest) (*api.SSHConfigResponse, e
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/config"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -1107,9 +1246,15 @@ retry:
return &cfg, nil
}
// SSHCheckHost performs the POST /ssh/check-host request to the CA with the
// given principal.
// SSHCheckHost performs the POST /ssh/check-host request to the CA with an empty context,
// the principal and a token and returns the api.SSHCheckPrincipalResponse.
func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalResponse, error) {
return c.SSHCheckHostWithContext(context.Background(), principal, token)
}
// SSHCheckHostWithContext performs the POST /ssh/check-host request to the CA with the provided context,
// principal and token and returns the api.SSHCheckPrincipalResponse.
func (c *Client) SSHCheckHostWithContext(ctx context.Context, principal, token string) (*api.SSHCheckPrincipalResponse, error) {
var retried bool
body, err := json.Marshal(&api.SSHCheckPrincipalRequest{
Type: provisioner.SSHHostCert,
@ -1122,12 +1267,12 @@ func (c *Client) SSHCheckHost(principal, token string) (*api.SSHCheckPrincipalRe
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/check-host"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -1141,17 +1286,22 @@ retry:
return &check, nil
}
// SSHGetHosts performs the GET /ssh/get-hosts request to the CA.
// SSHGetHosts performs the GET /ssh/get-hosts request to the CA with an empty context.
func (c *Client) SSHGetHosts() (*api.SSHGetHostsResponse, error) {
return c.SSHGetHostsWithContext(context.Background())
}
// SSHGetHostsWithContext performs the GET /ssh/get-hosts request to the CA with the provided context.
func (c *Client) SSHGetHostsWithContext(ctx context.Context) (*api.SSHGetHostsResponse, error) {
var retried bool
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/hosts"})
retry:
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -1164,8 +1314,13 @@ retry:
return &hosts, nil
}
// SSHBastion performs the POST /ssh/bastion request to the CA.
// SSHBastion performs the POST /ssh/bastion request to the CA with an empty context.
func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) {
return c.SSHBastionWithContext(context.Background(), req)
}
// SSHBastionWithContext performs the POST /ssh/bastion request to the CA with the provided context.
func (c *Client) SSHBastionWithContext(ctx context.Context, req *api.SSHBastionRequest) (*api.SSHBastionResponse, error) {
var retried bool
body, err := json.Marshal(req)
if err != nil {
@ -1173,12 +1328,12 @@ func (c *Client) SSHBastion(req *api.SSHBastionRequest) (*api.SSHBastionResponse
}
u := c.endpoint.ResolveReference(&url.URL{Path: "/ssh/bastion"})
retry:
resp, err := c.client.Post(u.String(), "application/json", bytes.NewReader(body))
resp, err := c.client.PostWithContext(ctx, u.String(), "application/json", bytes.NewReader(body))
if err != nil {
return nil, clientError(err)
}
if resp.StatusCode >= 400 {
if !retried && c.retryOnError(resp) {
if !retried && c.retryOnError(resp) { //nolint:contextcheck // deeply nested context; retry using the same context
retried = true
goto retry
}
@ -1192,11 +1347,16 @@ retry:
}
// RootFingerprint is a helper method that returns the current root fingerprint.
// It does an health connection and gets the fingerprint from the TLS verified
// chains.
// It does an health connection and gets the fingerprint from the TLS verified chains.
func (c *Client) RootFingerprint() (string, error) {
return c.RootFingerprintWithContext(context.Background())
}
// RootFingerprintWithContext is a helper method that returns the current root fingerprint.
// It does an health connection and gets the fingerprint from the TLS verified chains.
func (c *Client) RootFingerprintWithContext(ctx context.Context) (string, error) {
u := c.endpoint.ResolveReference(&url.URL{Path: "/health"})
resp, err := c.client.Get(u.String())
resp, err := c.client.GetWithContext(ctx, u.String())
if err != nil {
return "", clientError(err)
}

@ -135,7 +135,7 @@ func (c *Client) getClientTLSConfig(ctx context.Context, sign *api.SignResponse,
//nolint:staticcheck // Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx)
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context
// Update client transport
c.SetTransport(tr)
@ -183,7 +183,7 @@ func (c *Client) GetServerTLSConfig(ctx context.Context, sign *api.SignResponse,
//nolint:staticcheck // Use mutable tls.Config on renew
tr.DialTLS = c.buildDialTLS(tlsCtx)
// tr.DialTLSContext = c.buildDialTLSContext(tlsCtx)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk)
renewer.RenewCertificate = getRenewFunc(tlsCtx, c, tr, pk) //nolint:contextcheck // deeply nested context
// Update client transport
c.SetTransport(tr)

@ -1,6 +1,7 @@
package stepcas
import (
"context"
"net/url"
"strings"
"time"
@ -37,7 +38,7 @@ type stepIssuer interface {
}
// newStepIssuer returns the configured step issuer.
func newStepIssuer(caURL *url.URL, client *ca.Client, iss *apiv1.CertificateIssuer) (stepIssuer, error) {
func newStepIssuer(ctx context.Context, caURL *url.URL, client *ca.Client, iss *apiv1.CertificateIssuer) (stepIssuer, error) {
if err := validateCertificateIssuer(iss); err != nil {
return nil, err
}
@ -46,7 +47,7 @@ func newStepIssuer(caURL *url.URL, client *ca.Client, iss *apiv1.CertificateIssu
case "x5c":
return newX5CIssuer(caURL, iss)
case "jwk":
return newJWKIssuer(caURL, client, iss)
return newJWKIssuer(ctx, caURL, client, iss)
default:
return nil, errors.Errorf("stepCAS `certificateIssuer.type` %s is not supported", iss.Type)
}

@ -1,6 +1,7 @@
package stepcas
import (
"context"
"net/url"
"reflect"
"testing"
@ -118,7 +119,7 @@ func Test_newStepIssuer(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newStepIssuer(tt.args.caURL, tt.args.client, tt.args.iss)
got, err := newStepIssuer(context.TODO(), tt.args.caURL, tt.args.client, tt.args.iss)
if (err != nil) != tt.wantErr {
t.Errorf("newStepIssuer() error = %v, wantErr %v", err, tt.wantErr)
return

@ -1,6 +1,7 @@
package stepcas
import (
"context"
"crypto"
"encoding/json"
"net/url"
@ -21,13 +22,13 @@ type jwkIssuer struct {
signer jose.Signer
}
func newJWKIssuer(caURL *url.URL, client *ca.Client, cfg *apiv1.CertificateIssuer) (*jwkIssuer, error) {
func newJWKIssuer(ctx context.Context, caURL *url.URL, client *ca.Client, cfg *apiv1.CertificateIssuer) (*jwkIssuer, error) {
var err error
var signer jose.Signer
// Read the key from the CA if not provided.
// Or read it from a PEM file.
if cfg.Key == "" {
p, err := findProvisioner(client, provisioner.TypeJWK, cfg.Provisioner)
p, err := findProvisioner(ctx, client, provisioner.TypeJWK, cfg.Provisioner)
if err != nil {
return nil, err
}
@ -144,10 +145,10 @@ func newJWKSignerFromEncryptedKey(kid, key, password string) (jose.Signer, error
return newJoseSigner(signer, so)
}
func findProvisioner(client *ca.Client, typ provisioner.Type, name string) (provisioner.Interface, error) {
func findProvisioner(ctx context.Context, client *ca.Client, typ provisioner.Type, name string) (provisioner.Interface, error) {
cursor := ""
for {
ps, err := client.Provisioners(ca.WithProvisionerCursor(cursor))
ps, err := client.ProvisionersWithContext(ctx, ca.WithProvisionerCursor(cursor))
if err != nil {
return nil, err
}

@ -43,7 +43,7 @@ func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) {
}
// Create client.
client, err := ca.NewClient(opts.CertificateAuthority, ca.WithRootSHA256(opts.CertificateAuthorityFingerprint))
client, err := ca.NewClient(opts.CertificateAuthority, ca.WithRootSHA256(opts.CertificateAuthorityFingerprint)) //nolint:contextcheck // deeply nested context
if err != nil {
return nil, err
}
@ -52,7 +52,7 @@ func New(ctx context.Context, opts apiv1.Options) (*StepCAS, error) {
// Create configured issuer unless we only want to use GetCertificateAuthority.
// This avoid the request for the password if not provided.
if !opts.IsCAGetter {
if iss, err = newStepIssuer(caURL, client, opts.CertificateIssuer); err != nil {
if iss, err = newStepIssuer(ctx, caURL, client, opts.CertificateIssuer); err != nil {
return nil, err
}
}

@ -245,7 +245,7 @@ func testJWKIssuer(t *testing.T, caURL *url.URL, password string) *jwkIssuer {
key = testEncryptedKeyPath
password = testPassword
}
jwk, err := newJWKIssuer(caURL, client, &apiv1.CertificateIssuer{
jwk, err := newJWKIssuer(context.TODO(), caURL, client, &apiv1.CertificateIssuer{
Type: "jwk",
Provisioner: "ra@doe.org",
Key: key,

Loading…
Cancel
Save