mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 02:46:30 +02:00
138 lines
3.1 KiB
Go
138 lines
3.1 KiB
Go
package clusterping
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
|
|
"github.com/go-jose/go-jose/v3"
|
|
|
|
"github.com/pomerium/pomerium/internal/version"
|
|
)
|
|
|
|
const (
|
|
JWKSPath = "/.well-known/pomerium/jwks.json"
|
|
)
|
|
|
|
type CheckErrorCode int
|
|
|
|
const (
|
|
ErrInvalidCert CheckErrorCode = iota
|
|
ErrDNSError
|
|
ErrConnectionError
|
|
ErrKeyNotFound
|
|
ErrUnexpectedResponse
|
|
)
|
|
|
|
type CheckError struct {
|
|
Code CheckErrorCode
|
|
Err error
|
|
}
|
|
|
|
func NewCheckError(code CheckErrorCode, err error) *CheckError {
|
|
return &CheckError{
|
|
Code: code,
|
|
Err: err,
|
|
}
|
|
}
|
|
|
|
var errorCodeToString = map[CheckErrorCode]string{
|
|
ErrInvalidCert: "invalid certificate",
|
|
ErrDNSError: "DNS error",
|
|
ErrConnectionError: "connection error",
|
|
ErrKeyNotFound: "key not found",
|
|
ErrUnexpectedResponse: "unexpected response",
|
|
}
|
|
|
|
func (e *CheckError) Error() string {
|
|
return fmt.Sprintf("%s: %v", errorCodeToString[e.Code], e.Err)
|
|
}
|
|
|
|
func (e *CheckError) Unwrap() error {
|
|
return e.Err
|
|
}
|
|
|
|
func GetJWKSURL(host string) string {
|
|
return (&url.URL{
|
|
Scheme: "https",
|
|
Host: host,
|
|
Path: JWKSPath,
|
|
}).String()
|
|
}
|
|
|
|
func CheckKey(
|
|
ctx context.Context,
|
|
jwksURL string,
|
|
key jose.JSONWebKey,
|
|
client *http.Client,
|
|
) error {
|
|
keys, err := fetchKeys(ctx, client, jwksURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !containsKey(keys, key) {
|
|
return NewCheckError(ErrKeyNotFound, fmt.Errorf("key %s not found in JWKS", key.KeyID))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func containsKey(keys []jose.JSONWebKey, key jose.JSONWebKey) bool {
|
|
for _, k := range keys {
|
|
if k.KeyID == key.KeyID {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func fetchKeys(ctx context.Context, client *http.Client, jwksURL string) ([]jose.JSONWebKey, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, jwksURL, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creating request: %w", err)
|
|
}
|
|
req.Header.Set("Accept", "application/json")
|
|
req.Header.Set("User-Agent", version.UserAgent())
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, convertRequestError(err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("unexpected status code %d", resp.StatusCode))
|
|
}
|
|
|
|
if resp.Header.Get("Content-Type") != "application/json" {
|
|
return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("unexpected content type %s", resp.Header.Get("Content-Type")))
|
|
}
|
|
|
|
var jwks struct {
|
|
Keys []jose.JSONWebKey `json:"keys"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
|
|
return nil, NewCheckError(ErrUnexpectedResponse, fmt.Errorf("error decoding response: %w", err))
|
|
}
|
|
|
|
return jwks.Keys, nil
|
|
}
|
|
|
|
func convertRequestError(err error) error {
|
|
if tlsErr := new(tls.CertificateVerificationError); errors.As(err, &tlsErr) {
|
|
return NewCheckError(ErrInvalidCert, err)
|
|
}
|
|
if dnsErr := new(net.DNSError); errors.As(err, &dnsErr) {
|
|
return NewCheckError(ErrDNSError, err)
|
|
}
|
|
if netErr := new(net.Error); errors.As(err, netErr) {
|
|
return NewCheckError(ErrConnectionError, err)
|
|
}
|
|
|
|
return fmt.Errorf("error making request: %w", err)
|
|
}
|