mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 01:09:36 +02:00
health-checks: add route reachability (#5093)
* health-checks: add route reachability * rm tls check bypass
This commit is contained in:
parent
a95423b310
commit
614048ae9c
5 changed files with 324 additions and 0 deletions
138
pkg/zero/ping/ping.go
Normal file
138
pkg/zero/ping/ping.go
Normal file
|
@ -0,0 +1,138 @@
|
|||
package clusterping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
)
|
||||
|
||||
const (
|
||||
JWKSPath = "/.well-known/pomerium/jwks.json"
|
||||
)
|
||||
|
||||
type CheckErrorCode int
|
||||
|
||||
const (
|
||||
ErrInvalidCert CheckErrorCode = iota
|
||||
ErrDNSError
|
||||
ErrConnectionError
|
||||
ErrKeyNotFound
|
||||
ErrUnexpectedResponse
|
||||
)
|
||||
|
||||
type CheckError struct {
|
||||
Code CheckErrorCode
|
||||
Err error
|
||||
}
|
||||
|
||||
func NewCheckError(code CheckErrorCode, err error) *CheckError {
|
||||
return &CheckError{
|
||||
Code: code,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
var errorCodeToString = map[CheckErrorCode]string{
|
||||
ErrInvalidCert: "invalid certificate",
|
||||
ErrDNSError: "DNS error",
|
||||
ErrConnectionError: "connection error",
|
||||
ErrKeyNotFound: "key not found",
|
||||
ErrUnexpectedResponse: "unexpected response",
|
||||
}
|
||||
|
||||
func (e *CheckError) Error() string {
|
||||
return fmt.Sprintf("%s: %v", errorCodeToString[e.Code], e.Err)
|
||||
}
|
||||
|
||||
func (e *CheckError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
func GetJWKSURL(host string) string {
|
||||
return (&url.URL{
|
||||
Scheme: "https",
|
||||
Host: host,
|
||||
Path: JWKSPath,
|
||||
}).String()
|
||||
}
|
||||
|
||||
func CheckKey(
|
||||
ctx context.Context,
|
||||
jwksURL string,
|
||||
key jose.JSONWebKey,
|
||||
client *http.Client,
|
||||
) error {
|
||||
keys, err := fetchKeys(ctx, client, jwksURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !containsKey(keys, key) {
|
||||
return NewCheckError(ErrKeyNotFound, fmt.Errorf("key %s not found in JWKS", key.KeyID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func containsKey(keys []jose.JSONWebKey, key jose.JSONWebKey) bool {
|
||||
for _, k := range keys {
|
||||
if k.KeyID == key.KeyID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func fetchKeys(ctx context.Context, client *http.Client, jwksURL string) ([]jose.JSONWebKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", version.UserAgent())
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, convertRequestError(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("unexpected status code %d", resp.StatusCode))
|
||||
}
|
||||
|
||||
if resp.Header.Get("Content-Type") != "application/json" {
|
||||
return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("unexpected content type %s", resp.Header.Get("Content-Type")))
|
||||
}
|
||||
|
||||
var jwks struct {
|
||||
Keys []jose.JSONWebKey `json:"keys"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
||||
return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("error decoding response: %w", err))
|
||||
}
|
||||
|
||||
return jwks.Keys, nil
|
||||
}
|
||||
|
||||
func convertRequestError(err error) error {
|
||||
if tlsErr := new(tls.CertificateVerificationError); errors.As(err, &tlsErr) {
|
||||
return NewCheckError(ErrInvalidCert, err)
|
||||
}
|
||||
if dnsErr := new(net.DNSError); errors.As(err, &dnsErr) {
|
||||
return NewCheckError(ErrDNSError, err)
|
||||
}
|
||||
if netErr := new(net.Error); errors.As(err, netErr) {
|
||||
return NewCheckError(ErrConnectionError, err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("error making request: %w", err)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue