diff --git a/commands.go b/commands.go index ec98535..b2c388e 100644 --- a/commands.go +++ b/commands.go @@ -51,7 +51,7 @@ func SaveOrPreview(options *Options, hostfile *hostess.Hostfile) error { } if err := hostfile.Save(); err != nil { - return fmt.Errorf("Unable to write to %s. Maybe you need to sudo? (error: %s)", hostess.GetHostsPath(), err) + return fmt.Errorf("Unable to write to %s. (error: %s)", hostess.GetHostsPath(), err) } return nil diff --git a/hostess/hostfile.go b/hostess/hostfile.go index d423390..281b338 100644 --- a/hostess/hostfile.go +++ b/hostess/hostfile.go @@ -196,7 +196,26 @@ func (h *Hostfile) Format() []byte { // Save writes the Hostfile to disk to /etc/hosts or to the location specified // by the HOSTESS_PATH environment variable (if set). func (h *Hostfile) Save() error { - file, err := os.OpenFile(h.Path, os.O_RDWR|os.O_APPEND|os.O_TRUNC, 0644) + var file *os.File + var err error + + // TODO break platform-specific code into separate functions + // Windows wants the file to be truncated before it's opened. Then we re- + // write the entire file contents. Truncating up front is risky but I don't + // know of a better way to do it. + if runtime.GOOS == "windows" { + if err := os.Truncate(h.Path, 0); err != nil { + return err + } + + file, err = os.OpenFile(h.Path, os.O_RDWR, 0644) + } else { + // TODO use atomic write-and-rename on Unix + // I think an earlier version of the program did this but it did not + // work on Windows so it was rolled back. We can probably get that code + // from history. + file, err = os.OpenFile(h.Path, os.O_RDWR|os.O_APPEND|os.O_TRUNC, 0644) + } if err != nil { return err } diff --git a/hostess/hostfile_test.go b/hostess/hostfile_test.go index 667343e..21244dc 100644 --- a/hostess/hostfile_test.go +++ b/hostess/hostfile_test.go @@ -2,11 +2,15 @@ package hostess_test import ( "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" "runtime" "strings" "testing" - hostess2 "github.com/cbednarski/hostess/hostess" + "github.com/cbednarski/hostess/hostess" ) const ipv4Pass = ` @@ -40,7 +44,7 @@ func Diff(expected, actual string) string { } func TestGetHostsPath(t *testing.T) { - path := hostess2.GetHostsPath() + path := hostess.GetHostsPath() var expected string if runtime.GOOS == "windows" { expected = "C:\\Windows\\System32\\drivers\\etc\\hosts" @@ -63,14 +67,14 @@ func TestFormatHostfile(t *testing.T) { 10.37.12.18 devsite.com m.devsite.com ` - hostfile := hostess2.NewHostfile() + hostfile := hostess.NewHostfile() hostfile.Path = "./hosts" - hostfile.Hosts.Add(hostess2.MustHostname("localhost", "127.0.0.1", true)) - hostfile.Hosts.Add(hostess2.MustHostname("ip-10-37-12-18", "127.0.1.1", true)) - hostfile.Hosts.Add(hostess2.MustHostname("devsite", "127.0.0.1", true)) - hostfile.Hosts.Add(hostess2.MustHostname("google.com", "8.8.8.8", false)) - hostfile.Hosts.Add(hostess2.MustHostname("devsite.com", "10.37.12.18", true)) - hostfile.Hosts.Add(hostess2.MustHostname("m.devsite.com", "10.37.12.18", true)) + hostfile.Hosts.Add(hostess.MustHostname("localhost", "127.0.0.1", true)) + hostfile.Hosts.Add(hostess.MustHostname("ip-10-37-12-18", "127.0.1.1", true)) + hostfile.Hosts.Add(hostess.MustHostname("devsite", "127.0.0.1", true)) + hostfile.Hosts.Add(hostess.MustHostname("google.com", "8.8.8.8", false)) + hostfile.Hosts.Add(hostess.MustHostname("devsite.com", "10.37.12.18", true)) + hostfile.Hosts.Add(hostess.MustHostname("m.devsite.com", "10.37.12.18", true)) f := string(hostfile.Format()) if f != expected { t.Errorf("Hostfile output is not formatted correctly: %s", Diff(expected, f)) @@ -81,7 +85,7 @@ func TestTrimWS(t *testing.T) { const expected = ` candy ` - actual := hostess2.TrimWS(expected) + actual := hostess.TrimWS(expected) if actual != "candy" { t.Errorf("Output was not trimmed correctly: %s", Diff(expected, actual)) } @@ -89,7 +93,7 @@ func TestTrimWS(t *testing.T) { func TestParseLineBlank(t *testing.T) { // Blank line - hosts, err := hostess2.ParseLine("") + hosts, err := hostess.ParseLine("") expected := "line is blank" if err.Error() != expected { t.Errorf("Expected error %q; found %q", expected, err.Error()) @@ -101,7 +105,7 @@ func TestParseLineBlank(t *testing.T) { func TestParseLineComment(t *testing.T) { // Comment - hosts, err := hostess2.ParseLine("# The following lines are desirable for IPv6 capable hosts") + hosts, err := hostess.ParseLine("# The following lines are desirable for IPv6 capable hosts") if err == nil { t.Error(err) } @@ -112,7 +116,7 @@ func TestParseLineComment(t *testing.T) { func TestParseLineOneWordComment(t *testing.T) { // Single word comment - hosts, err := hostess2.ParseLine("#blah") + hosts, err := hostess.ParseLine("#blah") if err != nil { t.Error(err) } @@ -122,23 +126,23 @@ func TestParseLineOneWordComment(t *testing.T) { } func TestParseLineBasicHostnameComment(t *testing.T) { - hosts, err := hostess2.ParseLine("#66.33.99.11 test.domain.com") + hosts, err := hostess.ParseLine("#66.33.99.11 test.domain.com") if err != nil { t.Error(err) } - if !hosts.Contains(hostess2.MustHostname("test.domain.com", "66.33.99.11", false)) || + if !hosts.Contains(hostess.MustHostname("test.domain.com", "66.33.99.11", false)) || len(hosts) != 1 { t.Error("Expected to find test.domain.com (disabled)") } } func TestParseLineMultiHostnameComment(t *testing.T) { - hosts, err := hostess2.ParseLine("# 66.33.99.11 test.domain.com domain.com") + hosts, err := hostess.ParseLine("# 66.33.99.11 test.domain.com domain.com") if err != nil { t.Error(err) } - if !hosts.Contains(hostess2.MustHostname("test.domain.com", "66.33.99.11", false)) || - !hosts.Contains(hostess2.MustHostname("domain.com", "66.33.99.11", false)) || + if !hosts.Contains(hostess.MustHostname("test.domain.com", "66.33.99.11", false)) || + !hosts.Contains(hostess.MustHostname("domain.com", "66.33.99.11", false)) || len(hosts) != 2 { t.Error("Expected to find domain.com and test.domain.com (disabled)") t.Errorf("Found %+v", hosts) @@ -147,13 +151,13 @@ func TestParseLineMultiHostnameComment(t *testing.T) { func TestParseLineMultiHostname(t *testing.T) { // Not Commented stuff - hosts, err := hostess2.ParseLine("255.255.255.255 broadcasthost test.domain.com domain.com") + hosts, err := hostess.ParseLine("255.255.255.255 broadcasthost test.domain.com domain.com") if err != nil { t.Error(err) } - if !hosts.Contains(hostess2.MustHostname("broadcasthost", "255.255.255.255", true)) || - !hosts.Contains(hostess2.MustHostname("test.domain.com", "255.255.255.255", true)) || - !hosts.Contains(hostess2.MustHostname("domain.com", "255.255.255.255", true)) || + if !hosts.Contains(hostess.MustHostname("broadcasthost", "255.255.255.255", true)) || + !hosts.Contains(hostess.MustHostname("test.domain.com", "255.255.255.255", true)) || + !hosts.Contains(hostess.MustHostname("domain.com", "255.255.255.255", true)) || len(hosts) != 3 { t.Error("Expected to find broadcasthost, domain.com, and test.domain.com (enabled)") } @@ -161,29 +165,29 @@ func TestParseLineMultiHostname(t *testing.T) { func TestParseLineIPv6A(t *testing.T) { // Ipv6 stuff - hosts, err := hostess2.ParseLine("::1 localhost") + hosts, err := hostess.ParseLine("::1 localhost") if err != nil { t.Error(err) } - if !hosts.Contains(hostess2.MustHostname("localhost", "::1", true)) || + if !hosts.Contains(hostess.MustHostname("localhost", "::1", true)) || len(hosts) != 1 { t.Error("Expected to find localhost ipv6 (enabled)") } } func TestParseLineIPv6B(t *testing.T) { - hosts, err := hostess2.ParseLine("ff02::1 ip6-allnodes") + hosts, err := hostess.ParseLine("ff02::1 ip6-allnodes") if err != nil { t.Error(err) } - if !hosts.Contains(hostess2.MustHostname("ip6-allnodes", "ff02::1", true)) || + if !hosts.Contains(hostess.MustHostname("ip6-allnodes", "ff02::1", true)) || len(hosts) != 1 { t.Error("Expected to find ip6-allnodes ipv6 (enabled)") } } func TestLoadHostfile(t *testing.T) { - hostfile := hostess2.NewHostfile() + hostfile := hostess.NewHostfile() hostfile.Read() if !strings.Contains(string(hostfile.GetData()), domain) { t.Errorf("Expected to find %s", domain) @@ -194,9 +198,36 @@ func TestLoadHostfile(t *testing.T) { if runtime.GOOS == "windows" { on = false } - hostname := hostess2.MustHostname(domain, ip, on) + hostname := hostess.MustHostname(domain, ip, on) found := hostfile.Hosts.Contains(hostname) if !found { t.Errorf("Expected to find %#v", hostname) } } + +func TestSaveHostfile(t *testing.T) { + fixture, err := os.Open(filepath.Join("testdata", "hostfile1")) + if err != nil { + t.Fatal(err) + } + + tempfile, err := ioutil.TempFile("", "hostess-test-*") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tempfile.Name()) + + if _, err := io.Copy(tempfile, fixture); err != nil { + t.Fatal(err) + } + + hostfile := hostess.NewHostfile() + hostfile.Path = tempfile.Name() + if err := hostfile.Read(); err != nil { + t.Fatal(err) + } + + if err := hostfile.Save(); err != nil { + t.Fatal(err) + } +}