diff --git a/ca/bootstrap_test.go b/ca/bootstrap_test.go index 3449b45a..9b78d0ee 100644 --- a/ca/bootstrap_test.go +++ b/ca/bootstrap_test.go @@ -156,8 +156,8 @@ func TestBootstrap(t *testing.T) { if !reflect.DeepEqual(got.endpoint, tt.want.endpoint) { t.Errorf("Bootstrap() endpoint = %v, want %v", got.endpoint, tt.want.endpoint) } - gotTR := got.client.Transport.(*http.Transport) - wantTR := tt.want.client.Transport.(*http.Transport) + gotTR := got.client.GetTransport().(*http.Transport) + wantTR := tt.want.client.GetTransport().(*http.Transport) if !reflect.DeepEqual(gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) { t.Errorf("Bootstrap() certPool = %v, want %v", gotTR.TLSClientConfig.RootCAs, wantTR.TLSClientConfig.RootCAs) } diff --git a/ca/client.go b/ca/client.go index 2a8e9ca8..0267dfa3 100644 --- a/ca/client.go +++ b/ca/client.go @@ -32,6 +32,58 @@ import ( "gopkg.in/square/go-jose.v2/jwt" ) +// UserAgent will set the User-Agent header in the client requests. +var UserAgent = "step-http-client/1.0" + +type uaClient struct { + Client *http.Client +} + +func newClient(transport http.RoundTripper) *uaClient { + return &uaClient{ + Client: &http.Client{ + Transport: transport, + }, + } +} + +func newInsecureClient() *uaClient { + return &uaClient{ + Client: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + }, + } +} + +func (c *uaClient) GetTransport() http.RoundTripper { + return c.Client.Transport +} + +func (c *uaClient) SetTransport(tr http.RoundTripper) { + c.Client.Transport = tr +} + +func (c *uaClient) Get(url string) (*http.Response, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, errors.Wrapf(err, "new request GET %s failed", url) + } + req.Header.Set("User-Agent", UserAgent) + return c.Client.Do(req) +} + +func (c *uaClient) Post(url, contentType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", contentType) + req.Header.Set("User-Agent", UserAgent) + return c.Client.Do(req) +} + // RetryFunc defines the method used to retry a request. If it returns true, the // request will be retried once. type RetryFunc func(code int) bool @@ -354,7 +406,7 @@ func WithProvisionerLimit(limit int) ProvisionerOption { // Client implements an HTTP client for the CA server. type Client struct { - client *http.Client + client *uaClient endpoint *url.URL retryFunc RetryFunc opts []ClientOption @@ -377,9 +429,7 @@ func NewClient(endpoint string, opts ...ClientOption) (*Client, error) { } return &Client{ - client: &http.Client{ - Transport: tr, - }, + client: newClient(tr), endpoint: u, retryFunc: o.retryFunc, opts: opts, @@ -398,7 +448,7 @@ func (c *Client) retryOnError(r *http.Response) bool { return false } r.Body.Close() - c.client.Transport = tr + c.client.SetTransport(tr) return true } } @@ -408,7 +458,7 @@ func (c *Client) retryOnError(r *http.Response) bool { // GetRootCAs returns the RootCAs certificate pool from the configured // transport. func (c *Client) GetRootCAs() *x509.CertPool { - switch t := c.client.Transport.(type) { + switch t := c.client.GetTransport().(type) { case *http.Transport: if t.TLSClientConfig != nil { return t.TLSClientConfig.RootCAs @@ -426,7 +476,7 @@ func (c *Client) GetRootCAs() *x509.CertPool { // SetTransport updates the transport of the internal HTTP client. func (c *Client) SetTransport(tr http.RoundTripper) { - c.client.Transport = tr + c.client.SetTransport(tr) } // Version performs the version request to the CA and returns the @@ -486,7 +536,7 @@ func (c *Client) Root(sha256Sum string) (*api.RootResponse, error) { sha256Sum = strings.ToLower(strings.Replace(sha256Sum, "-", "", -1)) u := c.endpoint.ResolveReference(&url.URL{Path: "/root/" + sha256Sum}) retry: - resp, err := getInsecureClient().Get(u.String()) + resp, err := newInsecureClient().Get(u.String()) if err != nil { return nil, errors.Wrapf(err, "client GET %s failed", u) } @@ -573,10 +623,10 @@ func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.Revo if err != nil { return nil, errors.Wrap(err, "error marshaling request") } - var client *http.Client + var client *uaClient retry: if tr != nil { - client = &http.Client{Transport: tr} + client = newClient(tr) } else { client = c.client } @@ -1082,14 +1132,6 @@ func createCertificateRequest(commonName string, sans []string, key crypto.Priva return &api.CertificateRequest{CertificateRequest: cr}, key, nil } -func getInsecureClient() *http.Client { - return &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }, - } -} - // getRootCAPath returns the path where the root CA is stored based on the // STEPPATH environment variable. func getRootCAPath() string {