diff --git a/api/api_test.go b/api/api_test.go index cf988593..8090c6d4 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -884,16 +884,12 @@ func Test_Sign(t *testing.T) { CsrPEM: CertificateRequest{csr}, OTT: "foobarzar", }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) invalid, err := json.Marshal(SignRequest{ CsrPEM: CertificateRequest{csr}, OTT: "", }) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) expected1 := []byte(`{"crt":"` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(certPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) expected2 := []byte(`{"crt":"` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","ca":"` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n","certChain":["` + strings.ReplaceAll(stepCertPEM, "\n", `\n`) + `\n","` + strings.ReplaceAll(rootPEM, "\n", `\n`) + `\n"]}`) diff --git a/ca/ca_test.go b/ca/ca_test.go index 7ad25cc6..a8c173c4 100644 --- a/ca/ca_test.go +++ b/ca/ca_test.go @@ -289,6 +289,9 @@ ZEp7knvU2psWRw== if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var sign api.SignResponse assert.FatalError(t, readJSON(body, &sign)) @@ -325,7 +328,7 @@ ZEp7knvU2psWRw== assert.FatalError(t, err) assert.Equals(t, intermediate, realIntermediate) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -369,6 +372,9 @@ func TestCAProvisioners(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var resp api.ProvisionersResponse @@ -379,7 +385,7 @@ func TestCAProvisioners(t *testing.T) { assert.FatalError(t, err) assert.Equals(t, a, b) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -436,12 +442,15 @@ func TestCAProvisionerEncryptedKey(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var ek api.ProvisionerKeyResponse assert.FatalError(t, readJSON(body, &ek)) assert.Equals(t, ek.Key, tc.expectedKey) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -498,12 +507,15 @@ func TestCARoot(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var root api.RootResponse assert.FatalError(t, readJSON(body, &root)) assert.Equals(t, root.RootPEM.Certificate, rootCrt) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } @@ -641,6 +653,9 @@ func TestCARenew(t *testing.T) { if assert.Equals(t, rr.Code, tc.status) { body := &ClosingBuffer{rr.Body} + resp := &http.Response{ + Body: body, + } if rr.Code < http.StatusBadRequest { var sign api.SignResponse assert.FatalError(t, readJSON(body, &sign)) @@ -673,7 +688,7 @@ func TestCARenew(t *testing.T) { assert.Equals(t, *sign.TLSOptions, authority.DefaultTLSOptions) } else { - err := readError(body) + err := readError(resp) if tc.errMsg == "" { assert.FatalError(t, errors.New("must validate response error")) } diff --git a/ca/client.go b/ca/client.go index 5e2d98c8..8930d8ee 100644 --- a/ca/client.go +++ b/ca/client.go @@ -622,7 +622,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var version api.VersionResponse if err := readJSON(resp.Body, &version); err != nil { @@ -652,7 +652,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var health api.HealthResponse if err := readJSON(resp.Body, &health); err != nil { @@ -687,7 +687,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var root api.RootResponse if err := readJSON(resp.Body, &root); err != nil { @@ -726,7 +726,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -765,7 +765,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -802,7 +802,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -842,7 +842,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -883,7 +883,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var revoke api.RevokeResponse if err := readJSON(resp.Body, &revoke); err != nil { @@ -926,7 +926,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var provisioners api.ProvisionersResponse if err := readJSON(resp.Body, &provisioners); err != nil { @@ -958,7 +958,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var key api.ProvisionerKeyResponse if err := readJSON(resp.Body, &key); err != nil { @@ -988,7 +988,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var roots api.RootsResponse if err := readJSON(resp.Body, &roots); err != nil { @@ -1018,7 +1018,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var federation api.FederationResponse if err := readJSON(resp.Body, &federation); err != nil { @@ -1052,7 +1052,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var sign api.SSHSignResponse if err := readJSON(resp.Body, &sign); err != nil { @@ -1086,7 +1086,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var renew api.SSHRenewResponse if err := readJSON(resp.Body, &renew); err != nil { @@ -1120,7 +1120,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var rekey api.SSHRekeyResponse if err := readJSON(resp.Body, &rekey); err != nil { @@ -1154,7 +1154,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var revoke api.SSHRevokeResponse if err := readJSON(resp.Body, &revoke); err != nil { @@ -1184,7 +1184,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { @@ -1214,7 +1214,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var keys api.SSHRootsResponse if err := readJSON(resp.Body, &keys); err != nil { @@ -1248,7 +1248,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var cfg api.SSHConfigResponse if err := readJSON(resp.Body, &cfg); err != nil { @@ -1287,7 +1287,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var check api.SSHCheckPrincipalResponse if err := readJSON(resp.Body, &check); err != nil { @@ -1316,7 +1316,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var hosts api.SSHGetHostsResponse if err := readJSON(resp.Body, &hosts); err != nil { @@ -1348,7 +1348,7 @@ retry: retried = true goto retry } - return nil, readError(resp.Body) + return nil, readError(resp) } var bastion api.SSHBastionResponse if err := readJSON(resp.Body, &bastion); err != nil { @@ -1516,12 +1516,13 @@ func readProtoJSON(r io.ReadCloser, m proto.Message) error { return protojson.Unmarshal(data, m) } -func readError(r io.ReadCloser) error { - defer r.Close() +func readError(r *http.Response) error { + defer r.Body.Close() apiErr := new(errs.Error) - if err := json.NewDecoder(r).Decode(apiErr); err != nil { + if err := json.NewDecoder(r.Body).Decode(apiErr); err != nil { return err } + apiErr.RequestID = r.Header.Get("X-Request-Id") return apiErr } diff --git a/errs/error.go b/errs/error.go index ba066925..c9ad92a6 100644 --- a/errs/error.go +++ b/errs/error.go @@ -49,10 +49,11 @@ func WithKeyVal(key string, val interface{}) Option { // Error represents the CA API errors. type Error struct { - Status int - Err error - Msg string - Details map[string]interface{} + Status int + Err error + Msg string + Details map[string]interface{} + RequestID string `json:"-"` } // ErrorResponse represents an error in JSON format. diff --git a/test/e2e/requestid_test.go b/test/e2e/requestid_test.go new file mode 100644 index 00000000..7eccb4f4 --- /dev/null +++ b/test/e2e/requestid_test.go @@ -0,0 +1,102 @@ +package e2e + +import ( + "context" + "encoding/json" + "fmt" + "net" + "path/filepath" + "sync" + "testing" + + "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/ca" + "github.com/smallstep/certificates/errs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/pemutil" +) + +func TestXxx(t *testing.T) { + dir := t.TempDir() + m, err := minica.New(minica.WithName("Step E2E")) + require.NoError(t, err) + + rootFilepath := filepath.Join(dir, "root.crt") + _, err = pemutil.Serialize(m.Root, pemutil.WithFilename(rootFilepath)) + require.NoError(t, err) + + intermediateCertFilepath := filepath.Join(dir, "intermediate.crt") + _, err = pemutil.Serialize(m.Intermediate, pemutil.WithFilename(intermediateCertFilepath)) + require.NoError(t, err) + + intermediateKeyFilepath := filepath.Join(dir, "intermediate.key") + _, err = pemutil.Serialize(m.Signer, pemutil.WithFilename(intermediateKeyFilepath)) + require.NoError(t, err) + + // get a random address to listen on and connect to; currently no nicer way to get one before starting the server + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + randomAddress := l.Addr().String() + err = l.Close() + require.NoError(t, err) + + cfg := &config.Config{ + Root: []string{rootFilepath}, + IntermediateCert: intermediateCertFilepath, + IntermediateKey: intermediateKeyFilepath, + Address: randomAddress, // reuse the address that was just "reserved" + DNSNames: []string{"127.0.0.1", "stepca.localhost"}, + AuthorityConfig: &config.AuthConfig{ + AuthorityID: "stepca-test", + DeploymentType: "standalone-test", + }, + Logger: json.RawMessage(`{"format": "text"}`), + } + c, err := ca.New(cfg) + require.NoError(t, err) + + // instantiate a client for the CA + client, err := ca.NewClient( + fmt.Sprintf("https://%s", randomAddress), + ca.WithRootFile(rootFilepath), + ) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + err = c.Run() + require.Error(t, err) // expect error when server is stopped + }() + + // require OK health response as the baseline + ctx := context.Background() + healthResponse, err := client.HealthWithContext(ctx) + assert.NoError(t, err) + require.Equal(t, "ok", healthResponse.Status) + + // expect an error when retrieving an invalid root + rootResponse, err := client.RootWithContext(ctx, "invalid") + if assert.Error(t, err) { + apiErr := &errs.Error{} + if assert.ErrorAs(t, err, &apiErr) { + assert.Equal(t, 404, apiErr.StatusCode()) + assert.Equal(t, "The requested resource could not be found. Please see the certificate authority logs for more info.", apiErr.Err.Error()) + assert.NotEmpty(t, apiErr.RequestID) + + // TODO: include the below error in the JSON? It's currently only output to the CA logs + //assert.Equal(t, "/root/invalid was not found: certificate with fingerprint invalid was not found", apiErr.Msg) + } + } + assert.Nil(t, rootResponse) + + // done testing; stop and wait for the server to quit + err = c.Stop() + require.NoError(t, err) + + wg.Wait() +}