mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-23 11:39:32 +02:00
use centralied api
This commit is contained in:
parent
4e5993488e
commit
01c17844df
10 changed files with 259 additions and 217 deletions
2
go.mod
2
go.mod
|
@ -52,7 +52,7 @@ require (
|
|||
github.com/pomerium/csrf v1.7.0
|
||||
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524
|
||||
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b
|
||||
github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31
|
||||
github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05
|
||||
github.com/prometheus/client_golang v1.16.0
|
||||
github.com/prometheus/client_model v0.4.0
|
||||
github.com/prometheus/common v0.44.0
|
||||
|
|
4
go.sum
4
go.sum
|
@ -657,8 +657,8 @@ github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 h1:3YQY1sb5
|
|||
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0=
|
||||
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b h1:oll/aOfJudnqFAwCvoXK9+WN2zVjTzHVPLXCggHQmHk=
|
||||
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b/go.mod h1:KswTenBBh4y1pmhU2dpm8VgJQCgSErCg7OOFTeebrNc=
|
||||
github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31 h1:FUoy3dpWd0ECmFHmYVUCJd7lVhQjglZsnrfjvRsVj4w=
|
||||
github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31/go.mod h1:cAyfEGM8blUzchYhOWrufuj/6lOF277meB4c/TjMS28=
|
||||
github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05 h1:Rl2df8q+DAd3SsJn9MpXrbo7JRNCDHVaohOyUZ2IJik=
|
||||
github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05/go.mod h1:cAyfEGM8blUzchYhOWrufuj/6lOF277meB4c/TjMS28=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GetBundles returns the list of bundles that have to be present in the cluster.
|
||||
func (c *service) RefreshBundleList(ctx context.Context) error {
|
||||
resp, err := c.config.api.GetClusterResourceBundles(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get bundles: %w", err)
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(resp.Bundles))
|
||||
for _, v := range resp.Bundles {
|
||||
ids = append(ids, v.Id)
|
||||
}
|
||||
|
||||
c.bundles.Set(ids)
|
||||
return nil
|
||||
}
|
|
@ -8,21 +8,17 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
sdk "github.com/pomerium/zero-sdk"
|
||||
connect_mux "github.com/pomerium/zero-sdk/connect-mux"
|
||||
)
|
||||
|
||||
// reconcilerConfig contains the configuration for the resource bundles reconciler.
|
||||
type reconcilerConfig struct {
|
||||
api *sdk.API
|
||||
connectMux *connect_mux.Mux
|
||||
api *sdk.API
|
||||
|
||||
databrokerClient databroker.DataBrokerServiceClient
|
||||
databrokerRPS int
|
||||
|
||||
tmpDir string
|
||||
|
||||
minDownloadTTL time.Duration
|
||||
|
||||
httpClient *http.Client
|
||||
|
||||
checkForUpdateIntervalWhenDisconnected time.Duration
|
||||
|
@ -49,13 +45,6 @@ func WithAPI(client *sdk.API) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithConnectMux configures the connect mux.
|
||||
func WithConnectMux(client *connect_mux.Mux) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.connectMux = client
|
||||
}
|
||||
}
|
||||
|
||||
// WithDataBrokerClient configures the databroker client.
|
||||
func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
|
@ -63,14 +52,6 @@ func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option {
|
|||
}
|
||||
}
|
||||
|
||||
// WithMinDownloadTTL configures the download URL validity in cache,
|
||||
// before it would be requested again from the cloud (that also sets it's own TTL to the signed URL).
|
||||
func WithMinDownloadTTL(ttl time.Duration) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
cfg.minDownloadTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// WithDownloadHTTPClient configures the http client used for downloading files.
|
||||
func WithDownloadHTTPClient(client *http.Client) Option {
|
||||
return func(cfg *reconcilerConfig) {
|
||||
|
@ -112,7 +93,6 @@ func newConfig(opts ...Option) *reconcilerConfig {
|
|||
cfg := &reconcilerConfig{}
|
||||
for _, opt := range []Option{
|
||||
WithTemporaryDirectory(os.TempDir()),
|
||||
WithMinDownloadTTL(5 * time.Minute),
|
||||
WithDownloadHTTPClient(http.DefaultClient),
|
||||
WithDatabrokerRPSLimit(1_000),
|
||||
WithCheckForUpdateIntervalWhenDisconnected(time.Minute * 5),
|
||||
|
|
|
@ -1,132 +0,0 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type downloadOptions struct {
|
||||
requestCacheEntry *BundleCacheEntry
|
||||
responseCacheEntry *BundleCacheEntry
|
||||
}
|
||||
|
||||
// DownloadOption is an option for downloading a bundle
|
||||
type DownloadOption func(*downloadOptions)
|
||||
|
||||
// WithCacheEntry sets the cache entry to use for the request.
|
||||
func WithCacheEntry(entry BundleCacheEntry) DownloadOption {
|
||||
return func(opts *downloadOptions) {
|
||||
opts.requestCacheEntry = &entry
|
||||
}
|
||||
}
|
||||
|
||||
// WithUpdateCacheEntry updates the cache entry with the values from the response.
|
||||
func WithUpdateCacheEntry(dst *BundleCacheEntry) DownloadOption {
|
||||
return func(opts *downloadOptions) {
|
||||
opts.responseCacheEntry = dst
|
||||
}
|
||||
}
|
||||
|
||||
func getDownloadOptions(opts ...DownloadOption) downloadOptions {
|
||||
var options downloadOptions
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
func (opt *downloadOptions) updateRequest(req *http.Request) {
|
||||
if opt.requestCacheEntry != nil {
|
||||
req.Header.Set("If-None-Match", opt.requestCacheEntry.ETag)
|
||||
req.Header.Set("If-Modified-Since", opt.requestCacheEntry.LastModified.Format(http.TimeFormat))
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *downloadOptions) updateFromResponse(resp *http.Response) error {
|
||||
if opt.responseCacheEntry == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusNotModified && opt.requestCacheEntry != nil {
|
||||
*opt.responseCacheEntry = *opt.requestCacheEntry
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
return updateBundleCacheEntryFromResponse(opt.responseCacheEntry, resp.Header)
|
||||
}
|
||||
|
||||
// DownloadBundleIfChanged downloads the bundle if it has changed.
|
||||
func (c *service) DownloadBundleIfChanged(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
bundleKey string,
|
||||
opts ...DownloadOption,
|
||||
) error {
|
||||
options := getDownloadOptions(opts...)
|
||||
|
||||
url, err := c.config.api.DownloadClusterResourceBundle(ctx, bundleKey, c.config.minDownloadTTL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get download url: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("new request: %w", err)
|
||||
}
|
||||
options.updateRequest(req)
|
||||
|
||||
resp, err := c.config.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
err = options.updateFromResponse(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusNotModified {
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected response: %d/%s", resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
_, err = io.Copy(dst, resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateBundleCacheEntryFromResponse(dst *BundleCacheEntry, headers http.Header) error {
|
||||
txt := headers.Get("Last-Modified")
|
||||
if txt == "" {
|
||||
return fmt.Errorf("missing last-modified header")
|
||||
}
|
||||
|
||||
lastModified, err := time.Parse(http.TimeFormat, txt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse last modified: %w", err)
|
||||
}
|
||||
|
||||
etag := headers.Get("ETag")
|
||||
if etag == "" {
|
||||
return fmt.Errorf("missing etag header")
|
||||
}
|
||||
|
||||
dst.LastModified = lastModified
|
||||
dst.ETag = etag
|
||||
|
||||
return nil
|
||||
}
|
|
@ -3,7 +3,19 @@ package reconciler
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
zero_sdk "github.com/pomerium/zero-sdk"
|
||||
)
|
||||
|
||||
// BundleCacheEntry is a cache entry for a bundle
|
||||
|
@ -16,29 +28,139 @@ import (
|
|||
// also it works in case of multiple instances, as it uses
|
||||
// the databroker database as a shared cache.
|
||||
type BundleCacheEntry struct {
|
||||
ETag string
|
||||
LastModified time.Time
|
||||
RecordTypes []string
|
||||
zero_sdk.DownloadConditional
|
||||
RecordTypes []string
|
||||
}
|
||||
|
||||
const (
|
||||
bundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrBundleCacheEntryNotFound is returned when a bundle cache entry is not found
|
||||
ErrBundleCacheEntryNotFound = errors.New("bundle cache entry not found")
|
||||
)
|
||||
|
||||
// Equals returns true if the two cache entries are equal
|
||||
func (c *BundleCacheEntry) Equals(other BundleCacheEntry) bool {
|
||||
return c.ETag == other.ETag && c.LastModified.Equal(other.LastModified)
|
||||
}
|
||||
|
||||
// GetBundleCacheEntry gets a bundle cache entry from the databroker
|
||||
func (c *service) GetBundleCacheEntry(_ context.Context, _ string, _ *BundleCacheEntry) error {
|
||||
// TODO: implement
|
||||
return ErrBundleCacheEntryNotFound
|
||||
func (c *service) GetBundleCacheEntry(ctx context.Context, id string) (*BundleCacheEntry, error) {
|
||||
record, err := c.config.databrokerClient.Get(ctx, &databroker.GetRequest{
|
||||
Type: bundleCacheEntryRecordType,
|
||||
Id: id,
|
||||
})
|
||||
if err != nil && status.Code(err) == codes.NotFound {
|
||||
return nil, ErrBundleCacheEntryNotFound
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("get bundle cache entry: %w", err)
|
||||
}
|
||||
|
||||
var dst BundleCacheEntry
|
||||
data := record.GetRecord().GetData()
|
||||
err = dst.FromAny(data)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).
|
||||
Str("bundle-id", id).
|
||||
Str("data", protojson.Format(data)).
|
||||
Msg("unmarshal bundle cache entry")
|
||||
// we would allow it to be overwritten by the update process
|
||||
return nil, ErrBundleCacheEntryNotFound
|
||||
}
|
||||
|
||||
return &dst, nil
|
||||
}
|
||||
|
||||
// SetBundleCacheEntry sets a bundle cache entry in the databroker
|
||||
func (c *service) SetBundleCacheEntry(_ context.Context, _ string, _ BundleCacheEntry) error {
|
||||
// TODO: implement
|
||||
func (c *service) SetBundleCacheEntry(ctx context.Context, id string, src BundleCacheEntry) error {
|
||||
val, err := src.ToAny()
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal bundle cache entry: %w", err)
|
||||
}
|
||||
resp, err := c.config.databrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: []*databroker.Record{
|
||||
{
|
||||
Type: bundleCacheEntryRecordType,
|
||||
Id: id,
|
||||
Data: val,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("set bundle cache entry: %w", err)
|
||||
}
|
||||
log.Ctx(ctx).Info().
|
||||
Str("bundle-id", id).
|
||||
Str("sent", protojson.Format(val)).
|
||||
Str("got", protojson.Format(resp.GetRecord().GetData())).
|
||||
Msg("set bundle cache entry")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToAny marshals a BundleCacheEntry into an anypb.Any
|
||||
func (r *BundleCacheEntry) ToAny() (*anypb.Any, error) {
|
||||
err := r.Validate()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
|
||||
types := make([]*structpb.Value, 0, len(r.RecordTypes))
|
||||
for _, t := range r.RecordTypes {
|
||||
types = append(types, structpb.NewStringValue(t))
|
||||
}
|
||||
|
||||
return protoutil.NewAny(&structpb.Struct{
|
||||
Fields: map[string]*structpb.Value{
|
||||
"etag": structpb.NewStringValue(r.ETag),
|
||||
"last_modified": structpb.NewStringValue(r.LastModified),
|
||||
"record_types": structpb.NewListValue(&structpb.ListValue{Values: types}),
|
||||
},
|
||||
}), nil
|
||||
}
|
||||
|
||||
// FromAny unmarshals an anypb.Any into a BundleCacheEntry
|
||||
func (r *BundleCacheEntry) FromAny(any *anypb.Any) error {
|
||||
var s structpb.Struct
|
||||
err := any.UnmarshalTo(&s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal struct: %w", err)
|
||||
}
|
||||
fmt.Println("****", protojson.Format(any), protojson.Format(&s))
|
||||
|
||||
r.ETag = s.GetFields()["etag"].GetStringValue()
|
||||
r.LastModified = s.GetFields()["last_modified"].GetStringValue()
|
||||
|
||||
for _, v := range s.GetFields()["record_types"].GetListValue().GetValues() {
|
||||
r.RecordTypes = append(r.RecordTypes, v.GetStringValue())
|
||||
}
|
||||
|
||||
err = r.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate validates a BundleCacheEntry
|
||||
func (r *BundleCacheEntry) Validate() error {
|
||||
var errs *multierror.Error
|
||||
if len(r.RecordTypes) == 0 {
|
||||
errs = multierror.Append(errs, errors.New("record_types is required"))
|
||||
}
|
||||
if err := r.DownloadConditional.Validate(); err != nil {
|
||||
errs = multierror.Append(errs, err)
|
||||
}
|
||||
return errs.ErrorOrNil()
|
||||
}
|
||||
|
||||
// GetDownloadConditional returns conditional download information
|
||||
func (r *BundleCacheEntry) GetDownloadConditional() *zero_sdk.DownloadConditional {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
cond := r.DownloadConditional
|
||||
return &cond
|
||||
}
|
||||
|
||||
// Equals returns true if the two cache entries are equal
|
||||
func (r *BundleCacheEntry) Equals(other *BundleCacheEntry) bool {
|
||||
return r != nil && other != nil &&
|
||||
r.ETag == other.ETag && r.LastModified == other.LastModified
|
||||
}
|
||||
|
|
29
internal/zero/reconciler/download_cache_test.go
Normal file
29
internal/zero/reconciler/download_cache_test.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package reconciler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
||||
zero_sdk "github.com/pomerium/zero-sdk"
|
||||
)
|
||||
|
||||
func TestCacheEntryProto(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
original := reconciler.BundleCacheEntry{
|
||||
DownloadConditional: zero_sdk.DownloadConditional{
|
||||
ETag: "etag value",
|
||||
LastModified: "2009-02-13 18:31:30 -0500 EST",
|
||||
},
|
||||
RecordTypes: []string{"one", "two"},
|
||||
}
|
||||
originalProto, err := original.ToAny()
|
||||
require.NoError(t, err)
|
||||
var unmarshaled reconciler.BundleCacheEntry
|
||||
err = unmarshaled.FromAny(originalProto)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, original.Equals(&unmarshaled))
|
||||
}
|
|
@ -52,7 +52,7 @@ func Run(ctx context.Context, opts ...Option) error {
|
|||
// it may be later optimized by splitting between download and reconciliation process,
|
||||
// as we would get more resource bundles beyond the config.
|
||||
func (c *service) watchUpdates(ctx context.Context) error {
|
||||
return c.config.connectMux.Watch(ctx,
|
||||
return c.config.api.Watch(ctx,
|
||||
connect_mux.WithOnConnected(func(ctx context.Context) {
|
||||
c.triggerFullUpdate(true)
|
||||
}),
|
||||
|
|
|
@ -16,7 +16,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -29,22 +28,27 @@ func (c *service) SyncLoop(ctx context.Context) error {
|
|||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
ticker.Reset(c.periodicUpdateInterval.Load())
|
||||
dur := c.periodicUpdateInterval.Load()
|
||||
ticker.Reset(dur)
|
||||
log.Ctx(ctx).Info().Str("duration", dur.String()).Msg("*** next sync cycle ***")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.bundleSyncRequest:
|
||||
log.Ctx(ctx).Info().Msg("bundle sync triggered")
|
||||
err := c.syncBundles(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reconciler: sync bundles: %w", err)
|
||||
}
|
||||
case <-c.fullSyncRequest:
|
||||
log.Ctx(ctx).Info().Msg("full sync triggered")
|
||||
err := c.syncAll(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reconciler: sync all: %w", err)
|
||||
}
|
||||
case <-ticker.C:
|
||||
log.Ctx(ctx).Info().Msg("periodic sync triggered")
|
||||
err := c.syncAll(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reconciler: sync all: %w", err)
|
||||
|
@ -72,7 +76,7 @@ func (c *service) syncBundleList(ctx context.Context) error {
|
|||
// refresh bundle list,
|
||||
// ignoring other signals while we're retrying
|
||||
return retry.Retry(ctx,
|
||||
"refresh bundle list", c.RefreshBundleList,
|
||||
"refresh bundle list", c.refreshBundleList,
|
||||
retry.WithWatch("refresh bundle list", c.fullSyncRequest, nil),
|
||||
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
|
||||
)
|
||||
|
@ -83,7 +87,7 @@ func (c *service) syncBundleList(ctx context.Context) error {
|
|||
func (c *service) syncBundles(ctx context.Context) error {
|
||||
return retry.Retry(ctx,
|
||||
"sync bundles", c.trySyncBundles,
|
||||
retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.RefreshBundleList),
|
||||
retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.refreshBundleList),
|
||||
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
|
||||
)
|
||||
}
|
||||
|
@ -110,38 +114,41 @@ func (c *service) trySyncBundles(ctx context.Context) error {
|
|||
// This is only persisted in the databroker after all records are successfully synced.
|
||||
// That allows us to ignore any changes based on the same bundle state, without need to re-check all records between bundle and databroker.
|
||||
func (c *service) syncBundle(ctx context.Context, key string) error {
|
||||
var cached, changed BundleCacheEntry
|
||||
opts := []DownloadOption{
|
||||
WithUpdateCacheEntry(&changed),
|
||||
}
|
||||
|
||||
err := c.GetBundleCacheEntry(ctx, key, &cached)
|
||||
if err == nil {
|
||||
opts = append(opts, WithCacheEntry(cached))
|
||||
} else if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) {
|
||||
cached, err := c.GetBundleCacheEntry(ctx, key)
|
||||
if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) {
|
||||
return fmt.Errorf("get bundle cache entry: %w", err)
|
||||
}
|
||||
|
||||
// download is much faster compared to databroker sync,
|
||||
// so we don't use pipe but rather download to a temp file and then sync it to databroker
|
||||
|
||||
fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key))
|
||||
fd, err := c.GetTmpFile(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file: %w", err)
|
||||
return fmt.Errorf("get tmp file: %w", err)
|
||||
}
|
||||
defer fd.Close()
|
||||
defer os.Remove(fd.Name())
|
||||
defer func() {
|
||||
if err := fd.Close(); err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("close tmp file")
|
||||
}
|
||||
}()
|
||||
|
||||
err = c.DownloadBundleIfChanged(ctx, fd, key, opts...)
|
||||
conditional := cached.GetDownloadConditional()
|
||||
log.Ctx(ctx).Info().Str("id", key).Any("conditional", conditional).Msg("downloading bundle")
|
||||
|
||||
result, err := c.config.api.DownloadClusterResourceBundle(ctx, fd, key, conditional)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download bundle: %w", err)
|
||||
}
|
||||
|
||||
if changed.Equals(cached) {
|
||||
if result.NotModified {
|
||||
log.Ctx(ctx).Info().Str("bundle", key).Msg("bundle not changed")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Info().Str("bundle", key).
|
||||
Interface("cached-entry", cached).
|
||||
Interface("current-entry", result.DownloadConditional).
|
||||
Msg("bundle changed")
|
||||
|
||||
_, err = fd.Seek(0, io.SeekStart)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seek to start: %w", err)
|
||||
|
@ -151,16 +158,19 @@ func (c *service) syncBundle(ctx context.Context, key string) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("apply bundle to databroker: %w", err)
|
||||
}
|
||||
changed.RecordTypes = bundleRecordTypes
|
||||
current := BundleCacheEntry{
|
||||
DownloadConditional: *result.DownloadConditional,
|
||||
RecordTypes: bundleRecordTypes,
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Info().
|
||||
Str("bundle", key).
|
||||
Strs("record_types", bundleRecordTypes).
|
||||
Str("etag", changed.ETag).
|
||||
Time("last_modified", changed.LastModified).
|
||||
Str("etag", current.ETag).
|
||||
Str("last_modified", current.LastModified).
|
||||
Msg("bundle synced")
|
||||
|
||||
err = c.SetBundleCacheEntry(ctx, key, changed)
|
||||
err = c.SetBundleCacheEntry(ctx, key, current)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set bundle cache entry: %w", err)
|
||||
}
|
||||
|
@ -198,3 +208,18 @@ func (c *service) syncBundleToDatabroker(ctx context.Context, src io.Reader) ([]
|
|||
|
||||
return bundleRecords.RecordTypes(), nil
|
||||
}
|
||||
|
||||
func (c *service) refreshBundleList(ctx context.Context) error {
|
||||
resp, err := c.config.api.GetClusterResourceBundles(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get bundles: %w", err)
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(resp.Bundles))
|
||||
for _, v := range resp.Bundles {
|
||||
ids = append(ids, v.Id)
|
||||
}
|
||||
|
||||
c.bundles.Set(ids)
|
||||
return nil
|
||||
}
|
||||
|
|
40
internal/zero/reconciler/tmpfile.go
Normal file
40
internal/zero/reconciler/tmpfile.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package reconciler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
// ReadWriteSeekCloser is a file that can be read, written, seeked, and closed.
|
||||
type ReadWriteSeekCloser interface {
|
||||
io.ReadWriteSeeker
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// GetTmpFile returns a temporary file for the reconciler to use.
|
||||
// TODO: encrypt contents to ensure encryption at rest
|
||||
func (c *service) GetTmpFile(key string) (ReadWriteSeekCloser, error) {
|
||||
fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
return &tmpFile{File: fd}, nil
|
||||
}
|
||||
|
||||
type tmpFile struct {
|
||||
*os.File
|
||||
}
|
||||
|
||||
func (f *tmpFile) Close() error {
|
||||
var errs *multierror.Error
|
||||
if err := f.File.Close(); err != nil {
|
||||
errs = multierror.Append(errs, err)
|
||||
}
|
||||
if err := os.Remove(f.File.Name()); err != nil {
|
||||
errs = multierror.Append(errs, err)
|
||||
}
|
||||
return errs.ErrorOrNil()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue