diff --git a/hostfile.go b/hostfile.go index 04ca657..0efad13 100644 --- a/hostfile.go +++ b/hostfile.go @@ -4,8 +4,10 @@ import ( "fmt" "io/ioutil" "os" + "path/filepath" "runtime" "strings" + "time" ) const defaultOSX = ` @@ -190,15 +192,35 @@ func (h *Hostfile) Format() []byte { } // Save writes the Hostfile to disk to /etc/hosts or to the location specified -// by the HOSTESS_PATH environment variable (if set). +// by the HOSTESS_PATH environment variable (if set). We'll try to write a +// temporary file and then move it over the old file. This gives us an atomic +// update and ensures that if there is an I/O error or bug we dont truncate the +// hosts file and leave it empty. func (h *Hostfile) Save() error { - file, err := os.OpenFile(h.Path, os.O_RDWR|os.O_APPEND|os.O_TRUNC, 0644) + // TODO replace this with ioutil.TempFile + tempPath := filepath.Join(os.TempDir(), fmt.Sprintf("hostess.tmp.%d", time.Now().Unix())) + tempFile, err := os.OpenFile(tempPath, os.O_RDWR|os.O_CREATE, 0644) if err != nil { return err } - defer file.Close() - _, err = file.Write(h.Format()) + // Write a temp file + if _, err = tempFile.Write(h.Format()); err != nil { + tempFile.Close() + os.Remove(tempPath) + return err + } - return err + // Flush write buffers and close the file + if err := tempFile.Close(); err != nil { + os.Remove(tempPath) // cleanup code + return err + } + + // Atomic update + if err := os.Rename(tempPath, h.Path); err != nil { + return err + } + + return nil } diff --git a/hostfile_test.go b/hostfile_test.go index e144334..876b65c 100644 --- a/hostfile_test.go +++ b/hostfile_test.go @@ -2,9 +2,13 @@ package hostess_test import ( "fmt" + "io/ioutil" + "os" + "path/filepath" "runtime" "strings" "testing" + "time" "github.com/cbednarski/hostess" ) @@ -200,3 +204,31 @@ func TestLoadHostfile(t *testing.T) { t.Errorf("Expected to find %#v", hostname) } } + +// We're going to test saving the hosts file to a temporary path. In order to +// verify atomic save behavior we'll first write a fixture file and then save +// over it. +func TestSave(t *testing.T) { + // TODO replace this with ioutil.TempFile + tempPath := filepath.Join(os.TempDir(), fmt.Sprintf("hostess.test.%d", time.Now().Unix())) + fixturePath := filepath.Join("test-fixtures", "hostfile1") + + data, err := ioutil.ReadFile(fixturePath) + if err != nil { + t.Fatalf("Failed reading fixture file: %s", err) + } + + if err := ioutil.WriteFile(tempPath, data, 0644); err != nil { + t.Fatalf("Failed writing temporary hosts file: %s", err) + } + defer os.Remove(tempPath) + + if err := os.Setenv("HOSTESS_PATH", tempPath); err != nil { + t.Fatalf("Failed to set HOSTESS_PATH to %q: %s", tempPath, err) + } + hostfile, _ := hostess.LoadHostfile() + t.Log(hostfile.Path) + if err := hostfile.Save(); err != nil { + t.Fatalf("Failed saving hosts file %q: %s", tempPath, err) + } +} \ No newline at end of file