diff --git a/config/envoyconfig/filemgr/config.go b/config/envoyconfig/filemgr/config.go index 802d9046f..cd6d13994 100644 --- a/config/envoyconfig/filemgr/config.go +++ b/config/envoyconfig/filemgr/config.go @@ -1,10 +1,9 @@ package filemgr import ( - "os" "path/filepath" - "github.com/google/uuid" + "github.com/pomerium/pomerium/internal/fileutil" ) type config struct { @@ -23,11 +22,7 @@ func WithCacheDir(cacheDir string) Option { func newConfig(options ...Option) *config { cfg := new(config) - cacheDir, err := os.UserCacheDir() - if err != nil { - cacheDir = filepath.Join(os.TempDir(), uuid.New().String()) - } - WithCacheDir(filepath.Join(cacheDir, "pomerium", "envoy", "files"))(cfg) + WithCacheDir(filepath.Join(fileutil.CacheDir(), "envoy", "files"))(cfg) for _, o := range options { o(cfg) } diff --git a/config/envoyconfig/filemgr/filemgr.go b/config/envoyconfig/filemgr/filemgr.go index cd1ccdbb1..95d8162bf 100644 --- a/config/envoyconfig/filemgr/filemgr.go +++ b/config/envoyconfig/filemgr/filemgr.go @@ -8,6 +8,7 @@ import ( envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + "github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/log" ) @@ -45,7 +46,7 @@ func (mgr *Manager) BytesDataSource(fileName string, data []byte) *envoy_config_ filePath := filepath.Join(mgr.cfg.cacheDir, fileName) if _, err := os.Stat(filePath); os.IsNotExist(err) { - err = os.WriteFile(filePath, data, 0o600) + err = fileutil.WriteFileAtomically(filePath, data, 0o600) if err != nil { log.Error().Err(err).Msg("filemgr: error writing cache file, falling back to inline bytes") return inlineBytes(data) diff --git a/config/options.go b/config/options.go index c708b9c1d..cbf2ff2d8 100644 --- a/config/options.go +++ b/config/options.go @@ -13,7 +13,6 @@ import ( "net/http" "net/url" "os" - "path/filepath" "reflect" "strings" "time" @@ -29,6 +28,7 @@ import ( "github.com/pomerium/csrf" "github.com/pomerium/pomerium/config/otelconfig" "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/hashutil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" @@ -302,7 +302,7 @@ var defaultOptions = Options{ AuthenticateCallbackPath: "/oauth2/callback", AutocertOptions: AutocertOptions{ - Folder: dataDir(), + Folder: fileutil.DataDir(), }, DataBrokerStorageType: "memory", SkipXffAppend: false, @@ -1814,18 +1814,6 @@ func valueOrFromFileBase64(value string, valueFile string) *string { return &encoded } -func dataDir() string { - homeDir, _ := os.UserHomeDir() - if homeDir == "" { - homeDir = "." - } - baseDir := filepath.Join(homeDir, ".local", "share") - if xdgData := os.Getenv("XDG_DATA_HOME"); xdgData != "" { - baseDir = xdgData - } - return filepath.Join(baseDir, "pomerium") -} - func compareByteSliceSlice(a, b [][]byte) int { sz := min(len(a), len(b)) for i := 0; i < sz; i++ { diff --git a/internal/fileutil/atomic.go b/internal/fileutil/atomic.go new file mode 100644 index 000000000..c515f440e --- /dev/null +++ b/internal/fileutil/atomic.go @@ -0,0 +1,58 @@ +package fileutil + +import ( + "os" + "path/filepath" +) + +// WriteFileAtomically writes to a file path atomically. It does this by creating a temporary +// file in the same directory and then renaming it. If anything goes wrong the temporary +// file is deleted. +func WriteFileAtomically(filePath string, data []byte, mode os.FileMode) error { + f, err := os.CreateTemp(filepath.Dir(filePath), filepath.Base(filePath)+".tmp") + if err != nil { + return err + } + tmpPath := f.Name() + + err = writeFileAndClose(f, data, mode) + if err != nil { + _ = os.Remove(tmpPath) + return err + } + + err = os.Rename(tmpPath, filePath) + if err != nil { + _ = os.Remove(tmpPath) + return err + } + + return nil +} + +func writeFileAndClose(f *os.File, data []byte, mode os.FileMode) error { + _, err := f.Write(data) + if err != nil { + _ = f.Close() + return err + } + + err = f.Sync() + if err != nil { + _ = f.Close() + return err + } + + err = f.Chmod(mode) + if err != nil { + _ = f.Close() + return err + } + + err = f.Close() + if err != nil { + return err + } + + return nil +} diff --git a/internal/fileutil/atomic_test.go b/internal/fileutil/atomic_test.go new file mode 100644 index 000000000..b25bb07ed --- /dev/null +++ b/internal/fileutil/atomic_test.go @@ -0,0 +1,30 @@ +package fileutil_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/fileutil" +) + +func TestWriteFileAtomically(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + + require.NoError(t, fileutil.WriteFileAtomically(filepath.Join(dir, "temp1.txt"), []byte("TEST"), 0o600)) + + entries, err := os.ReadDir(dir) + require.NoError(t, err) + + names := make([]string, len(entries)) + for i := range entries { + names[i] = entries[i].Name() + } + + assert.Equal(t, []string{"temp1.txt"}, names) +} diff --git a/internal/fileutil/directories.go b/internal/fileutil/directories.go new file mode 100644 index 000000000..31985482a --- /dev/null +++ b/internal/fileutil/directories.go @@ -0,0 +1,36 @@ +package fileutil + +import ( + "os" + "path/filepath" + + "github.com/rs/zerolog/log" +) + +// CacheDir returns $XDG_CACHE_HOME/pomerium, or $HOME/.cache/pomerium, or /tmp/pomerium/cache +func CacheDir() string { + dir, err := os.UserCacheDir() + if err == nil { + dir = filepath.Join(dir, "pomerium") + } else { + dir = filepath.Join(os.TempDir(), "pomerium", "cache") + log.Error().Msgf("user cache directory not set, defaulting to %s", dir) + } + return dir +} + +// DataDir returns $XDG_DATA_HOME/pomerium, or $HOME/.local/share/pomerium, or /tmp/pomerium/data +func DataDir() string { + dir := os.Getenv("XDG_DATA_HOME") + if dir != "" { + dir = filepath.Join(dir, "pomerium") + } else { + if home, err := os.UserHomeDir(); err == nil { + dir = filepath.Join(home, ".local", "share", "pomerium") + } else { + dir = filepath.Join(os.TempDir(), "pomerium", "data") + } + log.Error().Msgf("user data directory not set, defaulting to %s", dir) + } + return dir +} diff --git a/internal/zero/cmd/env.go b/internal/zero/cmd/env.go index 1cdda2365..2c031bd3b 100644 --- a/internal/zero/cmd/env.go +++ b/internal/zero/cmd/env.go @@ -6,6 +6,8 @@ import ( "path/filepath" "github.com/spf13/viper" + + "github.com/pomerium/pomerium/internal/fileutil" ) const ( @@ -63,11 +65,7 @@ func getBootstrapConfigFileName() (string, error) { if filename := os.Getenv(BootstrapConfigFileName); filename != "" { return filename, nil } - cacheDir, err := os.UserCacheDir() - if err != nil { - return "", err - } - dir := filepath.Join(cacheDir, "pomerium") + dir := fileutil.CacheDir() if err := os.MkdirAll(dir, 0o700); err != nil { return "", fmt.Errorf("error creating cache directory: %w", err) }