package acme import ( "context" "crypto/x509" "crypto/x509/pkix" "encoding/json" "fmt" "net" "net/url" "testing" "time" "github.com/pkg/errors" "github.com/smallstep/assert" "github.com/smallstep/certificates/authority/provisioner" "github.com/smallstep/certificates/db" "github.com/smallstep/nosql" "github.com/smallstep/nosql/database" ) var certDuration = 6 * time.Hour func defaultOrderOps() OrderOptions { return OrderOptions{ AccountID: "accID", Identifiers: []Identifier{ {Type: "dns", Value: "acme.example.com"}, {Type: "dns", Value: "step.example.com"}, }, NotBefore: clock.Now(), NotAfter: clock.Now().Add(certDuration), } } func newO() (*order, error) { mockdb := &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), true, nil }, MGet: func(bucket, key []byte) ([]byte, error) { return nil, database.ErrNotFound }, } return newOrder(mockdb, defaultOrderOps()) } func Test_getOrder(t *testing.T) { type test struct { id string db nosql.DB o *order err *Error } tests := map[string]func(t *testing.T) test{ "fail/not-found": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) return test{ o: o, id: o.ID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { return nil, database.ErrNotFound }, }, err: MalformedErr(errors.Errorf("order %s not found: not found", o.ID)), } }, "fail/db-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) return test{ o: o, id: o.ID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { return nil, errors.New("force") }, }, err: ServerInternalErr(errors.Errorf("error loading order %s: force", o.ID)), } }, "fail/unmarshal-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) return test{ o: o, id: o.ID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, key, []byte(o.ID)) return nil, nil }, }, err: ServerInternalErr(errors.New("error unmarshaling order: unexpected end of JSON input")), } }, "ok": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) b, err := json.Marshal(o) assert.FatalError(t, err) return test{ o: o, id: o.ID, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { assert.Equals(t, bucket, orderTable) assert.Equals(t, key, []byte(o.ID)) return b, nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if o, err := getOrder(tc.db, tc.id); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, tc.o.ID, o.ID) assert.Equals(t, tc.o.AccountID, o.AccountID) assert.Equals(t, tc.o.Status, o.Status) assert.Equals(t, tc.o.Identifiers, o.Identifiers) assert.Equals(t, tc.o.Created, o.Created) assert.Equals(t, tc.o.Expires, o.Expires) assert.Equals(t, tc.o.Authorizations, o.Authorizations) assert.Equals(t, tc.o.NotBefore, o.NotBefore) assert.Equals(t, tc.o.NotAfter, o.NotAfter) assert.Equals(t, tc.o.Certificate, o.Certificate) assert.Equals(t, tc.o.Error, o.Error) } } }) } } func TestOrderToACME(t *testing.T) { dir := newDirectory("ca.smallstep.com", "acme") prov := newProv() provName := url.PathEscape(prov.GetName()) baseURL := &url.URL{Scheme: "https", Host: "test.ca.smallstep.com"} ctx := context.WithValue(context.Background(), ProvisionerContextKey, prov) ctx = context.WithValue(ctx, BaseURLContextKey, baseURL) type test struct { o *order err *Error } tests := map[string]func(t *testing.T) test{ "ok/no-cert": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) return test{o: o} }, "ok/cert": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusValid o.Certificate = "cert-id" return test{o: o} }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { acmeOrder, err := tc.o.toACME(ctx, nil, dir) if err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, acmeOrder.ID, tc.o.ID) assert.Equals(t, acmeOrder.Status, tc.o.Status) assert.Equals(t, acmeOrder.Identifiers, tc.o.Identifiers) assert.Equals(t, acmeOrder.Finalize, fmt.Sprintf("%s/acme/%s/order/%s/finalize", baseURL.String(), provName, tc.o.ID)) if tc.o.Certificate != "" { assert.Equals(t, acmeOrder.Certificate, fmt.Sprintf("%s/acme/%s/certificate/%s", baseURL.String(), provName, tc.o.Certificate)) } expiry, err := time.Parse(time.RFC3339, acmeOrder.Expires) assert.FatalError(t, err) assert.Equals(t, expiry.String(), tc.o.Expires.String()) nbf, err := time.Parse(time.RFC3339, acmeOrder.NotBefore) assert.FatalError(t, err) assert.Equals(t, nbf.String(), tc.o.NotBefore.String()) naf, err := time.Parse(time.RFC3339, acmeOrder.NotAfter) assert.FatalError(t, err) assert.Equals(t, naf.String(), tc.o.NotAfter.String()) } } }) } } func TestOrderSave(t *testing.T) { type test struct { o, old *order db nosql.DB err *Error } tests := map[string]func(t *testing.T) test{ "fail/old-nil/swap-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) return test{ o: o, old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, false, errors.New("force") }, }, err: ServerInternalErr(errors.New("error storing order: force")), } }, "fail/old-nil/swap-false": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) return test{ o: o, old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), false, nil }, }, err: ServerInternalErr(errors.New("error storing order; value has changed since last read")), } }, "ok/old-nil": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) b, err := json.Marshal(o) assert.FatalError(t, err) return test{ o: o, old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, old, nil) assert.Equals(t, b, newval) assert.Equals(t, bucket, orderTable) assert.Equals(t, []byte(o.ID), key) return nil, true, nil }, }, } }, "ok/old-not-nil": func(t *testing.T) test { oldo, err := newO() assert.FatalError(t, err) o, err := newO() assert.FatalError(t, err) oldb, err := json.Marshal(oldo) assert.FatalError(t, err) b, err := json.Marshal(o) assert.FatalError(t, err) return test{ o: o, old: oldo, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, old, oldb) assert.Equals(t, b, newval) assert.Equals(t, bucket, orderTable) assert.Equals(t, []byte(o.ID), key) return []byte("foo"), true, nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if err := tc.o.save(tc.db, tc.old); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } else { assert.Nil(t, tc.err) } }) } } func Test_newOrder(t *testing.T) { type test struct { ops OrderOptions db nosql.DB err *Error authzs *([]string) } tests := map[string]func(t *testing.T) test{ "fail/unexpected-identifier-type": func(t *testing.T) test { ops := defaultOrderOps() ops.Identifiers[0].Type = "foo" return test{ ops: ops, err: MalformedErr(errors.New("unexpected authz type foo")), } }, "fail/save-order-error": func(t *testing.T) test { count := 0 return test{ ops: defaultOrderOps(), db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { if count >= 8 { return nil, false, errors.New("force") } count++ return nil, true, nil }, }, err: ServerInternalErr(errors.New("error storing order: force")), } }, "fail/get-orderIDs-error": func(t *testing.T) test { count := 0 ops := defaultOrderOps() return test{ ops: ops, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { if count >= 9 { return nil, false, errors.New("force") } count++ return nil, true, nil }, MGet: func(bucket, key []byte) ([]byte, error) { return nil, errors.New("force") }, }, err: ServerInternalErr(errors.Errorf("error loading orderIDs for account %s: force", ops.AccountID)), } }, "fail/save-orderIDs-error": func(t *testing.T) test { count := 0 var ( _oid = "" oid = &_oid ) ops := defaultOrderOps() return test{ ops: ops, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { if count >= 9 { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(ops.AccountID)) return nil, false, errors.New("force") } else if count == 8 { *oid = string(key) } count++ return nil, true, nil }, MGet: func(bucket, key []byte) ([]byte, error) { return nil, database.ErrNotFound }, MDel: func(bucket, key []byte) error { assert.Equals(t, bucket, orderTable) assert.Equals(t, key, []byte(*oid)) return nil }, }, err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", ops.AccountID)), } }, "ok": func(t *testing.T) test { count := 0 authzs := &([]string{}) var ( _oid = "" oid = &_oid ) ops := defaultOrderOps() return test{ ops: ops, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { if count >= 9 { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(ops.AccountID)) assert.Equals(t, old, nil) newB, err := json.Marshal([]string{*oid}) assert.FatalError(t, err) assert.Equals(t, newval, newB) } else if count == 8 { *oid = string(key) } else if count == 7 { *authzs = append(*authzs, string(key)) } else if count == 3 { *authzs = []string{string(key)} } count++ return nil, true, nil }, MGet: func(bucket, key []byte) ([]byte, error) { return nil, database.ErrNotFound }, }, authzs: authzs, } }, "ok/validity-bounds-not-set": func(t *testing.T) test { count := 0 authzs := &([]string{}) var ( _oid = "" oid = &_oid ) ops := defaultOrderOps() ops.backdate = time.Minute ops.defaultDuration = 12 * time.Hour ops.NotBefore = time.Time{} ops.NotAfter = time.Time{} return test{ ops: ops, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { if count >= 9 { assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(ops.AccountID)) assert.Equals(t, old, nil) newB, err := json.Marshal([]string{*oid}) assert.FatalError(t, err) assert.Equals(t, newval, newB) } else if count == 8 { *oid = string(key) } else if count == 7 { *authzs = append(*authzs, string(key)) } else if count == 3 { *authzs = []string{string(key)} } count++ return nil, true, nil }, MGet: func(bucket, key []byte) ([]byte, error) { return nil, database.ErrNotFound }, }, authzs: authzs, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { o, err := newOrder(tc.db, tc.ops) if err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } else { if assert.Nil(t, tc.err) { assert.Equals(t, o.AccountID, tc.ops.AccountID) assert.Equals(t, o.Status, StatusPending) assert.Equals(t, o.Identifiers, tc.ops.Identifiers) assert.Equals(t, o.Error, nil) assert.Equals(t, o.Certificate, "") assert.Equals(t, o.Authorizations, *tc.authzs) assert.True(t, o.Created.Before(time.Now().UTC().Add(time.Minute))) assert.True(t, o.Created.After(time.Now().UTC().Add(-1*time.Minute))) expiry := o.Created.Add(defaultExpiryDuration) assert.True(t, o.Expires.Before(expiry.Add(time.Minute))) assert.True(t, o.Expires.After(expiry.Add(-1*time.Minute))) nbf := tc.ops.NotBefore now := time.Now().UTC() if !tc.ops.NotBefore.IsZero() { assert.Equals(t, o.NotBefore, tc.ops.NotBefore) } else { nbf = o.NotBefore.Add(tc.ops.backdate) assert.True(t, o.NotBefore.Before(now.Add(-tc.ops.backdate+time.Second))) assert.True(t, o.NotBefore.Add(tc.ops.backdate+2*time.Second).After(now)) } if !tc.ops.NotAfter.IsZero() { assert.Equals(t, o.NotAfter, tc.ops.NotAfter) } else { naf := nbf.Add(tc.ops.defaultDuration) assert.Equals(t, o.NotAfter, naf) } } } }) } } func TestOrderIDs_save(t *testing.T) { accID := "acc-id" newOids := func() orderIDs { return []string{"1", "2"} } type test struct { oids, old orderIDs db nosql.DB err *Error } tests := map[string]func(t *testing.T) test{ "fail/old-nil/swap-error": func(t *testing.T) test { return test{ oids: newOids(), old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, false, errors.New("force") }, }, err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s: force", accID)), } }, "fail/old-nil/swap-false": func(t *testing.T) test { return test{ oids: newOids(), old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return []byte("foo"), false, nil }, }, err: ServerInternalErr(errors.Errorf("error storing order IDs for account %s; order IDs changed since last read", accID)), } }, "ok/old-nil": func(t *testing.T) test { oids := newOids() b, err := json.Marshal(oids) assert.FatalError(t, err) return test{ oids: oids, old: nil, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, old, nil) assert.Equals(t, b, newval) assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(accID)) return nil, true, nil }, }, } }, "ok/old-not-nil": func(t *testing.T) test { oldOids := newOids() oids := append(oldOids, "3") oldb, err := json.Marshal(oldOids) assert.FatalError(t, err) b, err := json.Marshal(oids) assert.FatalError(t, err) return test{ oids: oids, old: oldOids, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, old, oldb) assert.Equals(t, newval, b) assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(accID)) return nil, true, nil }, }, } }, "ok/new-empty-saved-as-nil": func(t *testing.T) test { oldOids := newOids() oids := []string{} oldb, err := json.Marshal(oldOids) assert.FatalError(t, err) return test{ oids: oids, old: oldOids, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { assert.Equals(t, old, oldb) assert.Equals(t, newval, nil) assert.Equals(t, bucket, ordersByAccountIDTable) assert.Equals(t, key, []byte(accID)) return nil, true, nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) if err := tc.oids.save(tc.db, tc.old, accID); err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } else { assert.Nil(t, tc.err) } }) } } func TestOrderUpdateStatus(t *testing.T) { type test struct { o, res *order err *Error db nosql.DB } tests := map[string]func(t *testing.T) test{ "fail/already-invalid": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusInvalid return test{ o: o, res: o, } }, "fail/already-valid": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusValid return test{ o: o, res: o, } }, "fail/unexpected-status": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusDeactivated return test{ o: o, res: o, err: ServerInternalErr(errors.New("unrecognized order status: deactivated")), } }, "fail/save-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Expires = time.Now().UTC().Add(-time.Minute) return test{ o: o, res: o, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, false, errors.New("force") }, }, err: ServerInternalErr(errors.New("error storing order: force")), } }, "ok/expired": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Expires = time.Now().UTC().Add(-time.Minute) _o := *o clone := &_o clone.Error = MalformedErr(errors.New("order has expired")) clone.Status = StatusInvalid return test{ o: o, res: clone, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, true, nil }, }, } }, "fail/get-authz-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) return test{ o: o, res: o, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { return nil, errors.New("force") }, }, err: ServerInternalErr(errors.New("error loading authz")), } }, "ok/still-pending": func(t *testing.T) test { az1, err := newAz() assert.FatalError(t, err) az2, err := newAz() assert.FatalError(t, err) az3, err := newAz() assert.FatalError(t, err) ch1, err := newHTTPCh() assert.FatalError(t, err) ch2, err := newTLSALPNCh() assert.FatalError(t, err) ch3, err := newDNSCh() assert.FatalError(t, err) ch1b, err := json.Marshal(ch1) assert.FatalError(t, err) ch2b, err := json.Marshal(ch2) assert.FatalError(t, err) ch3b, err := json.Marshal(ch3) assert.FatalError(t, err) o, err := newO() assert.FatalError(t, err) o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()} _az3, ok := az3.(*dnsAuthz) assert.Fatal(t, ok) _az3.baseAuthz.Status = StatusValid b1, err := json.Marshal(az1) assert.FatalError(t, err) b2, err := json.Marshal(az2) assert.FatalError(t, err) b3, err := json.Marshal(az3) assert.FatalError(t, err) count := 0 return test{ o: o, res: o, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { var ret []byte switch count { case 0: ret = b1 case 1: ret = ch1b case 2: ret = ch2b case 3: ret = ch3b case 4: ret = b2 case 5: ret = ch1b case 6: ret = ch2b case 7: ret = ch3b case 8: ret = b3 default: return nil, errors.New("unexpected count") } count++ return ret, nil }, MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, true, nil }, }, } }, "ok/invalid": func(t *testing.T) test { az1, err := newAz() assert.FatalError(t, err) az2, err := newAz() assert.FatalError(t, err) az3, err := newAz() assert.FatalError(t, err) ch1, err := newHTTPCh() assert.FatalError(t, err) ch2, err := newTLSALPNCh() assert.FatalError(t, err) ch3, err := newDNSCh() assert.FatalError(t, err) ch1b, err := json.Marshal(ch1) assert.FatalError(t, err) ch2b, err := json.Marshal(ch2) assert.FatalError(t, err) ch3b, err := json.Marshal(ch3) assert.FatalError(t, err) o, err := newO() assert.FatalError(t, err) o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()} _az3, ok := az3.(*dnsAuthz) assert.Fatal(t, ok) _az3.baseAuthz.Status = StatusInvalid b1, err := json.Marshal(az1) assert.FatalError(t, err) b2, err := json.Marshal(az2) assert.FatalError(t, err) b3, err := json.Marshal(az3) assert.FatalError(t, err) _o := *o clone := &_o clone.Status = StatusInvalid count := 0 return test{ o: o, res: clone, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { var ret []byte switch count { case 0: ret = b1 case 1: ret = ch1b case 2: ret = ch2b case 3: ret = ch3b case 4: ret = b2 case 5: ret = ch1b case 6: ret = ch2b case 7: ret = ch3b case 8: ret = b3 default: return nil, errors.New("unexpected count") } count++ return ret, nil }, MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, true, nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) o, err := tc.o.updateStatus(tc.db) if err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } else { if assert.Nil(t, tc.err) { expB, err := json.Marshal(tc.res) assert.FatalError(t, err) b, err := json.Marshal(o) assert.FatalError(t, err) assert.Equals(t, expB, b) } } }) } } type mockSignAuth struct { sign func(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) loadProvisionerByID func(string) (provisioner.Interface, error) ret1, ret2 interface{} err error } func (m *mockSignAuth) Sign(csr *x509.CertificateRequest, signOpts provisioner.Options, extraOpts ...provisioner.SignOption) ([]*x509.Certificate, error) { if m.sign != nil { return m.sign(csr, signOpts, extraOpts...) } else if m.err != nil { return nil, m.err } return []*x509.Certificate{m.ret1.(*x509.Certificate), m.ret2.(*x509.Certificate)}, m.err } func (m *mockSignAuth) LoadProvisionerByID(id string) (provisioner.Interface, error) { if m.loadProvisionerByID != nil { return m.loadProvisionerByID(id) } return m.ret1.(provisioner.Interface), m.err } func TestOrderFinalize(t *testing.T) { prov := newProv() type test struct { o, res *order err *Error db nosql.DB csr *x509.CertificateRequest sa SignAuthority prov Provisioner } tests := map[string]func(t *testing.T) test{ "fail/already-invalid": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusInvalid return test{ o: o, err: OrderNotReadyErr(errors.Errorf("order %s has been abandoned", o.ID)), } }, "ok/already-valid": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusValid o.Certificate = "cert-id" return test{ o: o, res: o, } }, "fail/still-pending": func(t *testing.T) test { az1, err := newAz() assert.FatalError(t, err) az2, err := newAz() assert.FatalError(t, err) az3, err := newAz() assert.FatalError(t, err) ch1, err := newHTTPCh() assert.FatalError(t, err) ch2, err := newTLSALPNCh() assert.FatalError(t, err) ch3, err := newDNSCh() assert.FatalError(t, err) ch1b, err := json.Marshal(ch1) assert.FatalError(t, err) ch2b, err := json.Marshal(ch2) assert.FatalError(t, err) ch3b, err := json.Marshal(ch3) assert.FatalError(t, err) o, err := newO() assert.FatalError(t, err) o.Authorizations = []string{az1.getID(), az2.getID(), az3.getID()} _az3, ok := az3.(*dnsAuthz) assert.Fatal(t, ok) _az3.baseAuthz.Status = StatusValid b1, err := json.Marshal(az1) assert.FatalError(t, err) b2, err := json.Marshal(az2) assert.FatalError(t, err) b3, err := json.Marshal(az3) assert.FatalError(t, err) count := 0 return test{ o: o, res: o, db: &db.MockNoSQLDB{ MGet: func(bucket, key []byte) ([]byte, error) { var ret []byte switch count { case 0: ret = b1 case 1: ret = ch1b case 2: ret = ch2b case 3: ret = ch3b case 4: ret = b2 case 5: ret = ch1b case 6: ret = ch2b case 7: ret = ch3b case 8: ret = b3 default: return nil, errors.New("unexpected count") } count++ return ret, nil }, MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, true, nil }, }, err: OrderNotReadyErr(errors.Errorf("order %s is not ready", o.ID)), } }, "fail/ready/csr-names-match-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "acme.example.com", }, DNSNames: []string{"acme.example.com", "fail.smallstep.com"}, } return test{ o: o, csr: csr, err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), } }, "fail/ready/csr-names-match-error-2": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "", }, DNSNames: []string{"acme.example.com"}, } return test{ o: o, csr: csr, err: BadCSRErr(errors.Errorf("CSR names do not match identifiers exactly")), } }, "fail/ready/no-ipAddresses": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "", }, DNSNames: []string{"acme.example.com", "step.example.com"}, IPAddresses: []net.IP{net.ParseIP("1.1.1.1")}, } return test{ o: o, csr: csr, err: BadCSRErr(errors.Errorf("CSR contains IP Address SANs, but should only contain DNS Names")), } }, "fail/ready/no-emailAddresses": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "", }, DNSNames: []string{"acme.example.com", "step.example.com"}, EmailAddresses: []string{"max@smallstep.com", "mariano@smallstep.com"}, } return test{ o: o, csr: csr, err: BadCSRErr(errors.Errorf("CSR contains Email Address SANs, but should only contain DNS Names")), } }, "fail/ready/no-URIs": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady u, err := url.Parse("https://google.com") assert.FatalError(t, err) csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "", }, DNSNames: []string{"acme.example.com", "step.example.com"}, URIs: []*url.URL{u}, } return test{ o: o, csr: csr, err: BadCSRErr(errors.Errorf("CSR contains URI SANs, but should only contain DNS Names")), } }, "fail/ready/provisioner-auth-sign-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "acme.example.com", }, DNSNames: []string{"step.example.com", "acme.example.com"}, } return test{ o: o, csr: csr, err: ServerInternalErr(errors.New("error retrieving authorization options from ACME provisioner: force")), prov: &MockProvisioner{ MauthorizeSign: func(ctx context.Context, token string) ([]provisioner.SignOption, error) { return nil, errors.New("force") }, }, } }, "fail/ready/sign-cert-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "acme.example.com", }, DNSNames: []string{"step.example.com", "acme.example.com"}, } return test{ o: o, csr: csr, err: ServerInternalErr(errors.Errorf("error generating certificate for order %s: force", o.ID)), sa: &mockSignAuth{ err: errors.New("force"), }, } }, "fail/ready/store-cert-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "acme.example.com", }, DNSNames: []string{"step.example.com", "acme.example.com"}, } crt := &x509.Certificate{ Subject: pkix.Name{ CommonName: "acme.example.com", }, } inter := &x509.Certificate{ Subject: pkix.Name{ CommonName: "intermediate", }, } return test{ o: o, csr: csr, err: ServerInternalErr(errors.Errorf("error storing certificate: force")), sa: &mockSignAuth{ ret1: crt, ret2: inter, }, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { return nil, false, errors.New("force") }, }, } }, "fail/ready/store-order-error": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "acme.example.com", }, DNSNames: []string{"acme.example.com", "step.example.com"}, } crt := &x509.Certificate{ Subject: pkix.Name{ CommonName: "acme.example.com", }, } inter := &x509.Certificate{ Subject: pkix.Name{ CommonName: "intermediate", }, } count := 0 return test{ o: o, csr: csr, err: ServerInternalErr(errors.Errorf("error storing order: force")), sa: &mockSignAuth{ ret1: crt, ret2: inter, }, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { if count == 1 { return nil, false, errors.New("force") } count++ return nil, true, nil }, }, } }, "ok/ready/sign": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "acme.example.com", }, DNSNames: []string{"acme.example.com", "step.example.com"}, } crt := &x509.Certificate{ Subject: pkix.Name{ CommonName: "acme.example.com", }, } inter := &x509.Certificate{ Subject: pkix.Name{ CommonName: "intermediate", }, } _o := *o clone := &_o clone.Status = StatusValid count := 0 return test{ o: o, res: clone, csr: csr, sa: &mockSignAuth{ sign: func(csr *x509.CertificateRequest, pops provisioner.Options, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, len(signOps), 5) return []*x509.Certificate{crt, inter}, nil }, }, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { if count == 0 { clone.Certificate = string(key) } count++ return nil, true, nil }, }, } }, "ok/ready/no-sans": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady o.Identifiers = []Identifier{ {Type: "dns", Value: "step.example.com"}, } csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "step.example.com", }, } crt := &x509.Certificate{ Subject: pkix.Name{ CommonName: "step.example.com", }, DNSNames: []string{"step.example.com"}, } inter := &x509.Certificate{ Subject: pkix.Name{ CommonName: "intermediate", }, } clone := *o clone.Status = StatusValid count := 0 return test{ o: o, res: &clone, csr: csr, sa: &mockSignAuth{ sign: func(csr *x509.CertificateRequest, pops provisioner.Options, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, len(signOps), 5) return []*x509.Certificate{crt, inter}, nil }, }, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { if count == 0 { clone.Certificate = string(key) } count++ return nil, true, nil }, }, } }, "ok/ready/sans-and-name": func(t *testing.T) test { o, err := newO() assert.FatalError(t, err) o.Status = StatusReady csr := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: "acme.example.com", }, DNSNames: []string{"step.example.com"}, } crt := &x509.Certificate{ Subject: pkix.Name{ CommonName: "acme.example.com", }, DNSNames: []string{"acme.example.com", "step.example.com"}, } inter := &x509.Certificate{ Subject: pkix.Name{ CommonName: "intermediate", }, } clone := *o clone.Status = StatusValid count := 0 return test{ o: o, res: &clone, csr: csr, sa: &mockSignAuth{ sign: func(csr *x509.CertificateRequest, pops provisioner.Options, signOps ...provisioner.SignOption) ([]*x509.Certificate, error) { assert.Equals(t, len(signOps), 5) return []*x509.Certificate{crt, inter}, nil }, }, db: &db.MockNoSQLDB{ MCmpAndSwap: func(bucket, key, old, newval []byte) ([]byte, bool, error) { if count == 0 { clone.Certificate = string(key) } count++ return nil, true, nil }, }, } }, } for name, run := range tests { t.Run(name, func(t *testing.T) { tc := run(t) p := tc.prov if p == nil { p = prov } o, err := tc.o.finalize(tc.db, tc.csr, tc.sa, p) if err != nil { if assert.NotNil(t, tc.err) { ae, ok := err.(*Error) assert.True(t, ok) assert.HasPrefix(t, ae.Error(), tc.err.Error()) assert.Equals(t, ae.StatusCode(), tc.err.StatusCode()) assert.Equals(t, ae.Type, tc.err.Type) } } else { if assert.Nil(t, tc.err) { expB, err := json.Marshal(tc.res) assert.FatalError(t, err) b, err := json.Marshal(o) assert.FatalError(t, err) assert.Equals(t, expB, b) } } }) } }