diff --git a/authority/ssh.go b/authority/ssh.go index d5d6ce45..1b1645e4 100644 --- a/authority/ssh.go +++ b/authority/ssh.go @@ -102,10 +102,14 @@ func (a *Authority) GetSSHConfig(ctx context.Context, typ string, data map[strin return nil, err } - // Backwards compatibility for version of the cli older than v0.18.0 - if o.Name == "step_includes.tpl" && (data == nil || data[templates.SSHTemplateVersionKey] != "v2") { - o.Type = templates.File - o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/") + // Backwards compatibility for version of the cli older than v0.18.0. + // Before v0.18.0 we were not passing any value for SSHTemplateVersionKey + // from the cli. + if o.Name == "step_includes.tpl" { + if val, ok := data[templates.SSHTemplateVersionKey]; !ok || val == "" { + o.Type = templates.File + o.Path = strings.TrimPrefix(o.Path, "${STEPPATH}/") + } } output = append(output, o) diff --git a/authority/ssh_test.go b/authority/ssh_test.go index 41df8576..994d015f 100644 --- a/authority/ssh_test.go +++ b/authority/ssh_test.go @@ -501,6 +501,32 @@ func TestAuthority_GetSSHConfig(t *testing.T) { {Name: "sshd_config.tpl", Type: templates.File, Comment: "#", Path: "/etc/ssh/sshd_config", Content: []byte("Match all\n\tTrustedUserCAKeys /etc/ssh/ca.pub\n\tHostCertificate /etc/ssh/ssh_host_ecdsa_key-cert.pub\n\tHostKey /etc/ssh/ssh_host_ecdsa_key")}, } + tmplConfigUserIncludes := &templates.Templates{ + SSH: &templates.SSHTemplates{ + User: []templates.Template{ + {Name: "step_includes.tpl", Type: templates.PrependLine, TemplatePath: "./testdata/templates/step_includes.tpl", Path: "${STEPPATH}/ssh/includes", Comment: "#"}, + }, + }, + Data: map[string]interface{}{ + "Step": &templates.Step{ + SSH: templates.StepSSH{ + UserKey: user, + HostKey: host, + }, + }, + }, + } + + userOutputEmptyData := []templates.Output{ + {Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/ssh/config\"\n")}, + } + userOutputWithoutTemplateVersion := []templates.Output{ + {Name: "step_includes.tpl", Type: templates.File, Comment: "#", Path: "ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")}, + } + userOutputWithTemplateVersion := []templates.Output{ + {Name: "step_includes.tpl", Type: templates.PrependLine, Comment: "#", Path: "${STEPPATH}/ssh/includes", Content: []byte("Include \"/home/user/.step/ssh/config\"\n")}, + } + tmplConfigErr := &templates.Templates{ SSH: &templates.SSHTemplates{ User: []templates.Template{ @@ -542,6 +568,9 @@ func TestAuthority_GetSSHConfig(t *testing.T) { {"host", fields{tmplConfig, nil, hostSigner}, args{"host", nil}, hostOutput, false}, {"userWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithUserData, false}, {"hostWithData", fields{tmplConfigWithUserData, userSigner, hostSigner}, args{"host", map[string]string{"Certificate": "ssh_host_ecdsa_key-cert.pub", "Key": "ssh_host_ecdsa_key"}}, hostOutputWithUserData, false}, + {"userIncludesEmptyData", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", nil}, userOutputEmptyData, false}, + {"userIncludesWithoutTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step"}}, userOutputWithoutTemplateVersion, false}, + {"userIncludesWithTemplateVersion", fields{tmplConfigUserIncludes, userSigner, hostSigner}, args{"user", map[string]string{"StepPath": "/home/user/.step", "StepSSHTemplateVersion": "v2"}}, userOutputWithTemplateVersion, false}, {"disabled", fields{tmplConfig, nil, nil}, args{"host", nil}, nil, true}, {"badType", fields{tmplConfig, userSigner, hostSigner}, args{"bad", nil}, nil, true}, {"userError", fields{tmplConfigErr, userSigner, hostSigner}, args{"user", nil}, nil, true}, diff --git a/authority/testdata/templates/step_includes.tpl b/authority/testdata/templates/step_includes.tpl new file mode 100644 index 00000000..8c481bd8 --- /dev/null +++ b/authority/testdata/templates/step_includes.tpl @@ -0,0 +1 @@ +{{- if or .User.GOOS "none" | eq "windows" }}Include "{{ .User.StepPath | replace "\\" "/" | trimPrefix "C:" }}/ssh/config"{{- else }}Include "{{.User.StepPath}}/ssh/config"{{- end }} diff --git a/ca/identity/client.go b/ca/identity/client.go index 6f862115..97389b2d 100644 --- a/ca/identity/client.go +++ b/ca/identity/client.go @@ -10,7 +10,6 @@ import ( "net/url" "github.com/pkg/errors" - "go.step.sm/cli-utils/step" ) // Client wraps http.Client with a transport using the step root and identity. @@ -28,7 +27,7 @@ func (c *Client) ResolveReference(ref *url.URL) *url.URL { // $STEPPATH/config/defaults.json and the identity defined in // $STEPPATH/config/identity.json func LoadClient() (*Client, error) { - defaultsFile := step.DefaultsFile() + defaultsFile := DefaultsFile() b, err := ioutil.ReadFile(defaultsFile) if err != nil { return nil, errors.Wrapf(err, "error reading %s", defaultsFile) @@ -54,7 +53,7 @@ func LoadClient() (*Client, error) { return nil, err } if err := identity.Validate(); err != nil { - return nil, errors.Wrapf(err, "error validating %s", step.IdentityFile()) + return nil, errors.Wrapf(err, "error validating %s", IdentityFile()) } if kind := identity.Kind(); kind != MutualTLS { return nil, errors.Errorf("unsupported identity %s: only mTLS is currently supported", kind) diff --git a/ca/identity/client_test.go b/ca/identity/client_test.go index 402ec7b8..40a35766 100644 --- a/ca/identity/client_test.go +++ b/ca/identity/client_test.go @@ -11,6 +11,12 @@ import ( "testing" ) +func returnInput(val string) func() string { + return func() string { + return val + } +} + func TestClient(t *testing.T) { oldIdentityFile := IdentityFile oldDefaultsFile := DefaultsFile @@ -19,8 +25,8 @@ func TestClient(t *testing.T) { DefaultsFile = oldDefaultsFile }() - IdentityFile = "testdata/config/identity.json" - DefaultsFile = "testdata/config/defaults.json" + IdentityFile = returnInput("testdata/config/identity.json") + DefaultsFile = returnInput("testdata/config/defaults.json") client, err := LoadClient() if err != nil { @@ -140,36 +146,36 @@ func TestLoadClient(t *testing.T) { wantErr bool }{ {"ok", func() { - IdentityFile = "testdata/config/identity.json" - DefaultsFile = "testdata/config/defaults.json" + IdentityFile = returnInput("testdata/config/identity.json") + DefaultsFile = returnInput("testdata/config/defaults.json") }, expected, false}, {"fail identity", func() { - IdentityFile = "testdata/config/missing.json" - DefaultsFile = "testdata/config/defaults.json" + IdentityFile = returnInput("testdata/config/missing.json") + DefaultsFile = returnInput("testdata/config/defaults.json") }, nil, true}, {"fail identity", func() { - IdentityFile = "testdata/config/fail.json" - DefaultsFile = "testdata/config/defaults.json" + IdentityFile = returnInput("testdata/config/fail.json") + DefaultsFile = returnInput("testdata/config/defaults.json") }, nil, true}, {"fail defaults", func() { - IdentityFile = "testdata/config/identity.json" - DefaultsFile = "testdata/config/missing.json" + IdentityFile = returnInput("testdata/config/identity.json") + DefaultsFile = returnInput("testdata/config/missing.json") }, nil, true}, {"fail defaults", func() { - IdentityFile = "testdata/config/identity.json" - DefaultsFile = "testdata/config/fail.json" + IdentityFile = returnInput("testdata/config/identity.json") + DefaultsFile = returnInput("testdata/config/fail.json") }, nil, true}, {"fail ca", func() { - IdentityFile = "testdata/config/identity.json" - DefaultsFile = "testdata/config/badca.json" + IdentityFile = returnInput("testdata/config/identity.json") + DefaultsFile = returnInput("testdata/config/badca.json") }, nil, true}, {"fail root", func() { - IdentityFile = "testdata/config/identity.json" - DefaultsFile = "testdata/config/badroot.json" + IdentityFile = returnInput("testdata/config/identity.json") + DefaultsFile = returnInput("testdata/config/badroot.json") }, nil, true}, {"fail type", func() { - IdentityFile = "testdata/config/badIdentity.json" - DefaultsFile = "testdata/config/defaults.json" + IdentityFile = returnInput("testdata/config/badIdentity.json") + DefaultsFile = returnInput("testdata/config/defaults.json") }, nil, true}, } for _, tt := range tests { diff --git a/ca/identity/identity.go b/ca/identity/identity.go index c294d982..c9bf765d 100644 --- a/ca/identity/identity.go +++ b/ca/identity/identity.go @@ -39,6 +39,19 @@ const TunnelTLS Type = "tTLS" // DefaultLeeway is the duration for matching not before claims. const DefaultLeeway = 1 * time.Minute +var ( + identityDir = step.IdentityPath + configDir = step.ConfigPath + + // IdentityFile contains a pointer to a function that outputs the location of + // the identity file. + IdentityFile = step.IdentityFile + + // DefaultsFile contains a prointer a function that outputs the location of the + // defaults configuration file. + DefaultsFile = step.DefaultsFile +) + // Identity represents the identity file that can be used to authenticate with // the CA. type Identity struct { @@ -68,25 +81,17 @@ func LoadIdentity(filename string) (*Identity, error) { // LoadDefaultIdentity loads the default identity. func LoadDefaultIdentity() (*Identity, error) { - return LoadIdentity(step.IdentityFile()) -} - -func profileConfigDir() string { - return filepath.Join(step.Path(), "config") -} - -func profileIdentityDir() string { - return filepath.Join(step.Path(), "identity") + return LoadIdentity(IdentityFile()) } // WriteDefaultIdentity writes the given certificates and key and the // identity.json pointing to the new files. func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) error { - if err := os.MkdirAll(profileConfigDir(), 0700); err != nil { + if err := os.MkdirAll(configDir(), 0700); err != nil { return errors.Wrap(err, "error creating config directory") } - identityDir := profileIdentityDir() + identityDir := identityDir() if err := os.MkdirAll(identityDir, 0700); err != nil { return errors.Wrap(err, "error creating identity directory") } @@ -123,7 +128,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er }); err != nil { return errors.Wrap(err, "error writing identity json") } - if err := ioutil.WriteFile(step.IdentityFile(), buf.Bytes(), 0600); err != nil { + if err := ioutil.WriteFile(IdentityFile(), buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") } @@ -132,7 +137,7 @@ func WriteDefaultIdentity(certChain []api.Certificate, key crypto.PrivateKey) er // WriteIdentityCertificate writes the identity certificate to disk. func WriteIdentityCertificate(certChain []api.Certificate) error { - filename := filepath.Join(profileIdentityDir(), "identity.crt") + filename := filepath.Join(identityDir(), "identity.crt") return writeCertificate(filename, certChain) } @@ -315,7 +320,7 @@ func (i *Identity) Renew(client Renewer) error { return errors.Wrap(err, "error encoding identity certificate") } } - certFilename := filepath.Join(profileIdentityDir(), "identity.crt") + certFilename := filepath.Join(identityDir(), "identity.crt") if err := ioutil.WriteFile(certFilename, buf.Bytes(), 0600); err != nil { return errors.Wrap(err, "error writing identity certificate") } diff --git a/ca/identity/identity_test.go b/ca/identity/identity_test.go index ce64768c..d3b1d541 100644 --- a/ca/identity/identity_test.go +++ b/ca/identity/identity_test.go @@ -33,9 +33,9 @@ func TestLoadDefaultIdentity(t *testing.T) { want *Identity wantErr bool }{ - {"ok", func() { IdentityFile = "testdata/config/identity.json" }, expected, false}, - {"fail read", func() { IdentityFile = "testdata/config/missing.json" }, nil, true}, - {"fail unmarshal", func() { IdentityFile = "testdata/config/fail.json" }, nil, true}, + {"ok", func() { IdentityFile = returnInput("testdata/config/identity.json") }, expected, false}, + {"fail read", func() { IdentityFile = returnInput("testdata/config/missing.json") }, nil, true}, + {"fail unmarshal", func() { IdentityFile = returnInput("testdata/config/fail.json") }, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -217,9 +217,9 @@ func TestWriteDefaultIdentity(t *testing.T) { certChain = append(certChain, api.Certificate{Certificate: c}) } - configDir = filepath.Join(tmpDir, "config") - identityDir = filepath.Join(tmpDir, "identity") - IdentityFile = filepath.Join(tmpDir, "config", "identity.json") + configDir = returnInput(filepath.Join(tmpDir, "config")) + identityDir = returnInput(filepath.Join(tmpDir, "identity")) + IdentityFile = returnInput(filepath.Join(tmpDir, "config", "identity.json")) type args struct { certChain []api.Certificate @@ -233,27 +233,27 @@ func TestWriteDefaultIdentity(t *testing.T) { }{ {"ok", func() {}, args{certChain, key}, false}, {"fail mkdir config", func() { - configDir = filepath.Join(tmpDir, "identity", "identity.crt") - identityDir = filepath.Join(tmpDir, "identity") + configDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt")) + identityDir = returnInput(filepath.Join(tmpDir, "identity")) }, args{certChain, key}, true}, {"fail mkdir identity", func() { - configDir = filepath.Join(tmpDir, "config") - identityDir = filepath.Join(tmpDir, "identity", "identity.crt") + configDir = returnInput(filepath.Join(tmpDir, "config")) + identityDir = returnInput(filepath.Join(tmpDir, "identity", "identity.crt")) }, args{certChain, key}, true}, {"fail certificate", func() { - configDir = filepath.Join(tmpDir, "config") - identityDir = filepath.Join(tmpDir, "bad-dir") - os.MkdirAll(identityDir, 0600) + configDir = returnInput(filepath.Join(tmpDir, "config")) + identityDir = returnInput(filepath.Join(tmpDir, "bad-dir")) + os.MkdirAll(identityDir(), 0600) }, args{certChain, key}, true}, {"fail key", func() { - configDir = filepath.Join(tmpDir, "config") - identityDir = filepath.Join(tmpDir, "identity") + configDir = returnInput(filepath.Join(tmpDir, "config")) + identityDir = returnInput(filepath.Join(tmpDir, "identity")) }, args{certChain, "badKey"}, true}, {"fail write identity", func() { - configDir = filepath.Join(tmpDir, "bad-dir") - identityDir = filepath.Join(tmpDir, "identity") - IdentityFile = filepath.Join(configDir, "identity.json") - os.MkdirAll(configDir, 0600) + configDir = returnInput(filepath.Join(tmpDir, "bad-dir")) + identityDir = returnInput(filepath.Join(tmpDir, "identity")) + IdentityFile = returnInput(filepath.Join(configDir(), "identity.json")) + os.MkdirAll(configDir(), 0600) }, args{certChain, key}, true}, } @@ -377,7 +377,7 @@ func TestIdentity_Renew(t *testing.T) { } oldIdentityDir := identityDir - identityDir = "testdata/identity" + identityDir = returnInput("testdata/identity") defer func() { identityDir = oldIdentityDir os.RemoveAll(tmpDir) @@ -432,8 +432,8 @@ func TestIdentity_Renew(t *testing.T) { {"fail renew", func() {}, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{fail}, true}, {"fail certificate", func() {}, fields{"mTLS", "testdata/certs/server.crt", "testdata/identity/identity_key"}, args{ok}, true}, {"fail write identity", func() { - identityDir = filepath.Join(tmpDir, "bad-dir") - os.MkdirAll(identityDir, 0600) + identityDir = returnInput(filepath.Join(tmpDir, "bad-dir")) + os.MkdirAll(identityDir(), 0600) }, fields{"mTLS", "testdata/identity/identity.crt", "testdata/identity/identity_key"}, args{ok}, true}, } for _, tt := range tests {