mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-26 21:19:31 +02:00
use centralied api
This commit is contained in:
parent
4e5993488e
commit
01c17844df
10 changed files with 259 additions and 217 deletions
2
go.mod
2
go.mod
|
@ -52,7 +52,7 @@ require (
|
||||||
github.com/pomerium/csrf v1.7.0
|
github.com/pomerium/csrf v1.7.0
|
||||||
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524
|
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524
|
||||||
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b
|
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b
|
||||||
github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31
|
github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05
|
||||||
github.com/prometheus/client_golang v1.16.0
|
github.com/prometheus/client_golang v1.16.0
|
||||||
github.com/prometheus/client_model v0.4.0
|
github.com/prometheus/client_model v0.4.0
|
||||||
github.com/prometheus/common v0.44.0
|
github.com/prometheus/common v0.44.0
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -657,8 +657,8 @@ github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 h1:3YQY1sb5
|
||||||
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0=
|
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0=
|
||||||
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b h1:oll/aOfJudnqFAwCvoXK9+WN2zVjTzHVPLXCggHQmHk=
|
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b h1:oll/aOfJudnqFAwCvoXK9+WN2zVjTzHVPLXCggHQmHk=
|
||||||
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b/go.mod h1:KswTenBBh4y1pmhU2dpm8VgJQCgSErCg7OOFTeebrNc=
|
github.com/pomerium/webauthn v0.0.0-20221118023040-00a9c430578b/go.mod h1:KswTenBBh4y1pmhU2dpm8VgJQCgSErCg7OOFTeebrNc=
|
||||||
github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31 h1:FUoy3dpWd0ECmFHmYVUCJd7lVhQjglZsnrfjvRsVj4w=
|
github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05 h1:Rl2df8q+DAd3SsJn9MpXrbo7JRNCDHVaohOyUZ2IJik=
|
||||||
github.com/pomerium/zero-sdk v0.0.0-20230813022804-3bf1f871ab31/go.mod h1:cAyfEGM8blUzchYhOWrufuj/6lOF277meB4c/TjMS28=
|
github.com/pomerium/zero-sdk v0.0.0-20230816000855-af1b8165df05/go.mod h1:cAyfEGM8blUzchYhOWrufuj/6lOF277meB4c/TjMS28=
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||||
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
package reconciler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetBundles returns the list of bundles that have to be present in the cluster.
|
|
||||||
func (c *service) RefreshBundleList(ctx context.Context) error {
|
|
||||||
resp, err := c.config.api.GetClusterResourceBundles(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get bundles: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ids := make([]string, 0, len(resp.Bundles))
|
|
||||||
for _, v := range resp.Bundles {
|
|
||||||
ids = append(ids, v.Id)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.bundles.Set(ids)
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -8,21 +8,17 @@ import (
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
sdk "github.com/pomerium/zero-sdk"
|
sdk "github.com/pomerium/zero-sdk"
|
||||||
connect_mux "github.com/pomerium/zero-sdk/connect-mux"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// reconcilerConfig contains the configuration for the resource bundles reconciler.
|
// reconcilerConfig contains the configuration for the resource bundles reconciler.
|
||||||
type reconcilerConfig struct {
|
type reconcilerConfig struct {
|
||||||
api *sdk.API
|
api *sdk.API
|
||||||
connectMux *connect_mux.Mux
|
|
||||||
|
|
||||||
databrokerClient databroker.DataBrokerServiceClient
|
databrokerClient databroker.DataBrokerServiceClient
|
||||||
databrokerRPS int
|
databrokerRPS int
|
||||||
|
|
||||||
tmpDir string
|
tmpDir string
|
||||||
|
|
||||||
minDownloadTTL time.Duration
|
|
||||||
|
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
|
||||||
checkForUpdateIntervalWhenDisconnected time.Duration
|
checkForUpdateIntervalWhenDisconnected time.Duration
|
||||||
|
@ -49,13 +45,6 @@ func WithAPI(client *sdk.API) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithConnectMux configures the connect mux.
|
|
||||||
func WithConnectMux(client *connect_mux.Mux) Option {
|
|
||||||
return func(cfg *reconcilerConfig) {
|
|
||||||
cfg.connectMux = client
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithDataBrokerClient configures the databroker client.
|
// WithDataBrokerClient configures the databroker client.
|
||||||
func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option {
|
func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option {
|
||||||
return func(cfg *reconcilerConfig) {
|
return func(cfg *reconcilerConfig) {
|
||||||
|
@ -63,14 +52,6 @@ func WithDataBrokerClient(client databroker.DataBrokerServiceClient) Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithMinDownloadTTL configures the download URL validity in cache,
|
|
||||||
// before it would be requested again from the cloud (that also sets it's own TTL to the signed URL).
|
|
||||||
func WithMinDownloadTTL(ttl time.Duration) Option {
|
|
||||||
return func(cfg *reconcilerConfig) {
|
|
||||||
cfg.minDownloadTTL = ttl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithDownloadHTTPClient configures the http client used for downloading files.
|
// WithDownloadHTTPClient configures the http client used for downloading files.
|
||||||
func WithDownloadHTTPClient(client *http.Client) Option {
|
func WithDownloadHTTPClient(client *http.Client) Option {
|
||||||
return func(cfg *reconcilerConfig) {
|
return func(cfg *reconcilerConfig) {
|
||||||
|
@ -112,7 +93,6 @@ func newConfig(opts ...Option) *reconcilerConfig {
|
||||||
cfg := &reconcilerConfig{}
|
cfg := &reconcilerConfig{}
|
||||||
for _, opt := range []Option{
|
for _, opt := range []Option{
|
||||||
WithTemporaryDirectory(os.TempDir()),
|
WithTemporaryDirectory(os.TempDir()),
|
||||||
WithMinDownloadTTL(5 * time.Minute),
|
|
||||||
WithDownloadHTTPClient(http.DefaultClient),
|
WithDownloadHTTPClient(http.DefaultClient),
|
||||||
WithDatabrokerRPSLimit(1_000),
|
WithDatabrokerRPSLimit(1_000),
|
||||||
WithCheckForUpdateIntervalWhenDisconnected(time.Minute * 5),
|
WithCheckForUpdateIntervalWhenDisconnected(time.Minute * 5),
|
||||||
|
|
|
@ -1,132 +0,0 @@
|
||||||
package reconciler
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type downloadOptions struct {
|
|
||||||
requestCacheEntry *BundleCacheEntry
|
|
||||||
responseCacheEntry *BundleCacheEntry
|
|
||||||
}
|
|
||||||
|
|
||||||
// DownloadOption is an option for downloading a bundle
|
|
||||||
type DownloadOption func(*downloadOptions)
|
|
||||||
|
|
||||||
// WithCacheEntry sets the cache entry to use for the request.
|
|
||||||
func WithCacheEntry(entry BundleCacheEntry) DownloadOption {
|
|
||||||
return func(opts *downloadOptions) {
|
|
||||||
opts.requestCacheEntry = &entry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithUpdateCacheEntry updates the cache entry with the values from the response.
|
|
||||||
func WithUpdateCacheEntry(dst *BundleCacheEntry) DownloadOption {
|
|
||||||
return func(opts *downloadOptions) {
|
|
||||||
opts.responseCacheEntry = dst
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDownloadOptions(opts ...DownloadOption) downloadOptions {
|
|
||||||
var options downloadOptions
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(&options)
|
|
||||||
}
|
|
||||||
return options
|
|
||||||
}
|
|
||||||
|
|
||||||
func (opt *downloadOptions) updateRequest(req *http.Request) {
|
|
||||||
if opt.requestCacheEntry != nil {
|
|
||||||
req.Header.Set("If-None-Match", opt.requestCacheEntry.ETag)
|
|
||||||
req.Header.Set("If-Modified-Since", opt.requestCacheEntry.LastModified.Format(http.TimeFormat))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (opt *downloadOptions) updateFromResponse(resp *http.Response) error {
|
|
||||||
if opt.responseCacheEntry == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusNotModified && opt.requestCacheEntry != nil {
|
|
||||||
*opt.responseCacheEntry = *opt.requestCacheEntry
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return updateBundleCacheEntryFromResponse(opt.responseCacheEntry, resp.Header)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DownloadBundleIfChanged downloads the bundle if it has changed.
|
|
||||||
func (c *service) DownloadBundleIfChanged(
|
|
||||||
ctx context.Context,
|
|
||||||
dst io.Writer,
|
|
||||||
bundleKey string,
|
|
||||||
opts ...DownloadOption,
|
|
||||||
) error {
|
|
||||||
options := getDownloadOptions(opts...)
|
|
||||||
|
|
||||||
url, err := c.config.api.DownloadClusterResourceBundle(ctx, bundleKey, c.config.minDownloadTTL)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get download url: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("new request: %w", err)
|
|
||||||
}
|
|
||||||
options.updateRequest(req)
|
|
||||||
|
|
||||||
resp, err := c.config.httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("do request: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
err = options.updateFromResponse(resp)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusNotModified {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("unexpected response: %d/%s", resp.StatusCode, resp.Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = io.Copy(dst, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("write file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateBundleCacheEntryFromResponse(dst *BundleCacheEntry, headers http.Header) error {
|
|
||||||
txt := headers.Get("Last-Modified")
|
|
||||||
if txt == "" {
|
|
||||||
return fmt.Errorf("missing last-modified header")
|
|
||||||
}
|
|
||||||
|
|
||||||
lastModified, err := time.Parse(http.TimeFormat, txt)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse last modified: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
etag := headers.Get("ETag")
|
|
||||||
if etag == "" {
|
|
||||||
return fmt.Errorf("missing etag header")
|
|
||||||
}
|
|
||||||
|
|
||||||
dst.LastModified = lastModified
|
|
||||||
dst.ETag = etag
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -3,7 +3,19 @@ package reconciler
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"time"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
zero_sdk "github.com/pomerium/zero-sdk"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BundleCacheEntry is a cache entry for a bundle
|
// BundleCacheEntry is a cache entry for a bundle
|
||||||
|
@ -16,29 +28,139 @@ import (
|
||||||
// also it works in case of multiple instances, as it uses
|
// also it works in case of multiple instances, as it uses
|
||||||
// the databroker database as a shared cache.
|
// the databroker database as a shared cache.
|
||||||
type BundleCacheEntry struct {
|
type BundleCacheEntry struct {
|
||||||
ETag string
|
zero_sdk.DownloadConditional
|
||||||
LastModified time.Time
|
RecordTypes []string
|
||||||
RecordTypes []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
bundleCacheEntryRecordType = "pomerium.io/BundleCacheEntry"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrBundleCacheEntryNotFound is returned when a bundle cache entry is not found
|
// ErrBundleCacheEntryNotFound is returned when a bundle cache entry is not found
|
||||||
ErrBundleCacheEntryNotFound = errors.New("bundle cache entry not found")
|
ErrBundleCacheEntryNotFound = errors.New("bundle cache entry not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
// Equals returns true if the two cache entries are equal
|
|
||||||
func (c *BundleCacheEntry) Equals(other BundleCacheEntry) bool {
|
|
||||||
return c.ETag == other.ETag && c.LastModified.Equal(other.LastModified)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetBundleCacheEntry gets a bundle cache entry from the databroker
|
// GetBundleCacheEntry gets a bundle cache entry from the databroker
|
||||||
func (c *service) GetBundleCacheEntry(_ context.Context, _ string, _ *BundleCacheEntry) error {
|
func (c *service) GetBundleCacheEntry(ctx context.Context, id string) (*BundleCacheEntry, error) {
|
||||||
// TODO: implement
|
record, err := c.config.databrokerClient.Get(ctx, &databroker.GetRequest{
|
||||||
return ErrBundleCacheEntryNotFound
|
Type: bundleCacheEntryRecordType,
|
||||||
|
Id: id,
|
||||||
|
})
|
||||||
|
if err != nil && status.Code(err) == codes.NotFound {
|
||||||
|
return nil, ErrBundleCacheEntryNotFound
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, fmt.Errorf("get bundle cache entry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dst BundleCacheEntry
|
||||||
|
data := record.GetRecord().GetData()
|
||||||
|
err = dst.FromAny(data)
|
||||||
|
if err != nil {
|
||||||
|
log.Ctx(ctx).Error().Err(err).
|
||||||
|
Str("bundle-id", id).
|
||||||
|
Str("data", protojson.Format(data)).
|
||||||
|
Msg("unmarshal bundle cache entry")
|
||||||
|
// we would allow it to be overwritten by the update process
|
||||||
|
return nil, ErrBundleCacheEntryNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dst, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetBundleCacheEntry sets a bundle cache entry in the databroker
|
// SetBundleCacheEntry sets a bundle cache entry in the databroker
|
||||||
func (c *service) SetBundleCacheEntry(_ context.Context, _ string, _ BundleCacheEntry) error {
|
func (c *service) SetBundleCacheEntry(ctx context.Context, id string, src BundleCacheEntry) error {
|
||||||
// TODO: implement
|
val, err := src.ToAny()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("marshal bundle cache entry: %w", err)
|
||||||
|
}
|
||||||
|
resp, err := c.config.databrokerClient.Put(ctx, &databroker.PutRequest{
|
||||||
|
Records: []*databroker.Record{
|
||||||
|
{
|
||||||
|
Type: bundleCacheEntryRecordType,
|
||||||
|
Id: id,
|
||||||
|
Data: val,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("set bundle cache entry: %w", err)
|
||||||
|
}
|
||||||
|
log.Ctx(ctx).Info().
|
||||||
|
Str("bundle-id", id).
|
||||||
|
Str("sent", protojson.Format(val)).
|
||||||
|
Str("got", protojson.Format(resp.GetRecord().GetData())).
|
||||||
|
Msg("set bundle cache entry")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToAny marshals a BundleCacheEntry into an anypb.Any
|
||||||
|
func (r *BundleCacheEntry) ToAny() (*anypb.Any, error) {
|
||||||
|
err := r.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
types := make([]*structpb.Value, 0, len(r.RecordTypes))
|
||||||
|
for _, t := range r.RecordTypes {
|
||||||
|
types = append(types, structpb.NewStringValue(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
return protoutil.NewAny(&structpb.Struct{
|
||||||
|
Fields: map[string]*structpb.Value{
|
||||||
|
"etag": structpb.NewStringValue(r.ETag),
|
||||||
|
"last_modified": structpb.NewStringValue(r.LastModified),
|
||||||
|
"record_types": structpb.NewListValue(&structpb.ListValue{Values: types}),
|
||||||
|
},
|
||||||
|
}), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromAny unmarshals an anypb.Any into a BundleCacheEntry
|
||||||
|
func (r *BundleCacheEntry) FromAny(any *anypb.Any) error {
|
||||||
|
var s structpb.Struct
|
||||||
|
err := any.UnmarshalTo(&s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unmarshal struct: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Println("****", protojson.Format(any), protojson.Format(&s))
|
||||||
|
|
||||||
|
r.ETag = s.GetFields()["etag"].GetStringValue()
|
||||||
|
r.LastModified = s.GetFields()["last_modified"].GetStringValue()
|
||||||
|
|
||||||
|
for _, v := range s.GetFields()["record_types"].GetListValue().GetValues() {
|
||||||
|
r.RecordTypes = append(r.RecordTypes, v.GetStringValue())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = r.Validate()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("validate: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates a BundleCacheEntry
|
||||||
|
func (r *BundleCacheEntry) Validate() error {
|
||||||
|
var errs *multierror.Error
|
||||||
|
if len(r.RecordTypes) == 0 {
|
||||||
|
errs = multierror.Append(errs, errors.New("record_types is required"))
|
||||||
|
}
|
||||||
|
if err := r.DownloadConditional.Validate(); err != nil {
|
||||||
|
errs = multierror.Append(errs, err)
|
||||||
|
}
|
||||||
|
return errs.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDownloadConditional returns conditional download information
|
||||||
|
func (r *BundleCacheEntry) GetDownloadConditional() *zero_sdk.DownloadConditional {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cond := r.DownloadConditional
|
||||||
|
return &cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equals returns true if the two cache entries are equal
|
||||||
|
func (r *BundleCacheEntry) Equals(other *BundleCacheEntry) bool {
|
||||||
|
return r != nil && other != nil &&
|
||||||
|
r.ETag == other.ETag && r.LastModified == other.LastModified
|
||||||
|
}
|
||||||
|
|
29
internal/zero/reconciler/download_cache_test.go
Normal file
29
internal/zero/reconciler/download_cache_test.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package reconciler_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
||||||
|
zero_sdk "github.com/pomerium/zero-sdk"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCacheEntryProto(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
original := reconciler.BundleCacheEntry{
|
||||||
|
DownloadConditional: zero_sdk.DownloadConditional{
|
||||||
|
ETag: "etag value",
|
||||||
|
LastModified: "2009-02-13 18:31:30 -0500 EST",
|
||||||
|
},
|
||||||
|
RecordTypes: []string{"one", "two"},
|
||||||
|
}
|
||||||
|
originalProto, err := original.ToAny()
|
||||||
|
require.NoError(t, err)
|
||||||
|
var unmarshaled reconciler.BundleCacheEntry
|
||||||
|
err = unmarshaled.FromAny(originalProto)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, original.Equals(&unmarshaled))
|
||||||
|
}
|
|
@ -52,7 +52,7 @@ func Run(ctx context.Context, opts ...Option) error {
|
||||||
// it may be later optimized by splitting between download and reconciliation process,
|
// it may be later optimized by splitting between download and reconciliation process,
|
||||||
// as we would get more resource bundles beyond the config.
|
// as we would get more resource bundles beyond the config.
|
||||||
func (c *service) watchUpdates(ctx context.Context) error {
|
func (c *service) watchUpdates(ctx context.Context) error {
|
||||||
return c.config.connectMux.Watch(ctx,
|
return c.config.api.Watch(ctx,
|
||||||
connect_mux.WithOnConnected(func(ctx context.Context) {
|
connect_mux.WithOnConnected(func(ctx context.Context) {
|
||||||
c.triggerFullUpdate(true)
|
c.triggerFullUpdate(true)
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -16,7 +16,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
@ -29,22 +28,27 @@ func (c *service) SyncLoop(ctx context.Context) error {
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
ticker.Reset(c.periodicUpdateInterval.Load())
|
dur := c.periodicUpdateInterval.Load()
|
||||||
|
ticker.Reset(dur)
|
||||||
|
log.Ctx(ctx).Info().Str("duration", dur.String()).Msg("*** next sync cycle ***")
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-c.bundleSyncRequest:
|
case <-c.bundleSyncRequest:
|
||||||
|
log.Ctx(ctx).Info().Msg("bundle sync triggered")
|
||||||
err := c.syncBundles(ctx)
|
err := c.syncBundles(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reconciler: sync bundles: %w", err)
|
return fmt.Errorf("reconciler: sync bundles: %w", err)
|
||||||
}
|
}
|
||||||
case <-c.fullSyncRequest:
|
case <-c.fullSyncRequest:
|
||||||
|
log.Ctx(ctx).Info().Msg("full sync triggered")
|
||||||
err := c.syncAll(ctx)
|
err := c.syncAll(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reconciler: sync all: %w", err)
|
return fmt.Errorf("reconciler: sync all: %w", err)
|
||||||
}
|
}
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
|
log.Ctx(ctx).Info().Msg("periodic sync triggered")
|
||||||
err := c.syncAll(ctx)
|
err := c.syncAll(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("reconciler: sync all: %w", err)
|
return fmt.Errorf("reconciler: sync all: %w", err)
|
||||||
|
@ -72,7 +76,7 @@ func (c *service) syncBundleList(ctx context.Context) error {
|
||||||
// refresh bundle list,
|
// refresh bundle list,
|
||||||
// ignoring other signals while we're retrying
|
// ignoring other signals while we're retrying
|
||||||
return retry.Retry(ctx,
|
return retry.Retry(ctx,
|
||||||
"refresh bundle list", c.RefreshBundleList,
|
"refresh bundle list", c.refreshBundleList,
|
||||||
retry.WithWatch("refresh bundle list", c.fullSyncRequest, nil),
|
retry.WithWatch("refresh bundle list", c.fullSyncRequest, nil),
|
||||||
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
|
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
|
||||||
)
|
)
|
||||||
|
@ -83,7 +87,7 @@ func (c *service) syncBundleList(ctx context.Context) error {
|
||||||
func (c *service) syncBundles(ctx context.Context) error {
|
func (c *service) syncBundles(ctx context.Context) error {
|
||||||
return retry.Retry(ctx,
|
return retry.Retry(ctx,
|
||||||
"sync bundles", c.trySyncBundles,
|
"sync bundles", c.trySyncBundles,
|
||||||
retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.RefreshBundleList),
|
retry.WithWatch("refresh bundle list", c.fullSyncRequest, c.refreshBundleList),
|
||||||
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
|
retry.WithWatch("bundle update", c.bundleSyncRequest, nil),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -110,38 +114,41 @@ func (c *service) trySyncBundles(ctx context.Context) error {
|
||||||
// This is only persisted in the databroker after all records are successfully synced.
|
// This is only persisted in the databroker after all records are successfully synced.
|
||||||
// That allows us to ignore any changes based on the same bundle state, without need to re-check all records between bundle and databroker.
|
// That allows us to ignore any changes based on the same bundle state, without need to re-check all records between bundle and databroker.
|
||||||
func (c *service) syncBundle(ctx context.Context, key string) error {
|
func (c *service) syncBundle(ctx context.Context, key string) error {
|
||||||
var cached, changed BundleCacheEntry
|
cached, err := c.GetBundleCacheEntry(ctx, key)
|
||||||
opts := []DownloadOption{
|
if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) {
|
||||||
WithUpdateCacheEntry(&changed),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.GetBundleCacheEntry(ctx, key, &cached)
|
|
||||||
if err == nil {
|
|
||||||
opts = append(opts, WithCacheEntry(cached))
|
|
||||||
} else if err != nil && !errors.Is(err, ErrBundleCacheEntryNotFound) {
|
|
||||||
return fmt.Errorf("get bundle cache entry: %w", err)
|
return fmt.Errorf("get bundle cache entry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// download is much faster compared to databroker sync,
|
// download is much faster compared to databroker sync,
|
||||||
// so we don't use pipe but rather download to a temp file and then sync it to databroker
|
// so we don't use pipe but rather download to a temp file and then sync it to databroker
|
||||||
|
fd, err := c.GetTmpFile(key)
|
||||||
fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create temp file: %w", err)
|
return fmt.Errorf("get tmp file: %w", err)
|
||||||
}
|
}
|
||||||
defer fd.Close()
|
defer func() {
|
||||||
defer os.Remove(fd.Name())
|
if err := fd.Close(); err != nil {
|
||||||
|
log.Ctx(ctx).Error().Err(err).Msg("close tmp file")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
err = c.DownloadBundleIfChanged(ctx, fd, key, opts...)
|
conditional := cached.GetDownloadConditional()
|
||||||
|
log.Ctx(ctx).Info().Str("id", key).Any("conditional", conditional).Msg("downloading bundle")
|
||||||
|
|
||||||
|
result, err := c.config.api.DownloadClusterResourceBundle(ctx, fd, key, conditional)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("download bundle: %w", err)
|
return fmt.Errorf("download bundle: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if changed.Equals(cached) {
|
if result.NotModified {
|
||||||
log.Ctx(ctx).Info().Str("bundle", key).Msg("bundle not changed")
|
log.Ctx(ctx).Info().Str("bundle", key).Msg("bundle not changed")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Ctx(ctx).Info().Str("bundle", key).
|
||||||
|
Interface("cached-entry", cached).
|
||||||
|
Interface("current-entry", result.DownloadConditional).
|
||||||
|
Msg("bundle changed")
|
||||||
|
|
||||||
_, err = fd.Seek(0, io.SeekStart)
|
_, err = fd.Seek(0, io.SeekStart)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("seek to start: %w", err)
|
return fmt.Errorf("seek to start: %w", err)
|
||||||
|
@ -151,16 +158,19 @@ func (c *service) syncBundle(ctx context.Context, key string) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("apply bundle to databroker: %w", err)
|
return fmt.Errorf("apply bundle to databroker: %w", err)
|
||||||
}
|
}
|
||||||
changed.RecordTypes = bundleRecordTypes
|
current := BundleCacheEntry{
|
||||||
|
DownloadConditional: *result.DownloadConditional,
|
||||||
|
RecordTypes: bundleRecordTypes,
|
||||||
|
}
|
||||||
|
|
||||||
log.Ctx(ctx).Info().
|
log.Ctx(ctx).Info().
|
||||||
Str("bundle", key).
|
Str("bundle", key).
|
||||||
Strs("record_types", bundleRecordTypes).
|
Strs("record_types", bundleRecordTypes).
|
||||||
Str("etag", changed.ETag).
|
Str("etag", current.ETag).
|
||||||
Time("last_modified", changed.LastModified).
|
Str("last_modified", current.LastModified).
|
||||||
Msg("bundle synced")
|
Msg("bundle synced")
|
||||||
|
|
||||||
err = c.SetBundleCacheEntry(ctx, key, changed)
|
err = c.SetBundleCacheEntry(ctx, key, current)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("set bundle cache entry: %w", err)
|
return fmt.Errorf("set bundle cache entry: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -198,3 +208,18 @@ func (c *service) syncBundleToDatabroker(ctx context.Context, src io.Reader) ([]
|
||||||
|
|
||||||
return bundleRecords.RecordTypes(), nil
|
return bundleRecords.RecordTypes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *service) refreshBundleList(ctx context.Context) error {
|
||||||
|
resp, err := c.config.api.GetClusterResourceBundles(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get bundles: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make([]string, 0, len(resp.Bundles))
|
||||||
|
for _, v := range resp.Bundles {
|
||||||
|
ids = append(ids, v.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.bundles.Set(ids)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
40
internal/zero/reconciler/tmpfile.go
Normal file
40
internal/zero/reconciler/tmpfile.go
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package reconciler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadWriteSeekCloser is a file that can be read, written, seeked, and closed.
|
||||||
|
type ReadWriteSeekCloser interface {
|
||||||
|
io.ReadWriteSeeker
|
||||||
|
io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTmpFile returns a temporary file for the reconciler to use.
|
||||||
|
// TODO: encrypt contents to ensure encryption at rest
|
||||||
|
func (c *service) GetTmpFile(key string) (ReadWriteSeekCloser, error) {
|
||||||
|
fd, err := os.CreateTemp(c.config.tmpDir, fmt.Sprintf("pomerium-bundle-%s", key))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create temp file: %w", err)
|
||||||
|
}
|
||||||
|
return &tmpFile{File: fd}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type tmpFile struct {
|
||||||
|
*os.File
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *tmpFile) Close() error {
|
||||||
|
var errs *multierror.Error
|
||||||
|
if err := f.File.Close(); err != nil {
|
||||||
|
errs = multierror.Append(errs, err)
|
||||||
|
}
|
||||||
|
if err := os.Remove(f.File.Name()); err != nil {
|
||||||
|
errs = multierror.Append(errs, err)
|
||||||
|
}
|
||||||
|
return errs.ErrorOrNil()
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue