pomerium/internal/zero/healthcheck/check_routes.go
2024-05-01 15:15:38 -04:00

131 lines
3.5 KiB
Go

package healthcheck
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"slices"
"time"
"github.com/go-jose/go-jose/v3"
"golang.org/x/exp/maps"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/health"
"github.com/pomerium/pomerium/pkg/protoutil"
clusterping "github.com/pomerium/pomerium/pkg/zero/ping"
)
// CheckRoutes checks whether all routes that are referenced by this pomerium instance configuration are reachable
// it resolves the DNS entry and tries to access a pomerium jwks route
// we should hit ourselves and observe the same public key that we have in our configuration
// otherwise, something is misconfigured on the DNS level
func (c *checker) CheckRoutes(ctx context.Context) {
err := checkRoutesReachable(ctx, c.bootstrap.GetConfig(), c.databrokerClient)
if err != nil {
log.Warn(ctx).Err(err).Msg("routes reachability check failed")
}
}
const (
connectionTimeout = time.Second * 30
)
func getPingHTTPClient() *http.Client {
return &http.Client{
Timeout: connectionTimeout,
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return (&net.Dialer{
Timeout: connectionTimeout,
}).DialContext(ctx, network, addr)
},
},
}
}
func checkRoutesReachable(
ctx context.Context,
cfg *config.Config,
databrokerClient databroker.DataBrokerServiceClient,
) error {
key, err := getClusterPublicKey(cfg)
if err != nil {
return fmt.Errorf("error getting cluster public key: %w", err)
}
hosts, err := getRouteHosts(ctx, databrokerClient)
if err != nil {
return fmt.Errorf("error getting route hosts: %w", err)
}
slices.Sort(hosts)
client := getPingHTTPClient()
var errs []error
for _, host := range hosts {
err = clusterping.CheckKey(ctx, clusterping.GetJWKSURL(host), *key, client)
if err != nil {
errs = append(errs, fmt.Errorf("%s: %w", host, err))
}
}
if len(errs) == 0 {
health.ReportOK(health.RoutesReachable)
} else {
health.ReportError(health.RoutesReachable, errors.Join(errs...))
}
return nil
}
func getClusterPublicKey(cfg *config.Config) (*jose.JSONWebKey, error) {
data, err := base64.StdEncoding.DecodeString(cfg.Options.SigningKey)
if err != nil {
return nil, fmt.Errorf("error decoding signing key: %w", err)
}
key, err := cryptutil.PublicJWKFromBytes(data)
if err != nil {
return nil, fmt.Errorf("error creating public jwk from bytes: %w", err)
}
return key, nil
}
func getRouteHosts(ctx context.Context, databrokerClient databroker.DataBrokerServiceClient) ([]string, error) {
records, _, _, err := databroker.InitialSync(ctx, databrokerClient, &databroker.SyncLatestRequest{
Type: protoutil.GetTypeURL(new(configpb.Config)),
})
if err != nil {
return nil, fmt.Errorf("error during initial sync: %w", err)
}
hosts := make(map[string]struct{})
for _, record := range records {
var cfg configpb.Config
if err := record.Data.UnmarshalTo(&cfg); err != nil {
return nil, fmt.Errorf("error unmarshalling config: %w", err)
}
for _, route := range cfg.GetRoutes() {
if route.GetTlsCustomCa() != "" {
continue
}
u, err := urlutil.ParseAndValidateURL(route.GetFrom())
if err != nil {
continue
}
hosts[u.Host] = struct{}{}
}
}
return maps.Keys(hosts), nil
}