diff --git a/authority/authority.go b/authority/authority.go index 8c60fffc..5a0cf1ab 100644 --- a/authority/authority.go +++ b/authority/authority.go @@ -17,7 +17,7 @@ const legacyAuthority = "step-certificate-authority" // Authority implements the Certificate Authority internal interface. type Authority struct { config *Config - rootX509Crt *x509.Certificate + rootX509Certs []*x509.Certificate intermediateIdentity *x509util.Identity validateOnce bool certificates *sync.Map @@ -79,15 +79,19 @@ func (a *Authority) init() error { } var err error - // First load the root using our modified pem/x509 package. - a.rootX509Crt, err = pemutil.ReadCertificate(a.config.Root) - if err != nil { - return err - } - // Add root certificate to the certificate map - sum := sha256.Sum256(a.rootX509Crt.Raw) - a.certificates.Store(hex.EncodeToString(sum[:]), a.rootX509Crt) + // Load the root certificates and add them to the certificate store + a.rootX509Certs = make([]*x509.Certificate, len(a.config.Root)) + for i, path := range a.config.Root { + crt, err := pemutil.ReadCertificate(path) + if err != nil { + return err + } + // Add root certificate to the certificate map + sum := sha256.Sum256(crt.Raw) + a.certificates.Store(hex.EncodeToString(sum[:]), crt) + a.rootX509Certs[i] = crt + } // Add federated roots for _, path := range a.config.FederatedRoots { diff --git a/authority/authority_test.go b/authority/authority_test.go index ad2f4980..1020f808 100644 --- a/authority/authority_test.go +++ b/authority/authority_test.go @@ -38,7 +38,7 @@ func testAuthority(t *testing.T) *Authority { } c := &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.ca.smallstep.com"}, @@ -68,7 +68,7 @@ func TestAuthorityNew(t *testing.T) { "fail bad root": func(t *testing.T) *newTest { c, err := LoadConfiguration("../ca/testdata/ca.json") assert.FatalError(t, err) - c.Root = "foo" + c.Root = []string{"foo"} return &newTest{ config: c, err: errors.New("open foo failed: no such file or directory"), @@ -105,10 +105,10 @@ func TestAuthorityNew(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - sum := sha256.Sum256(auth.rootX509Crt.Raw) + sum := sha256.Sum256(auth.rootX509Certs[0].Raw) root, ok := auth.certificates.Load(hex.EncodeToString(sum[:])) assert.Fatal(t, ok) - assert.Equals(t, auth.rootX509Crt, root) + assert.Equals(t, auth.rootX509Certs[0], root) assert.True(t, auth.initOnce) assert.NotNil(t, auth.intermediateIdentity) diff --git a/authority/config.go b/authority/config.go index 0ffdbef0..f19fb202 100644 --- a/authority/config.go +++ b/authority/config.go @@ -35,7 +35,7 @@ var ( // Config represents the CA configuration and it's mapped to a JSON object. type Config struct { - Root string `json:"root"` + Root multiString `json:"root"` FederatedRoots []string `json:"federatedRoots"` IntermediateCert string `json:"crt"` IntermediateKey string `json:"key"` @@ -117,7 +117,7 @@ func (c *Config) Validate() error { case c.Address == "": return errors.New("address cannot be empty") - case c.Root == "": + case c.Root.Empties(): return errors.New("root cannot be empty") case c.IntermediateCert == "": diff --git a/authority/config_test.go b/authority/config_test.go index c16b4780..01cea2a1 100644 --- a/authority/config_test.go +++ b/authority/config_test.go @@ -40,7 +40,7 @@ func TestConfigValidate(t *testing.T) { "empty-address": func(t *testing.T) ConfigValidateTest { return ConfigValidateTest{ config: &Config{ - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -54,7 +54,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -81,7 +81,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, Password: "pass", @@ -94,7 +94,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", DNSNames: []string{"test.smallstep.com"}, Password: "pass", @@ -107,7 +107,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", Password: "pass", @@ -120,7 +120,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -134,7 +134,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -149,7 +149,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, @@ -178,7 +178,7 @@ func TestConfigValidate(t *testing.T) { return ConfigValidateTest{ config: &Config{ Address: "127.0.0.1:443", - Root: "testdata/secrets/root_ca.crt", + Root: []string{"testdata/secrets/root_ca.crt"}, IntermediateCert: "testdata/secrets/intermediate_ca.crt", IntermediateKey: "testdata/secrets/intermediate_ca_key", DNSNames: []string{"test.smallstep.com"}, diff --git a/authority/root.go b/authority/root.go index 01710db8..d041ae8f 100644 --- a/authority/root.go +++ b/authority/root.go @@ -25,7 +25,12 @@ func (a *Authority) Root(sum string) (*x509.Certificate, error) { // GetRootCertificate returns the server root certificate. func (a *Authority) GetRootCertificate() *x509.Certificate { - return a.rootX509Crt + return a.rootX509Certs[0] +} + +// GetRootCertificates returns the server root certificates. +func (a *Authority) GetRootCertificates() []*x509.Certificate { + return a.rootX509Certs } // GetFederation returns all the root certificates in the federation. diff --git a/authority/root_test.go b/authority/root_test.go index 0db7e866..d9803d8e 100644 --- a/authority/root_test.go +++ b/authority/root_test.go @@ -37,7 +37,7 @@ func TestRoot(t *testing.T) { } } else { if assert.Nil(t, tc.err) { - assert.Equals(t, crt, a.rootX509Crt) + assert.Equals(t, crt, a.rootX509Certs[0]) } } }) diff --git a/authority/types.go b/authority/types.go index ec8f0d7b..50b632d6 100644 --- a/authority/types.go +++ b/authority/types.go @@ -36,11 +36,10 @@ func (d *duration) UnmarshalJSON(data []byte) (err error) { return } +// multiString represents a type that can be encoded/decoded in JSON as a single +// string or an array of strings. type multiString []string -// FIXME: remove me, avoids deadcode warning -var _ = multiString{} - // First returns the first element of a multiString. It will return an empty // string if the multistring is empty. func (s multiString) First() string { @@ -69,20 +68,24 @@ func (s multiString) Empties() bool { func (s multiString) MarshalJSON() ([]byte, error) { switch len(s) { case 0: - return []byte(""), nil + return []byte(`""`), nil case 1: return json.Marshal(s[0]) default: - return json.Marshal(s) + return json.Marshal([]string(s)) } } // UnmarshalJSON parses a string or a slice and sets it to the multiString. func (s *multiString) UnmarshalJSON(data []byte) error { + if s == nil { + return errors.New("multiString cannot be nil") + } if len(data) == 0 { *s = nil return nil } + // Parse string if data[0] == '"' { var str string if err := json.Unmarshal(data, &str); err != nil { @@ -91,8 +94,11 @@ func (s *multiString) UnmarshalJSON(data []byte) error { *s = []string{str} return nil } - if err := json.Unmarshal(data, s); err != nil { + // Parse array + var ss []string + if err := json.Unmarshal(data, &ss); err != nil { return errors.Wrapf(err, "error unmarshalling %s", data) } + *s = ss return nil } diff --git a/authority/types_test.go b/authority/types_test.go new file mode 100644 index 00000000..620751d3 --- /dev/null +++ b/authority/types_test.go @@ -0,0 +1,103 @@ +package authority + +import ( + "reflect" + "testing" +) + +func Test_multiString_First(t *testing.T) { + tests := []struct { + name string + s multiString + want string + }{ + {"empty", multiString{}, ""}, + {"string", multiString{"one"}, "one"}, + {"slice", multiString{"one", "two"}, "one"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.s.First(); got != tt.want { + t.Errorf("multiString.First() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiString_Empties(t *testing.T) { + tests := []struct { + name string + s multiString + want bool + }{ + {"empty", multiString{}, true}, + {"string", multiString{"one"}, false}, + {"empty string", multiString{""}, true}, + {"slice", multiString{"one", "two"}, false}, + {"empty slice", multiString{"one", ""}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.s.Empties(); got != tt.want { + t.Errorf("multiString.Empties() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiString_MarshalJSON(t *testing.T) { + tests := []struct { + name string + s multiString + want []byte + wantErr bool + }{ + {"empty", []string{}, []byte(`""`), false}, + {"string", []string{"a string"}, []byte(`"a string"`), false}, + {"slice", []string{"string one", "string two"}, []byte(`["string one","string two"]`), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.s.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("multiString.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("multiString.MarshalJSON() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_multiString_UnmarshalJSON(t *testing.T) { + + type args struct { + data []byte + } + tests := []struct { + name string + s *multiString + args args + want *multiString + wantErr bool + }{ + {"empty", new(multiString), args{[]byte{}}, new(multiString), false}, + {"empty string", new(multiString), args{[]byte(`""`)}, &multiString{""}, false}, + {"string", new(multiString), args{[]byte(`"a string"`)}, &multiString{"a string"}, false}, + {"slice", new(multiString), args{[]byte(`["string one","string two"]`)}, &multiString{"string one", "string two"}, false}, + {"error", new(multiString), args{[]byte(`["123",123]`)}, new(multiString), true}, + {"nil", nil, args{nil}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.s.UnmarshalJSON(tt.args.data); (err != nil) != tt.wantErr { + t.Errorf("multiString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(tt.s, tt.want) { + t.Errorf("multiString.UnmarshalJSON() = %v, want %v", tt.s, tt.want) + } + }) + } +} diff --git a/ca/ca.go b/ca/ca.go index 8f72984f..07ee3311 100644 --- a/ca/ca.go +++ b/ca/ca.go @@ -176,7 +176,9 @@ func (ca *CA) getTLSConfig(auth *authority.Authority) (*tls.Config, error) { } certPool := x509.NewCertPool() - certPool.AddCert(auth.GetRootCertificate()) + for _, crt := range auth.GetRootCertificates() { + certPool.AddCert(crt) + } // GetCertificate will only be called if the client supplies SNI // information or if tlsConfig.Certificates is empty.