diff --git a/common/os.go b/common/os.go index 9dc81c579..b1e778977 100644 --- a/common/os.go +++ b/common/os.go @@ -93,28 +93,30 @@ func MustWriteFile(filePath string, contents []byte, mode os.FileMode) { } } -// Writes to newBytes to filePath. -// Guaranteed not to lose *both* oldBytes and newBytes, -// (assuming that the OS is perfect) +// WriteFileAtomic writes newBytes to temp and atomically moves to filePath +// when everything else succeeds. func WriteFileAtomic(filePath string, newBytes []byte, mode os.FileMode) error { - // If a file already exists there, copy to filePath+".bak" (overwrite anything) - if _, err := os.Stat(filePath); !os.IsNotExist(err) { - fileBytes, err := ioutil.ReadFile(filePath) - if err != nil { - return fmt.Errorf("Could not read file %v. %v", filePath, err) - } - err = ioutil.WriteFile(filePath+".bak", fileBytes, mode) - if err != nil { - return fmt.Errorf("Could not write file %v. %v", filePath+".bak", err) - } + f, err := ioutil.TempFile("", "") + if err != nil { + return err + } + _, err = f.Write(newBytes) + if err == nil { + err = f.Sync() + } + if closeErr := f.Close(); err == nil { + err = closeErr + } + if permErr := os.Chmod(f.Name(), mode); err == nil { + err = permErr + } + if err == nil { + err = os.Rename(f.Name(), filePath) } - // Write newBytes to filePath.new - err := ioutil.WriteFile(filePath+".new", newBytes, mode) + // any err should result in full cleanup if err != nil { - return fmt.Errorf("Could not write file %v. %v", filePath+".new", err) + os.Remove(f.Name()) } - // Move filePath.new to filePath - err = os.Rename(filePath+".new", filePath) return err } diff --git a/common/os_test.go b/common/os_test.go new file mode 100644 index 000000000..05359e36e --- /dev/null +++ b/common/os_test.go @@ -0,0 +1,29 @@ +package common + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "testing" + "time" +) + +func TestWriteFileAtomic(t *testing.T) { + data := []byte("Becatron") + fname := fmt.Sprintf("/tmp/write-file-atomic-test-%v.txt", time.Now().UnixNano()) + err := WriteFileAtomic(fname, data, 0664) + if err != nil { + t.Fatal(err) + } + rData, err := ioutil.ReadFile(fname) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, rData) { + t.Fatalf("data mismatch: %v != %v", data, rData) + } + if err := os.Remove(fname); err != nil { + t.Fatal(err) + } +}