diff --git a/internal/zero/controller/controller.go b/internal/zero/controller/controller.go index 0275a6624..527644a24 100644 --- a/internal/zero/controller/controller.go +++ b/internal/zero/controller/controller.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/internal/zero/analytics" sdk "github.com/pomerium/pomerium/internal/zero/api" "github.com/pomerium/pomerium/internal/zero/bootstrap" + "github.com/pomerium/pomerium/internal/zero/healthcheck" "github.com/pomerium/pomerium/internal/zero/leaser" "github.com/pomerium/pomerium/internal/zero/reconciler" "github.com/pomerium/pomerium/internal/zero/reporter" @@ -108,6 +109,7 @@ func (c *controller) runZeroControlLoop(ctx context.Context) error { c.runReconcilerLeased, c.runAnalyticsLeased, c.runMetricsReporterLeased, + c.runHealthChecksLeased, ) } @@ -147,6 +149,14 @@ func (c *controller) runMetricsReporterLeased(ctx context.Context, client databr ) } +func (c *controller) runHealthChecksLeased(ctx context.Context, client databroker.DataBrokerServiceClient) error { + ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { + return c.Str("service", "zero-health-checks") + }) + + return healthcheck.RunChecks(ctx, c.bootstrapConfig, client) +} + func (c *controller) runHealthCheckReporter(ctx context.Context) error { ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { return c.Str("service", "zero-health-check-reporter") diff --git a/internal/zero/healthcheck/check_routes.go b/internal/zero/healthcheck/check_routes.go new file mode 100644 index 000000000..18d92934e --- /dev/null +++ b/internal/zero/healthcheck/check_routes.go @@ -0,0 +1,131 @@ +package healthcheck + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "net" + "net/http" + "slices" + "time" + + "github.com/go-jose/go-jose/v3" + "golang.org/x/exp/maps" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/cryptutil" + configpb "github.com/pomerium/pomerium/pkg/grpc/config" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/health" + "github.com/pomerium/pomerium/pkg/protoutil" + clusterping "github.com/pomerium/pomerium/pkg/zero/ping" +) + +// CheckRoutes checks whether all routes that are referenced by this pomerium instance configuration are reachable +// it resolves the DNS entry and tries to access a pomerium jwks route +// we should hit ourselves and observe the same public key that we have in our configuration +// otherwise, something is misconfigured on the DNS level +func (c *checker) CheckRoutes(ctx context.Context) { + err := checkRoutesReachable(ctx, c.bootstrap.GetConfig(), c.databrokerClient) + if err != nil { + log.Warn(ctx).Err(err).Msg("routes reachability check failed") + } +} + +const ( + connectionTimeout = time.Second * 30 +) + +func getPingHTTPClient() *http.Client { + return &http.Client{ + Timeout: connectionTimeout, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return (&net.Dialer{ + Timeout: connectionTimeout, + }).DialContext(ctx, network, addr) + }, + }, + } +} + +func checkRoutesReachable( + ctx context.Context, + cfg *config.Config, + databrokerClient databroker.DataBrokerServiceClient, +) error { + key, err := getClusterPublicKey(cfg) + if err != nil { + return fmt.Errorf("error getting cluster public key: %w", err) + } + + hosts, err := getRouteHosts(ctx, databrokerClient) + if err != nil { + return fmt.Errorf("error getting route hosts: %w", err) + } + slices.Sort(hosts) + + client := getPingHTTPClient() + var errs []error + for _, host := range hosts { + err = clusterping.CheckKey(ctx, clusterping.GetJWKSURL(host), *key, client) + if err != nil { + errs = append(errs, fmt.Errorf("%s: %w", host, err)) + } + } + + if len(errs) == 0 { + health.ReportOK(health.RoutesReachable) + } else { + health.ReportError(health.RoutesReachable, errors.Join(errs...)) + } + + return nil +} + +func getClusterPublicKey(cfg *config.Config) (*jose.JSONWebKey, error) { + data, err := base64.StdEncoding.DecodeString(cfg.Options.SigningKey) + if err != nil { + return nil, fmt.Errorf("error decoding signing key: %w", err) + } + + key, err := cryptutil.PublicJWKFromBytes(data) + if err != nil { + return nil, fmt.Errorf("error creating public jwk from bytes: %w", err) + } + + return key, nil +} + +func getRouteHosts(ctx context.Context, databrokerClient databroker.DataBrokerServiceClient) ([]string, error) { + records, _, _, err := databroker.InitialSync(ctx, databrokerClient, &databroker.SyncLatestRequest{ + Type: protoutil.GetTypeURL(new(configpb.Config)), + }) + if err != nil { + return nil, fmt.Errorf("error during initial sync: %w", err) + } + + hosts := make(map[string]struct{}) + for _, record := range records { + var cfg configpb.Config + if err := record.Data.UnmarshalTo(&cfg); err != nil { + return nil, fmt.Errorf("error unmarshalling config: %w", err) + } + + for _, route := range cfg.GetRoutes() { + if route.GetTlsCustomCa() != "" { + continue + } + u, err := urlutil.ParseAndValidateURL(route.GetFrom()) + if err != nil { + continue + } + hosts[u.Host] = struct{}{} + } + } + + return maps.Keys(hosts), nil +} diff --git a/internal/zero/healthcheck/checks.go b/internal/zero/healthcheck/checks.go new file mode 100644 index 000000000..353329be9 --- /dev/null +++ b/internal/zero/healthcheck/checks.go @@ -0,0 +1,43 @@ +package healthcheck + +import ( + "context" + "time" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +func RunChecks( + ctx context.Context, + bootstrap config.Source, + databrokerClient databroker.DataBrokerServiceClient, +) error { + c := &checker{ + bootstrap: bootstrap, + databrokerClient: databrokerClient, + } + return c.run(ctx) +} + +type checker struct { + bootstrap config.Source + databrokerClient databroker.DataBrokerServiceClient +} + +const runHealthChecksInterval = time.Minute * 30 + +func (c *checker) run(ctx context.Context) error { + tm := time.NewTimer(runHealthChecksInterval) + defer tm.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-tm.C: + c.CheckRoutes(ctx) + tm.Reset(runHealthChecksInterval) + } + } +} diff --git a/pkg/health/check.go b/pkg/health/check.go index 0eaf96fb2..229a9e40f 100644 --- a/pkg/health/check.go +++ b/pkg/health/check.go @@ -19,6 +19,8 @@ const ( ZeroBootstrapConfigSave = Check("zero.bootstrap-config.save") // ZeroConnect checks whether the Zero Connect service is connected ZeroConnect = Check("zero.connect") + // RoutesReachable checks whether all referenced routes can be resolved to this instance + RoutesReachable = Check("routes.reachable") ) // ZeroResourceBundle checks whether the Zero resource bundle was applied diff --git a/pkg/zero/ping/ping.go b/pkg/zero/ping/ping.go new file mode 100644 index 000000000..23b9284da --- /dev/null +++ b/pkg/zero/ping/ping.go @@ -0,0 +1,138 @@ +package clusterping + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "net/url" + + "github.com/go-jose/go-jose/v3" + + "github.com/pomerium/pomerium/internal/version" +) + +const ( + JWKSPath = "/.well-known/pomerium/jwks.json" +) + +type CheckErrorCode int + +const ( + ErrInvalidCert CheckErrorCode = iota + ErrDNSError + ErrConnectionError + ErrKeyNotFound + ErrUnexpectedResponse +) + +type CheckError struct { + Code CheckErrorCode + Err error +} + +func NewCheckError(code CheckErrorCode, err error) *CheckError { + return &CheckError{ + Code: code, + Err: err, + } +} + +var errorCodeToString = map[CheckErrorCode]string{ + ErrInvalidCert: "invalid certificate", + ErrDNSError: "DNS error", + ErrConnectionError: "connection error", + ErrKeyNotFound: "key not found", + ErrUnexpectedResponse: "unexpected response", +} + +func (e *CheckError) Error() string { + return fmt.Sprintf("%s: %v", errorCodeToString[e.Code], e.Err) +} + +func (e *CheckError) Unwrap() error { + return e.Err +} + +func GetJWKSURL(host string) string { + return (&url.URL{ + Scheme: "https", + Host: host, + Path: JWKSPath, + }).String() +} + +func CheckKey( + ctx context.Context, + jwksURL string, + key jose.JSONWebKey, + client *http.Client, +) error { + keys, err := fetchKeys(ctx, client, jwksURL) + if err != nil { + return err + } + + if !containsKey(keys, key) { + return NewCheckError(ErrKeyNotFound, fmt.Errorf("key %s not found in JWKS", key.KeyID)) + } + + return nil +} + +func containsKey(keys []jose.JSONWebKey, key jose.JSONWebKey) bool { + for _, k := range keys { + if k.KeyID == key.KeyID { + return true + } + } + return false +} + +func fetchKeys(ctx context.Context, client *http.Client, jwksURL string) ([]jose.JSONWebKey, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", version.UserAgent()) + resp, err := client.Do(req) + if err != nil { + return nil, convertRequestError(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("unexpected status code %d", resp.StatusCode)) + } + + if resp.Header.Get("Content-Type") != "application/json" { + return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("unexpected content type %s", resp.Header.Get("Content-Type"))) + } + + var jwks struct { + Keys []jose.JSONWebKey `json:"keys"` + } + if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { + return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("error decoding response: %w", err)) + } + + return jwks.Keys, nil +} + +func convertRequestError(err error) error { + if tlsErr := new(tls.CertificateVerificationError); errors.As(err, &tlsErr) { + return NewCheckError(ErrInvalidCert, err) + } + if dnsErr := new(net.DNSError); errors.As(err, &dnsErr) { + return NewCheckError(ErrDNSError, err) + } + if netErr := new(net.Error); errors.As(err, netErr) { + return NewCheckError(ErrConnectionError, err) + } + + return fmt.Errorf("error making request: %w", err) +}