diff --git a/go.mod b/go.mod index 112732107..101b0e23f 100644 --- a/go.mod +++ b/go.mod @@ -52,7 +52,7 @@ require ( github.com/pomerium/csrf v1.7.0 github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b - github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31 + github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05 github.com/prometheus/client_golang v1.16.0 github.com/prometheus/client_model v0.4.0 github.com/prometheus/common v0.44.0 diff --git a/go.sum b/go.sum index c06a03aaa..1c19d4b76 100644 --- a/go.sum +++ b/go.sum @@ -657,8 +657,8 @@ github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 h1:3YQY1sb5 github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0= github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b h1:oll/aOfJudnqFAwCvoXK9+WN2zVjTzHVPLXCggHQmHk= github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b/go.mod h1:KswTenBBh4y1pmhU2dpm8VgJQCgSErCg7OOFTeebrNc= -github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31 h1:FUoy3dpWd0ECmFHmYVUCJd7lVhQjglZsnrfjvRsVj4w= -github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31/go.mod h1:cAyfEGM8blUzchYhOWrufuj/6lOF277meB4c/TjMS28= +github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05 h1:Rl2df8q+DAd3SsJn9MpXrbo7JRNCDHVaohOyUZ2IJik= +github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05/go.mod h1:cAyfEGM8blUzchYhOWrufuj/6lOF277meB4c/TjMS28= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= diff --git a/internal/zero/reconciler/cloud_api.go b/internal/zero/reconciler/cloud_api.go deleted file mode 100644 index baf9fe485..000000000 --- a/internal/zero/reconciler/cloud_api.go +++ /dev/null @@ -1,22 +0,0 @@ -package reconciler - -import ( - "context" - "fmt" -) - -// GetBundles returns the list of bundles that have to be present in the cluster. -func (c *service) RefreshBundleList(ctx context.Context) error { - resp, err := c.config.api.GetClusterResourceBundles(ctx) - if err != nil { - return fmt.Errorf("get bundles: %w", err) - } - - ids := make([]string, 0, len(resp.Bundles)) - for _, v := range resp.Bundles { - ids = append(ids, v.Id) - } - - c.bundles.Set(ids) - return nil -} diff --git a/internal/zero/reconciler/config.go b/internal/zero/reconciler/config.go index 66654fd2c..d9f48546f 100644 --- a/internal/zero/reconciler/config.go +++ b/internal/zero/reconciler/config.go @@ -8,21 +8,17 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/databroker" sdk "github.com/pomerium/zero-sdk" - connect_mux "github.com/pomerium/zero-sdk/connect-mux" ) // reconcilerConfig contains the configuration for the resource bundles reconciler. type reconcilerConfig struct { - api *sdk.API - connectMux *connect_mux.Mux + api *sdk.API databrokerClient databroker.DataBrokerServiceClient databrokerRPS int tmpDir string - minDownloadTTL time.Duration - httpClient *http.Client checkForUpdateIntervalWhenDisconnected time.Duration @@ -49,13 +45,6 @@ func WithAPI(client *sdk.API) Option { } } -// WithConnectMux configures the connect mux. -func WithConnectMux(client *connect_mux.Mux) Option { - return func(cfg *reconcilerConfig) { - cfg.connectMux = client - } -} - // WithDataBrokerClient configures the databroker client. func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option { return func(cfg *reconcilerConfig) { @@ -63,14 +52,6 @@ func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option { } } -// WithMinDownloadTTL configures the download URL validity in cache, -// before it would be requested again from the cloud (that also sets it's own TTL to the signed URL). -func WithMinDownloadTTL(ttl time.Duration) Option { - return func(cfg *reconcilerConfig) { - cfg.minDownloadTTL = ttl - } -} - // WithDownloadHTTPClient configures the http client used for downloading files. func WithDownloadHTTPClient(client *http.Client) Option { return func(cfg *reconcilerConfig) { @@ -112,7 +93,6 @@ func newConfig(opts ...Option) *reconcilerConfig { cfg := &reconcilerConfig{} for _, opt := range []Option{ WithTemporaryDirectory(os.TempDir()), - WithMinDownloadTTL(5 * time.Minute), WithDownloadHTTPClient(http.DefaultClient), WithDatabrokerRPSLimit(1_000), WithCheckForUpdateIntervalWhenDisconnected(time.Minute * 5), diff --git a/internal/zero/reconciler/download.go b/internal/zero/reconciler/download.go deleted file mode 100644 index 1ff0c0e56..000000000 --- a/internal/zero/reconciler/download.go +++ /dev/null @@ -1,132 +0,0 @@ -package reconciler - -import ( - "context" - "fmt" - "io" - "net/http" - "time" -) - -type downloadOptions struct { - requestCacheEntry *BundleCacheEntry - responseCacheEntry *BundleCacheEntry -} - -// DownloadOption is an option for downloading a bundle -type DownloadOption func(*downloadOptions) - -// WithCacheEntry sets the cache entry to use for the request. -func WithCacheEntry(entry BundleCacheEntry) DownloadOption { - return func(opts *downloadOptions) { - opts.requestCacheEntry = &entry - } -} - -// WithUpdateCacheEntry updates the cache entry with the values from the response. -func WithUpdateCacheEntry(dst *BundleCacheEntry) DownloadOption { - return func(opts *downloadOptions) { - opts.responseCacheEntry = dst - } -} - -func getDownloadOptions(opts ...DownloadOption) downloadOptions { - var options downloadOptions - for _, opt := range opts { - opt(&options) - } - return options -} - -func (opt *downloadOptions) updateRequest(req *http.Request) { - if opt.requestCacheEntry != nil { - req.Header.Set("If-None-Match", opt.requestCacheEntry.ETag) - req.Header.Set("If-Modified-Since", opt.requestCacheEntry.LastModified.Format(http.TimeFormat)) - } -} - -func (opt *downloadOptions) updateFromResponse(resp *http.Response) error { - if opt.responseCacheEntry == nil { - return nil - } - - if resp.StatusCode == http.StatusNotModified && opt.requestCacheEntry != nil { - *opt.responseCacheEntry = *opt.requestCacheEntry - return nil - } - - if resp.StatusCode != http.StatusOK { - return nil - } - - return updateBundleCacheEntryFromResponse(opt.responseCacheEntry, resp.Header) -} - -// DownloadBundleIfChanged downloads the bundle if it has changed. -func (c *service) DownloadBundleIfChanged( - ctx context.Context, - dst io.Writer, - bundleKey string, - opts ...DownloadOption, -) error { - options := getDownloadOptions(opts...) - - url, err := c.config.api.DownloadClusterResourceBundle(ctx, bundleKey, c.config.minDownloadTTL) - if err != nil { - return fmt.Errorf("get download url: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) - if err != nil { - return fmt.Errorf("new request: %w", err) - } - options.updateRequest(req) - - resp, err := c.config.httpClient.Do(req) - if err != nil { - return fmt.Errorf("do request: %w", err) - } - defer resp.Body.Close() - - err = options.updateFromResponse(resp) - if err != nil { - return fmt.Errorf("response: %w", err) - } - - if resp.StatusCode == http.StatusNotModified { - return nil - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected response: %d/%s", resp.StatusCode, resp.Status) - } - - _, err = io.Copy(dst, resp.Body) - if err != nil { - return fmt.Errorf("write file: %w", err) - } - - return nil -} - -func updateBundleCacheEntryFromResponse(dst *BundleCacheEntry, headers http.Header) error { - txt := headers.Get("Last-Modified") - if txt == "" { - return fmt.Errorf("missing last-modified header") - } - - lastModified, err := time.Parse(http.TimeFormat, txt) - if err != nil { - return fmt.Errorf("parse last modified: %w", err) - } - - etag := headers.Get("ETag") - if etag == "" { - return fmt.Errorf("missing etag header") - } - - dst.LastModified = lastModified - dst.ETag = etag - - return nil -} diff --git a/internal/zero/reconciler/download_cache.go b/internal/zero/reconciler/download_cache.go index 5d99a8752..4d3803f36 100644 --- a/internal/zero/reconciler/download_cache.go +++ b/internal/zero/reconciler/download_cache.go @@ -3,7 +3,19 @@ package reconciler import ( "context" "errors" - "time" + "fmt" + + "github.com/hashicorp/go-multierror" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" + zero_sdk "github.com/pomerium/zero-sdk" ) // BundleCacheEntry is a cache entry for a bundle @@ -16,29 +28,139 @@ import ( // also it works in case of multiple instances, as it uses // the databroker database as a shared cache. type BundleCacheEntry struct { - ETag string - LastModified time.Time - RecordTypes []string + zero_sdk.DownloadConditional + RecordTypes []string } +const ( + bundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry" +) + var ( // ErrBundleCacheEntryNotFound is returned when a bundle cache entry is not found ErrBundleCacheEntryNotFound = errors.New("bundle cache entry not found") ) -// Equals returns true if the two cache entries are equal -func (c *BundleCacheEntry) Equals(other BundleCacheEntry) bool { - return c.ETag == other.ETag && c.LastModified.Equal(other.LastModified) -} - // GetBundleCacheEntry gets a bundle cache entry from the databroker -func (c *service) GetBundleCacheEntry(_ context.Context, _ string, _ *BundleCacheEntry) error { - // TODO: implement - return ErrBundleCacheEntryNotFound +func (c *service) GetBundleCacheEntry(ctx context.Context, id string) (*BundleCacheEntry, error) { + record, err := c.config.databrokerClient.Get(ctx, &databroker.GetRequest{ + Type: bundleCacheEntryRecordType, + Id: id, + }) + if err != nil && status.Code(err) == codes.NotFound { + return nil, ErrBundleCacheEntryNotFound + } else if err != nil { + return nil, fmt.Errorf("get bundle cache entry: %w", err) + } + + var dst BundleCacheEntry + data := record.GetRecord().GetData() + err = dst.FromAny(data) + if err != nil { + log.Ctx(ctx).Error().Err(err). + Str("bundle-id", id). + Str("data", protojson.Format(data)). + Msg("unmarshal bundle cache entry") + // we would allow it to be overwritten by the update process + return nil, ErrBundleCacheEntryNotFound + } + + return &dst, nil } // SetBundleCacheEntry sets a bundle cache entry in the databroker -func (c *service) SetBundleCacheEntry(_ context.Context, _ string, _ BundleCacheEntry) error { - // TODO: implement +func (c *service) SetBundleCacheEntry(ctx context.Context, id string, src BundleCacheEntry) error { + val, err := src.ToAny() + if err != nil { + return fmt.Errorf("marshal bundle cache entry: %w", err) + } + resp, err := c.config.databrokerClient.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{ + { + Type: bundleCacheEntryRecordType, + Id: id, + Data: val, + }, + }, + }) + if err != nil { + return fmt.Errorf("set bundle cache entry: %w", err) + } + log.Ctx(ctx).Info(). + Str("bundle-id", id). + Str("sent", protojson.Format(val)). + Str("got", protojson.Format(resp.GetRecord().GetData())). + Msg("set bundle cache entry") return nil } + +// ToAny marshals a BundleCacheEntry into an anypb.Any +func (r *BundleCacheEntry) ToAny() (*anypb.Any, error) { + err := r.Validate() + if err != nil { + return nil, fmt.Errorf("validate: %w", err) + } + + types := make([]*structpb.Value, 0, len(r.RecordTypes)) + for _, t := range r.RecordTypes { + types = append(types, structpb.NewStringValue(t)) + } + + return protoutil.NewAny(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + "etag": structpb.NewStringValue(r.ETag), + "last_modified": structpb.NewStringValue(r.LastModified), + "record_types": structpb.NewListValue(&structpb.ListValue{Values: types}), + }, + }), nil +} + +// FromAny unmarshals an anypb.Any into a BundleCacheEntry +func (r *BundleCacheEntry) FromAny(any *anypb.Any) error { + var s structpb.Struct + err := any.UnmarshalTo(&s) + if err != nil { + return fmt.Errorf("unmarshal struct: %w", err) + } + fmt.Println("****", protojson.Format(any), protojson.Format(&s)) + + r.ETag = s.GetFields()["etag"].GetStringValue() + r.LastModified = s.GetFields()["last_modified"].GetStringValue() + + for _, v := range s.GetFields()["record_types"].GetListValue().GetValues() { + r.RecordTypes = append(r.RecordTypes, v.GetStringValue()) + } + + err = r.Validate() + if err != nil { + return fmt.Errorf("validate: %w", err) + } + return nil +} + +// Validate validates a BundleCacheEntry +func (r *BundleCacheEntry) Validate() error { + var errs *multierror.Error + if len(r.RecordTypes) == 0 { + errs = multierror.Append(errs, errors.New("record_types is required")) + } + if err := r.DownloadConditional.Validate(); err != nil { + errs = multierror.Append(errs, err) + } + return errs.ErrorOrNil() +} + +// GetDownloadConditional returns conditional download information +func (r *BundleCacheEntry) GetDownloadConditional() *zero_sdk.DownloadConditional { + if r == nil { + return nil + } + cond := r.DownloadConditional + return &cond +} + +// Equals returns true if the two cache entries are equal +func (r *BundleCacheEntry) Equals(other *BundleCacheEntry) bool { + return r != nil && other != nil && + r.ETag == other.ETag && r.LastModified == other.LastModified +} diff --git a/internal/zero/reconciler/download_cache_test.go b/internal/zero/reconciler/download_cache_test.go new file mode 100644 index 000000000..f0d608a98 --- /dev/null +++ b/internal/zero/reconciler/download_cache_test.go @@ -0,0 +1,29 @@ +package reconciler_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/zero/reconciler" + zero_sdk "github.com/pomerium/zero-sdk" +) + +func TestCacheEntryProto(t *testing.T) { + t.Parallel() + + original := reconciler.BundleCacheEntry{ + DownloadConditional: zero_sdk.DownloadConditional{ + ETag: "etag value", + LastModified: "2009-02-13 18:31:30 -0500 EST", + }, + RecordTypes: []string{"one", "two"}, + } + originalProto, err := original.ToAny() + require.NoError(t, err) + var unmarshaled reconciler.BundleCacheEntry + err = unmarshaled.FromAny(originalProto) + require.NoError(t, err) + assert.True(t, original.Equals(&unmarshaled)) +} diff --git a/internal/zero/reconciler/service.go b/internal/zero/reconciler/service.go index 1925d0ed1..b10908d76 100644 --- a/internal/zero/reconciler/service.go +++ b/internal/zero/reconciler/service.go @@ -52,7 +52,7 @@ func Run(ctx context.Context, opts ...Option) error { // it may be later optimized by splitting between download and reconciliation process, // as we would get more resource bundles beyond the config. func (c *service) watchUpdates(ctx context.Context) error { - return c.config.connectMux.Watch(ctx, + return c.config.api.Watch(ctx, connect_mux.WithOnConnected(func(ctx context.Context) { c.triggerFullUpdate(true) }), diff --git a/internal/zero/reconciler/sync.go b/internal/zero/reconciler/sync.go index 42b11fee3..0e406d404 100644 --- a/internal/zero/reconciler/sync.go +++ b/internal/zero/reconciler/sync.go @@ -16,7 +16,6 @@ import ( "errors" "fmt" "io" - "os" "time" "github.com/pomerium/pomerium/internal/log" @@ -29,22 +28,27 @@ func (c *service) SyncLoop(ctx context.Context) error { defer ticker.Stop() for { - ticker.Reset(c.periodicUpdateInterval.Load()) + dur := c.periodicUpdateInterval.Load() + ticker.Reset(dur) + log.Ctx(ctx).Info().Str("duration", dur.String()).Msg("*** next sync cycle ***") select { case <-ctx.Done(): return ctx.Err() case <-c.bundleSyncRequest: + log.Ctx(ctx).Info().Msg("bundle sync triggered") err := c.syncBundles(ctx) if err != nil { return fmt.Errorf("reconciler: sync bundles: %w", err) } case <-c.fullSyncRequest: + log.Ctx(ctx).Info().Msg("full sync triggered") err := c.syncAll(ctx) if err != nil { return fmt.Errorf("reconciler: sync all: %w", err) } case <-ticker.C: + log.Ctx(ctx).Info().Msg("periodic sync triggered") err := c.syncAll(ctx) if err != nil { return fmt.Errorf("reconciler: sync all: %w", err) @@ -72,7 +76,7 @@ func (c *service) syncBundleList(ctx context.Context) error { // refresh bundle list, // ignoring other signals while we're retrying return retry.Retry(ctx, - "refresh bundle list", c.RefreshBundleList, + "refresh bundle list", c.refreshBundleList, retry.WithWatch("refresh bundle list", c.fullSyncRequest, nil), retry.WithWatch("bundle update", c.bundleSyncRequest, nil), ) @@ -83,7 +87,7 @@ func (c *service) syncBundleList(ctx context.Context) error { func (c *service) syncBundles(ctx context.Context) error { return retry.Retry(ctx, "sync bundles", c.trySyncBundles, - retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.RefreshBundleList), + retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.refreshBundleList), retry.WithWatch("bundle update", c.bundleSyncRequest, nil), ) } @@ -110,38 +114,41 @@ func (c *service) trySyncBundles(ctx context.Context) error { // This is only persisted in the databroker after all records are successfully synced. // That allows us to ignore any changes based on the same bundle state, without need to re-check all records between bundle and databroker. func (c *service) syncBundle(ctx context.Context, key string) error { - var cached, changed BundleCacheEntry - opts := []DownloadOption{ - WithUpdateCacheEntry(&changed), - } - - err := c.GetBundleCacheEntry(ctx, key, &cached) - if err == nil { - opts = append(opts, WithCacheEntry(cached)) - } else if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) { + cached, err := c.GetBundleCacheEntry(ctx, key) + if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) { return fmt.Errorf("get bundle cache entry: %w", err) } // download is much faster compared to databroker sync, // so we don't use pipe but rather download to a temp file and then sync it to databroker - - fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key)) + fd, err := c.GetTmpFile(key) if err != nil { - return fmt.Errorf("create temp file: %w", err) + return fmt.Errorf("get tmp file: %w", err) } - defer fd.Close() - defer os.Remove(fd.Name()) + defer func() { + if err := fd.Close(); err != nil { + log.Ctx(ctx).Error().Err(err).Msg("close tmp file") + } + }() - err = c.DownloadBundleIfChanged(ctx, fd, key, opts...) + conditional := cached.GetDownloadConditional() + log.Ctx(ctx).Info().Str("id", key).Any("conditional", conditional).Msg("downloading bundle") + + result, err := c.config.api.DownloadClusterResourceBundle(ctx, fd, key, conditional) if err != nil { return fmt.Errorf("download bundle: %w", err) } - if changed.Equals(cached) { + if result.NotModified { log.Ctx(ctx).Info().Str("bundle", key).Msg("bundle not changed") return nil } + log.Ctx(ctx).Info().Str("bundle", key). + Interface("cached-entry", cached). + Interface("current-entry", result.DownloadConditional). + Msg("bundle changed") + _, err = fd.Seek(0, io.SeekStart) if err != nil { return fmt.Errorf("seek to start: %w", err) @@ -151,16 +158,19 @@ func (c *service) syncBundle(ctx context.Context, key string) error { if err != nil { return fmt.Errorf("apply bundle to databroker: %w", err) } - changed.RecordTypes = bundleRecordTypes + current := BundleCacheEntry{ + DownloadConditional: *result.DownloadConditional, + RecordTypes: bundleRecordTypes, + } log.Ctx(ctx).Info(). Str("bundle", key). Strs("record_types", bundleRecordTypes). - Str("etag", changed.ETag). - Time("last_modified", changed.LastModified). + Str("etag", current.ETag). + Str("last_modified", current.LastModified). Msg("bundle synced") - err = c.SetBundleCacheEntry(ctx, key, changed) + err = c.SetBundleCacheEntry(ctx, key, current) if err != nil { return fmt.Errorf("set bundle cache entry: %w", err) } @@ -198,3 +208,18 @@ func (c *service) syncBundleToDatabroker(ctx context.Context, src io.Reader) ([] return bundleRecords.RecordTypes(), nil } + +func (c *service) refreshBundleList(ctx context.Context) error { + resp, err := c.config.api.GetClusterResourceBundles(ctx) + if err != nil { + return fmt.Errorf("get bundles: %w", err) + } + + ids := make([]string, 0, len(resp.Bundles)) + for _, v := range resp.Bundles { + ids = append(ids, v.Id) + } + + c.bundles.Set(ids) + return nil +} diff --git a/internal/zero/reconciler/tmpfile.go b/internal/zero/reconciler/tmpfile.go new file mode 100644 index 000000000..527186bd6 --- /dev/null +++ b/internal/zero/reconciler/tmpfile.go @@ -0,0 +1,40 @@ +package reconciler + +import ( + "fmt" + "io" + "os" + + "github.com/hashicorp/go-multierror" +) + +// ReadWriteSeekCloser is a file that can be read, written, seeked, and closed. +type ReadWriteSeekCloser interface { + io.ReadWriteSeeker + io.Closer +} + +// GetTmpFile returns a temporary file for the reconciler to use. +// TODO: encrypt contents to ensure encryption at rest +func (c *service) GetTmpFile(key string) (ReadWriteSeekCloser, error) { + fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key)) + if err != nil { + return nil, fmt.Errorf("create temp file: %w", err) + } + return &tmpFile{File: fd}, nil +} + +type tmpFile struct { + *os.File +} + +func (f *tmpFile) Close() error { + var errs *multierror.Error + if err := f.File.Close(); err != nil { + errs = multierror.Append(errs, err) + } + if err := os.Remove(f.File.Name()); err != nil { + errs = multierror.Append(errs, err) + } + return errs.ErrorOrNil() +}