mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 04:42:56 +02:00
zero: only leave public packages in pkg/zero (#4854)
This commit is contained in:
parent
a6ae9d3f2d
commit
b66634d1e6
24 changed files with 22 additions and 22 deletions
123
internal/zero/api/api.go
Normal file
123
internal/zero/api/api.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
// Package zero contains the pomerium zero configuration API client
|
||||
package zero
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/apierror"
|
||||
connect_mux "github.com/pomerium/pomerium/internal/zero/connect-mux"
|
||||
token_api "github.com/pomerium/pomerium/internal/zero/token"
|
||||
"github.com/pomerium/pomerium/pkg/fanout"
|
||||
cluster_api "github.com/pomerium/pomerium/pkg/zero/cluster"
|
||||
connect_api "github.com/pomerium/pomerium/pkg/zero/connect"
|
||||
)
|
||||
|
||||
// API is a Pomerium Zero Cluster API client
|
||||
type API struct {
|
||||
cfg *config
|
||||
cluster cluster_api.ClientWithResponsesInterface
|
||||
mux *connect_mux.Mux
|
||||
downloadURLCache *cluster_api.URLCache
|
||||
}
|
||||
|
||||
// WatchOption defines which events to watch for
|
||||
type WatchOption = connect_mux.WatchOption
|
||||
|
||||
// NewAPI creates a new API client
|
||||
func NewAPI(ctx context.Context, opts ...Option) (*API, error) {
|
||||
cfg, err := newConfig(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fetcher, err := cluster_api.NewTokenFetcher(cfg.clusterAPIEndpoint,
|
||||
cluster_api.WithHTTPClient(cfg.httpClient),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating token fetcher: %w", err)
|
||||
}
|
||||
|
||||
tokenCache := token_api.NewCache(fetcher, cfg.apiToken)
|
||||
|
||||
clusterClient, err := cluster_api.NewAuthorizedClient(cfg.clusterAPIEndpoint, tokenCache.GetToken, cfg.httpClient)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating cluster client: %w", err)
|
||||
}
|
||||
|
||||
connectClient, err := connect_api.NewAuthorizedConnectClient(ctx, cfg.connectAPIEndpoint, tokenCache.GetToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connect client: %w", err)
|
||||
}
|
||||
|
||||
return &API{
|
||||
cfg: cfg,
|
||||
cluster: clusterClient,
|
||||
mux: connect_mux.New(connectClient),
|
||||
downloadURLCache: cluster_api.NewURLCache(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect connects to the connect API and allows watching for changes
|
||||
func (api *API) Connect(ctx context.Context, opts ...fanout.Option) error {
|
||||
return api.mux.Run(ctx, opts...)
|
||||
}
|
||||
|
||||
// Watch dispatches API updates
|
||||
func (api *API) Watch(ctx context.Context, opts ...WatchOption) error {
|
||||
return api.mux.Watch(ctx, opts...)
|
||||
}
|
||||
|
||||
// GetClusterBootstrapConfig fetches the bootstrap configuration from the cluster API
|
||||
func (api *API) GetClusterBootstrapConfig(ctx context.Context) (*cluster_api.BootstrapConfig, error) {
|
||||
return apierror.CheckResponse[cluster_api.BootstrapConfig](
|
||||
api.cluster.GetClusterBootstrapConfigWithResponse(ctx),
|
||||
)
|
||||
}
|
||||
|
||||
// GetClusterResourceBundles fetches the resource bundles from the cluster API
|
||||
func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.GetBundlesResponse, error) {
|
||||
return apierror.CheckResponse[cluster_api.GetBundlesResponse](
|
||||
api.cluster.GetClusterResourceBundlesWithResponse(ctx),
|
||||
)
|
||||
}
|
||||
|
||||
// ReportBundleAppliedSuccess reports a successful bundle application
|
||||
func (api *API) ReportBundleAppliedSuccess(ctx context.Context, bundleID string, metadata map[string]string) error {
|
||||
status := cluster_api.BundleStatus{
|
||||
Success: &cluster_api.BundleStatusSuccess{
|
||||
Metadata: metadata,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := apierror.CheckResponse[cluster_api.EmptyResponse](
|
||||
api.cluster.ReportClusterResourceBundleStatusWithResponse(ctx, bundleID, status),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reporting bundle status: %w", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// ReportBundleAppliedFailure reports a failed bundle application
|
||||
func (api *API) ReportBundleAppliedFailure(
|
||||
ctx context.Context,
|
||||
bundleID string,
|
||||
source cluster_api.BundleStatusFailureSource,
|
||||
err error,
|
||||
) error {
|
||||
status := cluster_api.BundleStatus{
|
||||
Failure: &cluster_api.BundleStatusFailure{
|
||||
Message: err.Error(),
|
||||
Source: source,
|
||||
},
|
||||
}
|
||||
|
||||
_, err = apierror.CheckResponse[cluster_api.EmptyResponse](
|
||||
api.cluster.ReportClusterResourceBundleStatusWithResponse(ctx, bundleID, status),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reporting bundle status: %w", err)
|
||||
}
|
||||
return err
|
||||
}
|
87
internal/zero/api/config.go
Normal file
87
internal/zero/api/config.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package zero
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Option is a functional option for the SDK
|
||||
type Option func(*config)
|
||||
|
||||
type config struct {
|
||||
clusterAPIEndpoint string
|
||||
connectAPIEndpoint string
|
||||
apiToken string
|
||||
httpClient *http.Client
|
||||
downloadURLCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// WithClusterAPIEndpoint sets the cluster API endpoint
|
||||
func WithClusterAPIEndpoint(endpoint string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.clusterAPIEndpoint = endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// WithConnectAPIEndpoint sets the connect API endpoint
|
||||
func WithConnectAPIEndpoint(endpoint string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.connectAPIEndpoint = endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// WithAPIToken sets the API token
|
||||
func WithAPIToken(token string) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.apiToken = token
|
||||
}
|
||||
}
|
||||
|
||||
// WithHTTPClient sets the HTTP client
|
||||
func WithHTTPClient(client *http.Client) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
// WithDownloadURLCacheTTL sets the minimum TTL for download URL cache entries
|
||||
func WithDownloadURLCacheTTL(ttl time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.downloadURLCacheTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(opts ...Option) (*config, error) {
|
||||
cfg := new(config)
|
||||
for _, opt := range []Option{
|
||||
WithHTTPClient(http.DefaultClient),
|
||||
WithDownloadURLCacheTTL(15 * time.Minute),
|
||||
} {
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (c *config) validate() error {
|
||||
if c.clusterAPIEndpoint == "" {
|
||||
return fmt.Errorf("cluster API endpoint is required")
|
||||
}
|
||||
if c.connectAPIEndpoint == "" {
|
||||
return fmt.Errorf("connect API endpoint is required")
|
||||
}
|
||||
if c.apiToken == "" {
|
||||
return fmt.Errorf("API token is required")
|
||||
}
|
||||
if c.httpClient == nil {
|
||||
return fmt.Errorf("HTTP client is required")
|
||||
}
|
||||
return nil
|
||||
}
|
251
internal/zero/api/download.go
Normal file
251
internal/zero/api/download.go
Normal file
|
@ -0,0 +1,251 @@
|
|||
package zero
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/zero/apierror"
|
||||
cluster_api "github.com/pomerium/pomerium/pkg/zero/cluster"
|
||||
)
|
||||
|
||||
const (
|
||||
maxErrorResponseBodySize = 2 << 14 // 32kb
|
||||
maxUncompressedBlobSize = 2 << 30 // 1gb
|
||||
)
|
||||
|
||||
// DownloadClusterResourceBundle downloads given cluster resource bundle to given writer.
|
||||
func (api *API) DownloadClusterResourceBundle(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
id string,
|
||||
current *DownloadConditional,
|
||||
) (*DownloadResult, error) {
|
||||
req, err := api.getDownloadRequest(ctx, id, current)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get download request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := api.cfg.httpClient.Do(req.Request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotModified {
|
||||
return &DownloadResult{NotModified: true}, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, httpDownloadError(ctx, resp)
|
||||
}
|
||||
|
||||
var r io.Reader = resp.Body
|
||||
if resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
zr, err := gzip.NewReader(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip reader: %w", err)
|
||||
}
|
||||
defer zr.Close()
|
||||
|
||||
r = io.LimitReader(zr, maxUncompressedBlobSize)
|
||||
}
|
||||
|
||||
_, err = io.Copy(dst, r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("write body: %w", err)
|
||||
}
|
||||
|
||||
updated, err := newConditionalFromResponse(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot obtain cache conditions from response: %w", err)
|
||||
}
|
||||
|
||||
return &DownloadResult{
|
||||
DownloadConditional: updated,
|
||||
Metadata: extractMetadata(resp.Header, req.CaptureHeaders),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type downloadRequest struct {
|
||||
*http.Request
|
||||
cluster_api.DownloadCacheEntry
|
||||
}
|
||||
|
||||
func (api *API) getDownloadRequest(ctx context.Context, id string, current *DownloadConditional) (*downloadRequest, error) {
|
||||
params, err := api.getDownloadParams(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get download URL: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, params.URL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
|
||||
err = current.SetHeaders(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("set conditional download headers: %w", err)
|
||||
}
|
||||
|
||||
return &downloadRequest{
|
||||
Request: req,
|
||||
DownloadCacheEntry: *params,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (api *API) getDownloadParams(ctx context.Context, id string) (*cluster_api.DownloadCacheEntry, error) {
|
||||
param, ok := api.downloadURLCache.Get(id, api.cfg.downloadURLCacheTTL)
|
||||
if ok {
|
||||
return param, nil
|
||||
}
|
||||
|
||||
return api.updateBundleDownloadParams(ctx, id)
|
||||
}
|
||||
|
||||
func (api *API) updateBundleDownloadParams(ctx context.Context, id string) (*cluster_api.DownloadCacheEntry, error) {
|
||||
now := time.Now()
|
||||
|
||||
resp, err := apierror.CheckResponse[cluster_api.DownloadBundleResponse](
|
||||
api.cluster.DownloadClusterResourceBundleWithResponse(ctx, id),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get bundle download URL: %w", err)
|
||||
}
|
||||
|
||||
expiresSeconds, err := strconv.ParseInt(resp.ExpiresInSeconds, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse expiration: %w", err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(resp.Url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse url: %w", err)
|
||||
}
|
||||
|
||||
param := cluster_api.DownloadCacheEntry{
|
||||
URL: *u,
|
||||
ExpiresAt: now.Add(time.Duration(expiresSeconds) * time.Second),
|
||||
CaptureHeaders: resp.CaptureMetadataHeaders,
|
||||
}
|
||||
api.downloadURLCache.Set(id, param)
|
||||
return ¶m, nil
|
||||
}
|
||||
|
||||
// DownloadResult contains the result of a download operation
|
||||
type DownloadResult struct {
|
||||
// NotModified is true if the bundle has not been modified
|
||||
NotModified bool
|
||||
// DownloadConditional contains the new conditional
|
||||
*DownloadConditional
|
||||
// Metadata contains the metadata of the downloaded bundle
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
// DownloadConditional contains the conditional headers for a download operation
|
||||
type DownloadConditional struct {
|
||||
ETag string
|
||||
LastModified string
|
||||
}
|
||||
|
||||
// Validate validates the conditional headers
|
||||
func (c *DownloadConditional) Validate() error {
|
||||
if c.ETag == "" && c.LastModified == "" {
|
||||
return fmt.Errorf("either ETag or LastModified must be set")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetHeaders sets the conditional headers on the given request
|
||||
func (c *DownloadConditional) SetHeaders(req *http.Request) error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if err := c.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("If-None-Match", c.ETag)
|
||||
req.Header.Set("If-Modified-Since", c.LastModified)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newConditionalFromResponse(resp *http.Response) (*DownloadConditional, error) {
|
||||
c := &DownloadConditional{
|
||||
ETag: resp.Header.Get("ETag"),
|
||||
LastModified: resp.Header.Get("Last-Modified"),
|
||||
}
|
||||
if err := c.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
type xmlError struct {
|
||||
XMLName xml.Name `xml:"Error"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
Details string `xml:"Details"`
|
||||
}
|
||||
|
||||
func (e xmlError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
func tryXMLError(body []byte) (bool, error) {
|
||||
var xmlErr xmlError
|
||||
err := xml.Unmarshal(body, &xmlErr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("unmarshal xml error: %w", err)
|
||||
}
|
||||
|
||||
return true, xmlErr
|
||||
}
|
||||
|
||||
func httpDownloadError(ctx context.Context, resp *http.Response) error {
|
||||
var buf bytes.Buffer
|
||||
_, err := io.Copy(&buf, io.LimitReader(resp.Body, maxErrorResponseBodySize))
|
||||
|
||||
if isXML(resp.Header.Get("Content-Type")) {
|
||||
ok, err := tryXMLError(buf.Bytes())
|
||||
if ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Debug().Err(err).
|
||||
Str("error", resp.Status).
|
||||
Str("body", buf.String()).Msg("bundle download error")
|
||||
|
||||
return fmt.Errorf("download error: %s", resp.Status)
|
||||
}
|
||||
|
||||
// isXML parses content-type for application/xml
|
||||
func isXML(ct string) bool {
|
||||
mediaType, _, err := mime.ParseMediaType(ct)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return mediaType == "application/xml"
|
||||
}
|
||||
|
||||
func extractMetadata(header http.Header, keys []string) map[string]string {
|
||||
m := make(map[string]string)
|
||||
for _, k := range keys {
|
||||
v := header.Get(k)
|
||||
if v != "" {
|
||||
m[k] = v
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue