diff --git a/authority/tls.go b/authority/tls.go index bc160ad0..b7b2f936 100644 --- a/authority/tls.go +++ b/authority/tls.go @@ -263,7 +263,7 @@ func (a *Authority) Rekey(oldCert *x509.Certificate, pk crypto.PublicKey) ([]*x5 } fullchain := append([]*x509.Certificate{resp.Certificate}, resp.CertificateChain...) - if err = a.storeCertificate(fullchain); err != nil { + if err = a.storeRenewedCertificate(oldCert, fullchain); err != nil { if err != db.ErrNotImplemented { return nil, errs.Wrap(http.StatusInternalServerError, err, "authority.Rekey; error storing certificate in db", opts...) } @@ -287,6 +287,19 @@ func (a *Authority) storeCertificate(fullchain []*x509.Certificate) error { return a.db.StoreCertificate(fullchain[0]) } +// storeRenewedCertificate allows to use an extension of the db.AuthDB interface +// that can log if a certificate has been renewed or rekeyed. +// +// TODO: at some point we should implement this in the standard implementation. +func (a *Authority) storeRenewedCertificate(oldCert *x509.Certificate, fullchain []*x509.Certificate) error { + if s, ok := a.db.(interface { + StoreRenewedCertificate(*x509.Certificate, ...*x509.Certificate) error + }); ok { + return s.StoreRenewedCertificate(oldCert, fullchain...) + } + return a.db.StoreCertificate(fullchain[0]) +} + // RevokeOptions are the options for the Revoke API. type RevokeOptions struct { Serial string diff --git a/ca/client.go b/ca/client.go index 19f758f1..2292c41e 100644 --- a/ca/client.go +++ b/ca/client.go @@ -616,6 +616,36 @@ retry: return &sign, nil } +// Rekey performs the rekey request to the CA and returns the api.SignResponse +// struct. +func (c *Client) Rekey(req *api.RekeyRequest, tr http.RoundTripper) (*api.SignResponse, error) { + var retried bool + body, err := json.Marshal(req) + if err != nil { + return nil, errors.Wrap(err, "error marshaling request") + } + + u := c.endpoint.ResolveReference(&url.URL{Path: "/rekey"}) + client := &http.Client{Transport: tr} +retry: + resp, err := client.Post(u.String(), "application/json", bytes.NewReader(body)) + if err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; client POST %s failed", u) + } + if resp.StatusCode >= 400 { + if !retried && c.retryOnError(resp) { + retried = true + goto retry + } + return nil, readError(resp.Body) + } + var sign api.SignResponse + if err := readJSON(resp.Body, &sign); err != nil { + return nil, errs.Wrapf(http.StatusInternalServerError, err, "client.Rekey; error reading %s", u) + } + return &sign, nil +} + // Revoke performs the revoke request to the CA and returns the api.RevokeResponse // struct. func (c *Client) Revoke(req *api.RevokeRequest, tr http.RoundTripper) (*api.RevokeResponse, error) { diff --git a/ca/client_test.go b/ca/client_test.go index dbba4d4c..30669e6e 100644 --- a/ca/client_test.go +++ b/ca/client_test.go @@ -529,6 +529,75 @@ func TestClient_Renew(t *testing.T) { } } +func TestClient_Rekey(t *testing.T) { + ok := &api.SignResponse{ + ServerPEM: api.Certificate{Certificate: parseCertificate(certPEM)}, + CaPEM: api.Certificate{Certificate: parseCertificate(rootPEM)}, + CertChainPEM: []api.Certificate{ + {Certificate: parseCertificate(certPEM)}, + {Certificate: parseCertificate(rootPEM)}, + }, + } + + request := &api.RekeyRequest{ + CsrPEM: api.CertificateRequest{CertificateRequest: parseCertificateRequest(csrPEM)}, + } + + tests := []struct { + name string + request *api.RekeyRequest + response interface{} + responseCode int + wantErr bool + err error + }{ + {"ok", request, ok, 200, false, nil}, + {"unauthorized", request, errs.Unauthorized("force"), 401, true, errors.New(errs.UnauthorizedDefaultMsg)}, + {"empty request", &api.RekeyRequest{}, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + {"nil request", nil, errs.BadRequest("force"), 400, true, errors.New(errs.BadRequestDefaultMsg)}, + } + + srv := httptest.NewServer(nil) + defer srv.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, err := NewClient(srv.URL, WithTransport(http.DefaultTransport)) + if err != nil { + t.Errorf("NewClient() error = %v", err) + return + } + + srv.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + api.JSONStatus(w, tt.response, tt.responseCode) + }) + + got, err := c.Rekey(tt.request, nil) + if (err != nil) != tt.wantErr { + fmt.Printf("%+v", err) + t.Errorf("Client.Renew() error = %v, wantErr %v", err, tt.wantErr) + return + } + + switch { + case err != nil: + if got != nil { + t.Errorf("Client.Renew() = %v, want nil", got) + } + + sc, ok := err.(errs.StatusCoder) + assert.Fatal(t, ok, "error does not implement StatusCoder interface") + assert.Equals(t, sc.StatusCode(), tt.responseCode) + assert.HasPrefix(t, tt.err.Error(), err.Error()) + default: + if !reflect.DeepEqual(got, tt.response) { + t.Errorf("Client.Renew() = %v, want %v", got, tt.response) + } + } + }) + } +} + func TestClient_Provisioners(t *testing.T) { ok := &api.ProvisionersResponse{ Provisioners: provisioner.List{},