use centralied api

This commit is contained in:
Denis Mishin 2023-08-15 20:36:17 -04:00
parent 4e5993488e
commit 01c17844df
10 changed files with 259 additions and 217 deletions

2
go.mod
View file

@ -52,7 +52,7 @@ require (
github.com/pomerium/csrf v1.7.0
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b
github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31
github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05
github.com/prometheus/client_golang v1.16.0
github.com/prometheus/client_model v0.4.0
github.com/prometheus/common v0.44.0

4
go.sum
View file

@ -657,8 +657,8 @@ github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 h1:3YQY1sb5
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0=
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b h1:oll/aOfJudnqFAwCvoXK9+WN2zVjTzHVPLXCggHQmHk=
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b/go.mod h1:KswTenBBh4y1pmhU2dpm8VgJQCgSErCg7OOFTeebrNc=
github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31 h1:FUoy3dpWd0ECmFHmYVUCJd7lVhQjglZsnrfjvRsVj4w=
github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31/go.mod h1:cAyfEGM8blUzchYhOWrufuj/6lOF277meB4c/TjMS28=
github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05 h1:Rl2df8q+DAd3SsJn9MpXrbo7JRNCDHVaohOyUZ2IJik=
github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05/go.mod h1:cAyfEGM8blUzchYhOWrufuj/6lOF277meB4c/TjMS28=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=

View file

@ -1,22 +0,0 @@
package reconciler
import (
"context"
"fmt"
)
// GetBundles returns the list of bundles that have to be present in the cluster.
func (c *service) RefreshBundleList(ctx context.Context) error {
resp, err := c.config.api.GetClusterResourceBundles(ctx)
if err != nil {
return fmt.Errorf("get bundles: %w", err)
}
ids := make([]string, 0, len(resp.Bundles))
for _, v := range resp.Bundles {
ids = append(ids, v.Id)
}
c.bundles.Set(ids)
return nil
}

View file

@ -8,21 +8,17 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/databroker"
sdk "github.com/pomerium/zero-sdk"
connect_mux "github.com/pomerium/zero-sdk/connect-mux"
)
// reconcilerConfig contains the configuration for the resource bundles reconciler.
type reconcilerConfig struct {
api *sdk.API
connectMux *connect_mux.Mux
api *sdk.API
databrokerClient databroker.DataBrokerServiceClient
databrokerRPS int
tmpDir string
minDownloadTTL time.Duration
httpClient *http.Client
checkForUpdateIntervalWhenDisconnected time.Duration
@ -49,13 +45,6 @@ func WithAPI(client *sdk.API) Option {
}
}
// WithConnectMux configures the connect mux.
func WithConnectMux(client *connect_mux.Mux) Option {
return func(cfg *reconcilerConfig) {
cfg.connectMux = client
}
}
// WithDataBrokerClient configures the databroker client.
func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option {
return func(cfg *reconcilerConfig) {
@ -63,14 +52,6 @@ func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option {
}
}
// WithMinDownloadTTL configures the download URL validity in cache,
// before it would be requested again from the cloud (that also sets it's own TTL to the signed URL).
func WithMinDownloadTTL(ttl time.Duration) Option {
return func(cfg *reconcilerConfig) {
cfg.minDownloadTTL = ttl
}
}
// WithDownloadHTTPClient configures the http client used for downloading files.
func WithDownloadHTTPClient(client *http.Client) Option {
return func(cfg *reconcilerConfig) {
@ -112,7 +93,6 @@ func newConfig(opts ...Option) *reconcilerConfig {
cfg := &reconcilerConfig{}
for _, opt := range []Option{
WithTemporaryDirectory(os.TempDir()),
WithMinDownloadTTL(5 * time.Minute),
WithDownloadHTTPClient(http.DefaultClient),
WithDatabrokerRPSLimit(1_000),
WithCheckForUpdateIntervalWhenDisconnected(time.Minute * 5),

View file

@ -1,132 +0,0 @@
package reconciler
import (
"context"
"fmt"
"io"
"net/http"
"time"
)
type downloadOptions struct {
requestCacheEntry *BundleCacheEntry
responseCacheEntry *BundleCacheEntry
}
// DownloadOption is an option for downloading a bundle
type DownloadOption func(*downloadOptions)
// WithCacheEntry sets the cache entry to use for the request.
func WithCacheEntry(entry BundleCacheEntry) DownloadOption {
return func(opts *downloadOptions) {
opts.requestCacheEntry = &entry
}
}
// WithUpdateCacheEntry updates the cache entry with the values from the response.
func WithUpdateCacheEntry(dst *BundleCacheEntry) DownloadOption {
return func(opts *downloadOptions) {
opts.responseCacheEntry = dst
}
}
func getDownloadOptions(opts ...DownloadOption) downloadOptions {
var options downloadOptions
for _, opt := range opts {
opt(&options)
}
return options
}
func (opt *downloadOptions) updateRequest(req *http.Request) {
if opt.requestCacheEntry != nil {
req.Header.Set("If-None-Match", opt.requestCacheEntry.ETag)
req.Header.Set("If-Modified-Since", opt.requestCacheEntry.LastModified.Format(http.TimeFormat))
}
}
func (opt *downloadOptions) updateFromResponse(resp *http.Response) error {
if opt.responseCacheEntry == nil {
return nil
}
if resp.StatusCode == http.StatusNotModified && opt.requestCacheEntry != nil {
*opt.responseCacheEntry = *opt.requestCacheEntry
return nil
}
if resp.StatusCode != http.StatusOK {
return nil
}
return updateBundleCacheEntryFromResponse(opt.responseCacheEntry, resp.Header)
}
// DownloadBundleIfChanged downloads the bundle if it has changed.
func (c *service) DownloadBundleIfChanged(
ctx context.Context,
dst io.Writer,
bundleKey string,
opts ...DownloadOption,
) error {
options := getDownloadOptions(opts...)
url, err := c.config.api.DownloadClusterResourceBundle(ctx, bundleKey, c.config.minDownloadTTL)
if err != nil {
return fmt.Errorf("get download url: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
if err != nil {
return fmt.Errorf("new request: %w", err)
}
options.updateRequest(req)
resp, err := c.config.httpClient.Do(req)
if err != nil {
return fmt.Errorf("do request: %w", err)
}
defer resp.Body.Close()
err = options.updateFromResponse(resp)
if err != nil {
return fmt.Errorf("response: %w", err)
}
if resp.StatusCode == http.StatusNotModified {
return nil
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected response: %d/%s", resp.StatusCode, resp.Status)
}
_, err = io.Copy(dst, resp.Body)
if err != nil {
return fmt.Errorf("write file: %w", err)
}
return nil
}
func updateBundleCacheEntryFromResponse(dst *BundleCacheEntry, headers http.Header) error {
txt := headers.Get("Last-Modified")
if txt == "" {
return fmt.Errorf("missing last-modified header")
}
lastModified, err := time.Parse(http.TimeFormat, txt)
if err != nil {
return fmt.Errorf("parse last modified: %w", err)
}
etag := headers.Get("ETag")
if etag == "" {
return fmt.Errorf("missing etag header")
}
dst.LastModified = lastModified
dst.ETag = etag
return nil
}

View file

@ -3,7 +3,19 @@ package reconciler
import (
"context"
"errors"
"time"
"fmt"
"github.com/hashicorp/go-multierror"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
zero_sdk "github.com/pomerium/zero-sdk"
)
// BundleCacheEntry is a cache entry for a bundle
@ -16,29 +28,139 @@ import (
// also it works in case of multiple instances, as it uses
// the databroker database as a shared cache.
type BundleCacheEntry struct {
ETag string
LastModified time.Time
RecordTypes []string
zero_sdk.DownloadConditional
RecordTypes []string
}
const (
bundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry"
)
var (
// ErrBundleCacheEntryNotFound is returned when a bundle cache entry is not found
ErrBundleCacheEntryNotFound = errors.New("bundle cache entry not found")
)
// Equals returns true if the two cache entries are equal
func (c *BundleCacheEntry) Equals(other BundleCacheEntry) bool {
return c.ETag == other.ETag && c.LastModified.Equal(other.LastModified)
}
// GetBundleCacheEntry gets a bundle cache entry from the databroker
func (c *service) GetBundleCacheEntry(_ context.Context, _ string, _ *BundleCacheEntry) error {
// TODO: implement
return ErrBundleCacheEntryNotFound
func (c *service) GetBundleCacheEntry(ctx context.Context, id string) (*BundleCacheEntry, error) {
record, err := c.config.databrokerClient.Get(ctx, &databroker.GetRequest{
Type: bundleCacheEntryRecordType,
Id: id,
})
if err != nil && status.Code(err) == codes.NotFound {
return nil, ErrBundleCacheEntryNotFound
} else if err != nil {
return nil, fmt.Errorf("get bundle cache entry: %w", err)
}
var dst BundleCacheEntry
data := record.GetRecord().GetData()
err = dst.FromAny(data)
if err != nil {
log.Ctx(ctx).Error().Err(err).
Str("bundle-id", id).
Str("data", protojson.Format(data)).
Msg("unmarshal bundle cache entry")
// we would allow it to be overwritten by the update process
return nil, ErrBundleCacheEntryNotFound
}
return &dst, nil
}
// SetBundleCacheEntry sets a bundle cache entry in the databroker
func (c *service) SetBundleCacheEntry(_ context.Context, _ string, _ BundleCacheEntry) error {
// TODO: implement
func (c *service) SetBundleCacheEntry(ctx context.Context, id string, src BundleCacheEntry) error {
val, err := src.ToAny()
if err != nil {
return fmt.Errorf("marshal bundle cache entry: %w", err)
}
resp, err := c.config.databrokerClient.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{
{
Type: bundleCacheEntryRecordType,
Id: id,
Data: val,
},
},
})
if err != nil {
return fmt.Errorf("set bundle cache entry: %w", err)
}
log.Ctx(ctx).Info().
Str("bundle-id", id).
Str("sent", protojson.Format(val)).
Str("got", protojson.Format(resp.GetRecord().GetData())).
Msg("set bundle cache entry")
return nil
}
// ToAny marshals a BundleCacheEntry into an anypb.Any
func (r *BundleCacheEntry) ToAny() (*anypb.Any, error) {
err := r.Validate()
if err != nil {
return nil, fmt.Errorf("validate: %w", err)
}
types := make([]*structpb.Value, 0, len(r.RecordTypes))
for _, t := range r.RecordTypes {
types = append(types, structpb.NewStringValue(t))
}
return protoutil.NewAny(&structpb.Struct{
Fields: map[string]*structpb.Value{
"etag": structpb.NewStringValue(r.ETag),
"last_modified": structpb.NewStringValue(r.LastModified),
"record_types": structpb.NewListValue(&structpb.ListValue{Values: types}),
},
}), nil
}
// FromAny unmarshals an anypb.Any into a BundleCacheEntry
func (r *BundleCacheEntry) FromAny(any *anypb.Any) error {
var s structpb.Struct
err := any.UnmarshalTo(&s)
if err != nil {
return fmt.Errorf("unmarshal struct: %w", err)
}
fmt.Println("****", protojson.Format(any), protojson.Format(&s))
r.ETag = s.GetFields()["etag"].GetStringValue()
r.LastModified = s.GetFields()["last_modified"].GetStringValue()
for _, v := range s.GetFields()["record_types"].GetListValue().GetValues() {
r.RecordTypes = append(r.RecordTypes, v.GetStringValue())
}
err = r.Validate()
if err != nil {
return fmt.Errorf("validate: %w", err)
}
return nil
}
// Validate validates a BundleCacheEntry
func (r *BundleCacheEntry) Validate() error {
var errs *multierror.Error
if len(r.RecordTypes) == 0 {
errs = multierror.Append(errs, errors.New("record_types is required"))
}
if err := r.DownloadConditional.Validate(); err != nil {
errs = multierror.Append(errs, err)
}
return errs.ErrorOrNil()
}
// GetDownloadConditional returns conditional download information
func (r *BundleCacheEntry) GetDownloadConditional() *zero_sdk.DownloadConditional {
if r == nil {
return nil
}
cond := r.DownloadConditional
return &cond
}
// Equals returns true if the two cache entries are equal
func (r *BundleCacheEntry) Equals(other *BundleCacheEntry) bool {
return r != nil && other != nil &&
r.ETag == other.ETag && r.LastModified == other.LastModified
}

View file

@ -0,0 +1,29 @@
package reconciler_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/zero/reconciler"
zero_sdk "github.com/pomerium/zero-sdk"
)
func TestCacheEntryProto(t *testing.T) {
t.Parallel()
original := reconciler.BundleCacheEntry{
DownloadConditional: zero_sdk.DownloadConditional{
ETag: "etag value",
LastModified: "2009-02-13 18:31:30 -0500 EST",
},
RecordTypes: []string{"one", "two"},
}
originalProto, err := original.ToAny()
require.NoError(t, err)
var unmarshaled reconciler.BundleCacheEntry
err = unmarshaled.FromAny(originalProto)
require.NoError(t, err)
assert.True(t, original.Equals(&unmarshaled))
}

View file

@ -52,7 +52,7 @@ func Run(ctx context.Context, opts ...Option) error {
// it may be later optimized by splitting between download and reconciliation process,
// as we would get more resource bundles beyond the config.
func (c *service) watchUpdates(ctx context.Context) error {
return c.config.connectMux.Watch(ctx,
return c.config.api.Watch(ctx,
connect_mux.WithOnConnected(func(ctx context.Context) {
c.triggerFullUpdate(true)
}),

View file

@ -16,7 +16,6 @@ import (
"errors"
"fmt"
"io"
"os"
"time"
"github.com/pomerium/pomerium/internal/log"
@ -29,22 +28,27 @@ func (c *service) SyncLoop(ctx context.Context) error {
defer ticker.Stop()
for {
ticker.Reset(c.periodicUpdateInterval.Load())
dur := c.periodicUpdateInterval.Load()
ticker.Reset(dur)
log.Ctx(ctx).Info().Str("duration", dur.String()).Msg("*** next sync cycle ***")
select {
case <-ctx.Done():
return ctx.Err()
case <-c.bundleSyncRequest:
log.Ctx(ctx).Info().Msg("bundle sync triggered")
err := c.syncBundles(ctx)
if err != nil {
return fmt.Errorf("reconciler: sync bundles: %w", err)
}
case <-c.fullSyncRequest:
log.Ctx(ctx).Info().Msg("full sync triggered")
err := c.syncAll(ctx)
if err != nil {
return fmt.Errorf("reconciler: sync all: %w", err)
}
case <-ticker.C:
log.Ctx(ctx).Info().Msg("periodic sync triggered")
err := c.syncAll(ctx)
if err != nil {
return fmt.Errorf("reconciler: sync all: %w", err)
@ -72,7 +76,7 @@ func (c *service) syncBundleList(ctx context.Context) error {
// refresh bundle list,
// ignoring other signals while we're retrying
return retry.Retry(ctx,
"refresh bundle list", c.RefreshBundleList,
"refresh bundle list", c.refreshBundleList,
retry.WithWatch("refresh bundle list", c.fullSyncRequest, nil),
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
)
@ -83,7 +87,7 @@ func (c *service) syncBundleList(ctx context.Context) error {
func (c *service) syncBundles(ctx context.Context) error {
return retry.Retry(ctx,
"sync bundles", c.trySyncBundles,
retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.RefreshBundleList),
retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.refreshBundleList),
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
)
}
@ -110,38 +114,41 @@ func (c *service) trySyncBundles(ctx context.Context) error {
// This is only persisted in the databroker after all records are successfully synced.
// That allows us to ignore any changes based on the same bundle state, without need to re-check all records between bundle and databroker.
func (c *service) syncBundle(ctx context.Context, key string) error {
var cached, changed BundleCacheEntry
opts := []DownloadOption{
WithUpdateCacheEntry(&changed),
}
err := c.GetBundleCacheEntry(ctx, key, &cached)
if err == nil {
opts = append(opts, WithCacheEntry(cached))
} else if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) {
cached, err := c.GetBundleCacheEntry(ctx, key)
if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) {
return fmt.Errorf("get bundle cache entry: %w", err)
}
// download is much faster compared to databroker sync,
// so we don't use pipe but rather download to a temp file and then sync it to databroker
fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key))
fd, err := c.GetTmpFile(key)
if err != nil {
return fmt.Errorf("create temp file: %w", err)
return fmt.Errorf("get tmp file: %w", err)
}
defer fd.Close()
defer os.Remove(fd.Name())
defer func() {
if err := fd.Close(); err != nil {
log.Ctx(ctx).Error().Err(err).Msg("close tmp file")
}
}()
err = c.DownloadBundleIfChanged(ctx, fd, key, opts...)
conditional := cached.GetDownloadConditional()
log.Ctx(ctx).Info().Str("id", key).Any("conditional", conditional).Msg("downloading bundle")
result, err := c.config.api.DownloadClusterResourceBundle(ctx, fd, key, conditional)
if err != nil {
return fmt.Errorf("download bundle: %w", err)
}
if changed.Equals(cached) {
if result.NotModified {
log.Ctx(ctx).Info().Str("bundle", key).Msg("bundle not changed")
return nil
}
log.Ctx(ctx).Info().Str("bundle", key).
Interface("cached-entry", cached).
Interface("current-entry", result.DownloadConditional).
Msg("bundle changed")
_, err = fd.Seek(0, io.SeekStart)
if err != nil {
return fmt.Errorf("seek to start: %w", err)
@ -151,16 +158,19 @@ func (c *service) syncBundle(ctx context.Context, key string) error {
if err != nil {
return fmt.Errorf("apply bundle to databroker: %w", err)
}
changed.RecordTypes = bundleRecordTypes
current := BundleCacheEntry{
DownloadConditional: *result.DownloadConditional,
RecordTypes: bundleRecordTypes,
}
log.Ctx(ctx).Info().
Str("bundle", key).
Strs("record_types", bundleRecordTypes).
Str("etag", changed.ETag).
Time("last_modified", changed.LastModified).
Str("etag", current.ETag).
Str("last_modified", current.LastModified).
Msg("bundle synced")
err = c.SetBundleCacheEntry(ctx, key, changed)
err = c.SetBundleCacheEntry(ctx, key, current)
if err != nil {
return fmt.Errorf("set bundle cache entry: %w", err)
}
@ -198,3 +208,18 @@ func (c *service) syncBundleToDatabroker(ctx context.Context, src io.Reader) ([]
return bundleRecords.RecordTypes(), nil
}
func (c *service) refreshBundleList(ctx context.Context) error {
resp, err := c.config.api.GetClusterResourceBundles(ctx)
if err != nil {
return fmt.Errorf("get bundles: %w", err)
}
ids := make([]string, 0, len(resp.Bundles))
for _, v := range resp.Bundles {
ids = append(ids, v.Id)
}
c.bundles.Set(ids)
return nil
}

View file

@ -0,0 +1,40 @@
package reconciler
import (
"fmt"
"io"
"os"
"github.com/hashicorp/go-multierror"
)
// ReadWriteSeekCloser is a file that can be read, written, seeked, and closed.
type ReadWriteSeekCloser interface {
io.ReadWriteSeeker
io.Closer
}
// GetTmpFile returns a temporary file for the reconciler to use.
// TODO: encrypt contents to ensure encryption at rest
func (c *service) GetTmpFile(key string) (ReadWriteSeekCloser, error) {
fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key))
if err != nil {
return nil, fmt.Errorf("create temp file: %w", err)
}
return &tmpFile{File: fd}, nil
}
type tmpFile struct {
*os.File
}
func (f *tmpFile) Close() error {
var errs *multierror.Error
if err := f.File.Close(); err != nil {
errs = multierror.Append(errs, err)
}
if err := os.Remove(f.File.Name()); err != nil {
errs = multierror.Append(errs, err)
}
return errs.ErrorOrNil()
}