mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-23 05:57:19 +02:00
implement session creation
This commit is contained in:
parent
24b35e26a5
commit
b95ad4dbc3
15 changed files with 646 additions and 148 deletions
|
@ -6,25 +6,11 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/authenticateapi"
|
||||
)
|
||||
|
||||
type VerifyAccessTokenRequest struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
||||
}
|
||||
|
||||
type VerifyIdentityTokenRequest struct {
|
||||
IdentityToken string `json:"identityToken"`
|
||||
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
||||
}
|
||||
|
||||
type VerifyTokenResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
Claims map[string]any `json:"claims,omitempty"`
|
||||
}
|
||||
|
||||
func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request) error {
|
||||
var req VerifyAccessTokenRequest
|
||||
var req authenticateapi.VerifyAccessTokenRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
|
@ -35,7 +21,7 @@ func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request)
|
|||
return err
|
||||
}
|
||||
|
||||
var res VerifyTokenResponse
|
||||
var res authenticateapi.VerifyTokenResponse
|
||||
claims, err := authenticator.VerifyAccessToken(r.Context(), req.AccessToken)
|
||||
if err == nil {
|
||||
res.Valid = true
|
||||
|
@ -57,7 +43,7 @@ func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
|
||||
func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Request) error {
|
||||
var req VerifyIdentityTokenRequest
|
||||
var req authenticateapi.VerifyIdentityTokenRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
|
@ -68,7 +54,7 @@ func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Reques
|
|||
return err
|
||||
}
|
||||
|
||||
var res VerifyTokenResponse
|
||||
var res authenticateapi.VerifyTokenResponse
|
||||
claims, err := authenticator.VerifyIdentityToken(r.Context(), req.IdentityToken)
|
||||
if err == nil {
|
||||
res.Valid = true
|
||||
|
|
|
@ -29,7 +29,7 @@ import (
|
|||
type Authorize struct {
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
accessTracker *AccessTracker
|
||||
globalCache storage.Cache
|
||||
groupsCacheWarmer *cacheWarmer
|
||||
|
@ -43,7 +43,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) {
|
|||
tracerProvider := trace.NewTracerProvider(ctx, "Authorize")
|
||||
tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer)
|
||||
a := &Authorize{
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}),
|
||||
store: store.New(),
|
||||
globalCache: storage.NewGlobalCache(time.Minute),
|
||||
tracerProvider: tracerProvider,
|
||||
|
@ -155,7 +155,7 @@ func newPolicyEvaluator(
|
|||
// OnConfigChange updates internal structures based on config.Options
|
||||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
currentState := a.state.Load()
|
||||
a.currentOptions.Store(cfg.Options)
|
||||
a.currentConfig.Store(cfg)
|
||||
if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state")
|
||||
} else {
|
||||
|
|
|
@ -186,7 +186,7 @@ func (a *Authorize) deniedResponse(
|
|||
Err: errors.New(reason),
|
||||
DebugURL: debugEndpoint,
|
||||
RequestID: requestid.FromContext(ctx),
|
||||
BrandingOptions: a.currentOptions.Load().BrandingOptions,
|
||||
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
|
||||
}
|
||||
httpErr.ErrorResponse(ctx, w, r)
|
||||
|
||||
|
@ -213,7 +213,7 @@ func (a *Authorize) requireLoginResponse(
|
|||
in *envoy_service_auth_v3.CheckRequest,
|
||||
request *evaluator.Request,
|
||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||
options := a.currentOptions.Load()
|
||||
options := a.currentConfig.Load().Options
|
||||
state := a.state.Load()
|
||||
|
||||
if !a.shouldRedirect(in) {
|
||||
|
@ -251,7 +251,7 @@ func (a *Authorize) requireWebAuthnResponse(
|
|||
request *evaluator.Request,
|
||||
result *evaluator.Result,
|
||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||
opts := a.currentOptions.Load()
|
||||
opts := a.currentConfig.Load().Options
|
||||
state := a.state.Load()
|
||||
|
||||
// always assume https scheme
|
||||
|
@ -327,7 +327,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti
|
|||
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
|
||||
// session that lives on the authenticate service.
|
||||
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
|
||||
opts := a.currentOptions.Load()
|
||||
opts := a.currentConfig.Load().Options
|
||||
authenticateURL, err := opts.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -127,8 +127,9 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
}},
|
||||
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
|
||||
}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(opt)
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||
Options: opt,
|
||||
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.store = store.New()
|
||||
pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil)
|
||||
require.NoError(t, err)
|
||||
|
@ -183,15 +184,16 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
func TestAuthorize_deniedResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||
Options: &config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
},
|
||||
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||
|
||||
t.Run("json", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -3,15 +3,17 @@ package authorize
|
|||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type sessionOrServiceAccount interface {
|
||||
GetId() string
|
||||
GetUserId() string
|
||||
Validate() error
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package authorize
|
|||
import (
|
||||
"context"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -21,6 +22,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/contextutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/telemetry/requestid"
|
||||
|
@ -44,31 +46,25 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
requestID := requestid.FromHTTPHeader(hreq.Header)
|
||||
ctx = requestid.WithValue(ctx, requestID)
|
||||
|
||||
sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq)
|
||||
|
||||
var s sessionOrServiceAccount
|
||||
var u *user.User
|
||||
var err error
|
||||
if sessionState != nil {
|
||||
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
||||
if status.Code(err) == codes.Unavailable {
|
||||
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
|
||||
sessionState = nil
|
||||
}
|
||||
}
|
||||
if sessionState != nil && s != nil {
|
||||
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
|
||||
}
|
||||
|
||||
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionState)
|
||||
req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// load the session
|
||||
s, err := a.loadSession(ctx, hreq, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if there's a session or service account, load the user
|
||||
var u *user.User
|
||||
if s != nil {
|
||||
req.Session.ID = s.GetId()
|
||||
u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error
|
||||
}
|
||||
|
||||
res, err := state.evaluator.Evaluate(ctx, req)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation")
|
||||
|
@ -88,10 +84,65 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
return resp, err
|
||||
}
|
||||
|
||||
func (a *Authorize) loadSession(
|
||||
ctx context.Context,
|
||||
hreq *http.Request,
|
||||
req *evaluator.Request,
|
||||
) (s sessionOrServiceAccount, err error) {
|
||||
requestID := requestid.FromHTTPHeader(hreq.Header)
|
||||
|
||||
// attempt to create a session from an incoming idp token
|
||||
s, err = config.NewIncomingIDPTokenSessionCreator(
|
||||
func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) {
|
||||
return getDataBrokerRecord(ctx, recordType, recordID, 0)
|
||||
},
|
||||
func(ctx context.Context, records []*databroker.Record) error {
|
||||
_, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{
|
||||
Records: records,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// invalidate cache
|
||||
for _, record := range records {
|
||||
storage.GetQuerier(ctx).InvalidateCache(ctx, &databroker.QueryRequest{
|
||||
Type: record.GetType(),
|
||||
Query: record.GetId(),
|
||||
Limit: 1,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
},
|
||||
).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq)
|
||||
if err == nil {
|
||||
return s, nil
|
||||
} else if !errors.Is(err, sessions.ErrNoSessionFound) {
|
||||
log.Ctx(ctx).Info().
|
||||
Str("request-id", requestID).
|
||||
Err(err).
|
||||
Msg("error creating session for incoming idp token")
|
||||
}
|
||||
|
||||
sessionState, _ := a.state.Load().sessionStore.LoadSessionStateAndCheckIDP(hreq)
|
||||
if sessionState == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion)
|
||||
if status.Code(err) == codes.Unavailable {
|
||||
log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable")
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
||||
ctx context.Context,
|
||||
in *envoy_service_auth_v3.CheckRequest,
|
||||
sessionState *sessions.State,
|
||||
) (*evaluator.Request, error) {
|
||||
requestURL := getCheckRequestURL(in)
|
||||
attrs := in.GetAttributes()
|
||||
|
@ -106,17 +157,12 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
|||
attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(),
|
||||
),
|
||||
}
|
||||
if sessionState != nil {
|
||||
req.Session = evaluator.RequestSession{
|
||||
ID: sessionState.ID,
|
||||
}
|
||||
}
|
||||
req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions()))
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy {
|
||||
options := a.currentOptions.Load()
|
||||
options := a.currentConfig.Load().Options
|
||||
|
||||
for p := range options.GetAllPolicies() {
|
||||
id, _ := p.RouteID()
|
||||
|
|
|
@ -18,7 +18,6 @@ import (
|
|||
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
|
@ -49,15 +48,16 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
|||
-----END CERTIFICATE-----`
|
||||
|
||||
func Test_getEvaluatorRequest(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||
Options: &config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
},
|
||||
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||
&envoy_service_auth_v3.CheckRequest{
|
||||
|
@ -88,13 +88,10 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
&sessions.State{
|
||||
ID: "SESSION_ID",
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
expect := &evaluator.Request{
|
||||
Policy: &a.currentOptions.Load().Policies[0],
|
||||
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||
Session: evaluator.RequestSession{
|
||||
ID: "SESSION_ID",
|
||||
},
|
||||
|
@ -117,15 +114,16 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{
|
||||
Options: &config.Options{
|
||||
Policies: []config.Policy{{
|
||||
From: "https://example.com",
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
},
|
||||
}), state: atomicutil.NewValue(new(authorizeState))}
|
||||
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(),
|
||||
&envoy_service_auth_v3.CheckRequest{
|
||||
|
@ -145,10 +143,10 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
expect := &evaluator.Request{
|
||||
Policy: &a.currentOptions.Load().Policies[0],
|
||||
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||
Session: evaluator.RequestSession{},
|
||||
HTTP: evaluator.NewRequestHTTP(
|
||||
http.MethodGet,
|
||||
|
|
|
@ -31,7 +31,7 @@ func (a *Authorize) logAuthorizeCheck(
|
|||
impersonateDetails := a.getImpersonateDetails(ctx, s)
|
||||
|
||||
evt := log.Ctx(ctx).Info().Str("service", "authorize")
|
||||
fields := a.currentOptions.Load().GetAuthorizeLogFields()
|
||||
fields := a.currentConfig.Load().Options.GetAuthorizeLogFields()
|
||||
for _, field := range fields {
|
||||
evt = populateLogEvent(ctx, field, evt, in, s, u, hdrs, impersonateDetails, res)
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ func (bearerTokenFormat *BearerTokenFormat) ToPB() *configpb.BearerTokenFormat {
|
|||
case BearerTokenFormatIDPIdentityToken:
|
||||
return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN.Enum()
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown bearer token format: %s", bearerTokenFormat))
|
||||
panic(fmt.Sprintf("unknown bearer token format: %v", bearerTokenFormat))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,9 +4,10 @@ import (
|
|||
"errors"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/pomerium/pomerium/config/otelconfig"
|
||||
"github.com/spf13/viper"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/pomerium/pomerium/config/otelconfig"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -37,6 +38,7 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
|||
DecodePolicyBase64Hook(),
|
||||
decodeNullBoolHookFunc(),
|
||||
decodeJWTClaimHeadersHookFunc(),
|
||||
decodeBearerTokenFormatHookFunc(),
|
||||
decodeCodecTypeHookFunc(),
|
||||
decodePPLPolicyHookFunc(),
|
||||
decodeSANMatcherHookFunc(),
|
||||
|
|
|
@ -5,15 +5,28 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/internal/sessions/header"
|
||||
"github.com/pomerium/pomerium/internal/sessions/queryparam"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/authenticateapi"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
"github.com/pomerium/pomerium/pkg/grpcutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
// A SessionStore saves and loads sessions based on the options.
|
||||
|
@ -116,82 +129,253 @@ func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v
|
|||
return store.store.SaveSession(w, r, v)
|
||||
}
|
||||
|
||||
// An IDPTokenSessionHandler handles incoming idp access and identity tokens.
|
||||
type IDPTokenSessionHandler struct {
|
||||
options *Options
|
||||
getSession func(ctx context.Context, id string) (*session.Session, error)
|
||||
putSession func(ctx context.Context, s *session.Session) error
|
||||
var (
|
||||
accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d")
|
||||
identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17")
|
||||
)
|
||||
|
||||
type IncomingIDPTokenSessionCreator interface {
|
||||
CreateSession(ctx context.Context, cfg *Config, policy *Policy, r *http.Request) (*session.Session, error)
|
||||
}
|
||||
|
||||
// NewIDPTokenSessionHandler creates a new IDPTokenSessionHandler.
|
||||
func NewIDPTokenSessionHandler(
|
||||
options *Options,
|
||||
getSession func(ctx context.Context, id string) (*session.Session, error),
|
||||
putSession func(ctx context.Context, s *session.Session) error,
|
||||
) *IDPTokenSessionHandler {
|
||||
return &IDPTokenSessionHandler{
|
||||
options: options,
|
||||
getSession: getSession,
|
||||
putSession: putSession,
|
||||
type incomingIDPTokenSessionCreator struct {
|
||||
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error)
|
||||
putRecords func(ctx context.Context, records []*databroker.Record) error
|
||||
}
|
||||
|
||||
func NewIncomingIDPTokenSessionCreator(
|
||||
getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error),
|
||||
putRecords func(ctx context.Context, records []*databroker.Record) error,
|
||||
) IncomingIDPTokenSessionCreator {
|
||||
return &incomingIDPTokenSessionCreator{getRecord: getRecord, putRecords: putRecords}
|
||||
}
|
||||
|
||||
// CreateSession attempts to create a session for incoming idp access and
|
||||
// identity tokens. If no access or identity token is passed ErrNoSessionFound will be returned.
|
||||
// If the tokens are not valid an error will be returned.
|
||||
func (c *incomingIDPTokenSessionCreator) CreateSession(
|
||||
ctx context.Context,
|
||||
cfg *Config,
|
||||
policy *Policy,
|
||||
r *http.Request,
|
||||
) (session *session.Session, err error) {
|
||||
if rawAccessToken, ok := cfg.GetIncomingIDPAccessTokenForPolicy(policy, r); ok {
|
||||
return c.createSessionAccessToken(ctx, cfg, policy, rawAccessToken)
|
||||
}
|
||||
|
||||
if rawIdentityToken, ok := cfg.GetIncomingIDPIdentityTokenForPolicy(policy, r); ok {
|
||||
return c.createSessionForIdentityToken(ctx, cfg, policy, rawIdentityToken)
|
||||
}
|
||||
|
||||
return nil, sessions.ErrNoSessionFound
|
||||
}
|
||||
|
||||
// // CreateSessionForIncomingIDPToken creates a session from an incoming idp access or identity token.
|
||||
// // If no such tokens are found or they are invalid ErrNoSessionFound will be returned.
|
||||
// func (h *IDPTokenSessionHandler) CreateSessionForIncomingIDPToken(r *http.Request) (*session.Session, error) {
|
||||
// idp, err := h.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String())
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
func (c *incomingIDPTokenSessionCreator) createSessionAccessToken(
|
||||
ctx context.Context,
|
||||
cfg *Config,
|
||||
policy *Policy,
|
||||
rawAccessToken string,
|
||||
) (*session.Session, error) {
|
||||
sessionID := uuid.NewSHA1(accessTokenUUIDNamespace, []byte(rawAccessToken)).String()
|
||||
s, err := c.getSession(ctx, sessionID)
|
||||
if err == nil {
|
||||
return s, nil
|
||||
} else if !storage.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// return nil, sessions.ErrNoSessionFound
|
||||
// }
|
||||
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err)
|
||||
}
|
||||
|
||||
// func (h *IDPTokenSessionHandler) getIncomingIDPAccessToken(r *http.Request) (rawAccessToken string, ok bool) {
|
||||
// if h.options.
|
||||
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err)
|
||||
}
|
||||
|
||||
// return "", false
|
||||
// }
|
||||
res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{
|
||||
AccessToken: rawAccessToken,
|
||||
IdentityProviderID: idp.GetId(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||
} else if !res.Valid {
|
||||
return nil, fmt.Errorf("invalid access token")
|
||||
}
|
||||
|
||||
// func (h *IDPTokenSessionHandler) getIncomingIDPIdentityToken(r *http.Request) (rawIdentityToken string, ok bool) {
|
||||
// return "", false
|
||||
// }
|
||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||
s.OauthToken = &session.OAuthToken{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: rawAccessToken,
|
||||
ExpiresAt: s.ExpiresAt,
|
||||
}
|
||||
u := c.newUserFromIDPClaims(res.Claims)
|
||||
err = c.putSessionAndUser(ctx, s, u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||
}
|
||||
|
||||
// func CreateSessionForIncomingIDPToken(
|
||||
// r *http.Request,
|
||||
// options *Options,
|
||||
// policy *Policy,
|
||||
// getSession func(ctx context.Context, id string) (*session.Session, error),
|
||||
// putSession func(ctx context.Context, s *session.Session) error)(*session.Session, error) {
|
||||
// }
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken(
|
||||
ctx context.Context,
|
||||
cfg *Config,
|
||||
policy *Policy,
|
||||
rawIdentityToken string,
|
||||
) (*session.Session, error) {
|
||||
sessionID := uuid.NewSHA1(identityTokenUUIDNamespace, []byte(rawIdentityToken)).String()
|
||||
s, err := c.getSession(ctx, sessionID)
|
||||
if err == nil {
|
||||
return s, nil
|
||||
} else if !storage.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idp, err := cfg.Options.GetIdentityProviderForPolicy(policy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err)
|
||||
}
|
||||
|
||||
authenticateURL, transport, err := cfg.resolveAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err)
|
||||
}
|
||||
|
||||
res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{
|
||||
IdentityToken: rawIdentityToken,
|
||||
IdentityProviderID: idp.GetId(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying identity token: %w", err)
|
||||
} else if !res.Valid {
|
||||
return nil, fmt.Errorf("invalid identity token")
|
||||
}
|
||||
|
||||
s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims)
|
||||
s.SetRawIDToken(rawIdentityToken)
|
||||
u := c.newUserFromIDPClaims(res.Claims)
|
||||
err = c.putSessionAndUser(ctx, s, u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error saving session and user: %w", err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims(
|
||||
cfg *Config,
|
||||
sessionID string,
|
||||
claims jwtutil.Claims,
|
||||
) *session.Session {
|
||||
now := time.Now()
|
||||
s := new(session.Session)
|
||||
s.Id = sessionID
|
||||
if userID, ok := claims.GetUserID(); ok {
|
||||
s.UserId = userID
|
||||
}
|
||||
if issuedAt, ok := claims.GetIssuedAt(); ok {
|
||||
s.IssuedAt = timestamppb.New(issuedAt)
|
||||
} else {
|
||||
s.IssuedAt = timestamppb.New(now)
|
||||
}
|
||||
if expiresAt, ok := claims.GetExpirationTime(); ok {
|
||||
s.ExpiresAt = timestamppb.New(expiresAt)
|
||||
} else {
|
||||
s.ExpiresAt = timestamppb.New(now.Add(cfg.Options.CookieExpire))
|
||||
}
|
||||
s.AccessedAt = timestamppb.New(now)
|
||||
s.AddClaims(identity.Claims(claims).Flatten())
|
||||
if aud, ok := claims.GetAudience(); ok {
|
||||
s.Audience = aud
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims(
|
||||
claims jwtutil.Claims,
|
||||
) *user.User {
|
||||
u := new(user.User)
|
||||
if userID, ok := claims.GetUserID(); ok {
|
||||
u.Id = userID
|
||||
}
|
||||
if name, ok := claims.GetString("name"); ok {
|
||||
u.Name = name
|
||||
}
|
||||
if email, ok := claims.GetString("email"); ok {
|
||||
u.Email = email
|
||||
}
|
||||
u.Claims = identity.Claims(claims).Flatten().ToPB()
|
||||
return u
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) {
|
||||
record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
s, ok := msg.(*session.Session)
|
||||
if !ok {
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error {
|
||||
var records []*databroker.Record
|
||||
if id := s.GetId(); id != "" {
|
||||
records = append(records, &databroker.Record{
|
||||
Type: grpcutil.GetTypeURL(s),
|
||||
Id: id,
|
||||
Data: protoutil.NewAny(s),
|
||||
})
|
||||
}
|
||||
if id := u.GetId(); id != "" {
|
||||
records = append(records, &databroker.Record{
|
||||
Type: grpcutil.GetTypeURL(u),
|
||||
Id: id,
|
||||
Data: protoutil.NewAny(u),
|
||||
})
|
||||
}
|
||||
return c.putRecords(ctx, records)
|
||||
}
|
||||
|
||||
// GetIncomingIDPAccessTokenForPolicy returns the raw idp access token from a request if there is one.
|
||||
func (options *Options) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
|
||||
func (cfg *Config) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) {
|
||||
bearerTokenFormat := BearerTokenFormatDefault
|
||||
if options != nil && options.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *options.BearerTokenFormat
|
||||
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *cfg.Options.BearerTokenFormat
|
||||
}
|
||||
if policy != nil && policy.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *policy.BearerTokenFormat
|
||||
}
|
||||
|
||||
if token := r.Header.Get("X-Pomerium-IDP-Access-Token"); token != "" {
|
||||
if token := r.Header.Get(httputil.HeaderPomeriumIDPAccessToken); token != "" {
|
||||
return token, true
|
||||
}
|
||||
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
prefix := "Pomerium-IDP-Access-Token "
|
||||
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
|
||||
prefix := httputil.AuthorizationTypePomeriumIDPAccessToken + " "
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||
return strings.TrimPrefix(auth, prefix), true
|
||||
}
|
||||
|
||||
prefix = "Bearer Pomerium-IDP-Access-Token-"
|
||||
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPAccessToken + "-"
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||
return strings.TrimPrefix(auth, prefix), true
|
||||
}
|
||||
|
||||
prefix = "Bearer "
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPAccessToken {
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
|
||||
bearerTokenFormat == BearerTokenFormatIDPAccessToken {
|
||||
return strings.TrimPrefix(auth, prefix), true
|
||||
}
|
||||
}
|
||||
|
@ -200,32 +384,33 @@ func (options *Options) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *ht
|
|||
}
|
||||
|
||||
// GetIncomingIDPAccessTokenForPolicy returns the raw idp identity token from a request if there is one.
|
||||
func (options *Options) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
|
||||
func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) {
|
||||
bearerTokenFormat := BearerTokenFormatDefault
|
||||
if options != nil && options.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *options.BearerTokenFormat
|
||||
if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *cfg.Options.BearerTokenFormat
|
||||
}
|
||||
if policy != nil && policy.BearerTokenFormat != nil {
|
||||
bearerTokenFormat = *policy.BearerTokenFormat
|
||||
}
|
||||
|
||||
if token := r.Header.Get("X-Pomerium-IDP-Identity-Token"); token != "" {
|
||||
if token := r.Header.Get(httputil.HeaderPomeriumIDPIdentityToken); token != "" {
|
||||
return token, true
|
||||
}
|
||||
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
prefix := "Pomerium-IDP-Identity-Token "
|
||||
if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" {
|
||||
prefix := httputil.AuthorizationTypePomeriumIDPIdentityToken + " "
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||
return strings.TrimPrefix(auth, prefix), true
|
||||
}
|
||||
|
||||
prefix = "Bearer Pomerium-IDP-Identity-Token-"
|
||||
prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPIdentityToken + "-"
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) {
|
||||
return strings.TrimPrefix(auth, prefix), true
|
||||
}
|
||||
|
||||
prefix = "Bearer "
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
|
||||
if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) &&
|
||||
bearerTokenFormat == BearerTokenFormatIDPIdentityToken {
|
||||
return strings.TrimPrefix(auth, prefix), true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
package httputil
|
||||
|
||||
// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers
|
||||
const AuthorizationTypePomerium = "Pomerium"
|
||||
// Pomerium authorization types
|
||||
const (
|
||||
// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers
|
||||
AuthorizationTypePomerium = "Pomerium"
|
||||
AuthorizationTypePomeriumIDPAccessToken = "Pomerium-IDP-Access-Token" //nolint: gosec
|
||||
AuthorizationTypePomeriumIDPIdentityToken = "Pomerium-IDP-Identity-Token" //nolint: gosec
|
||||
)
|
||||
|
||||
// Standard headers
|
||||
const (
|
||||
|
@ -17,6 +22,8 @@ const (
|
|||
// can be used in place of the standard authorization header if that header is being
|
||||
// used by upstream applications.
|
||||
HeaderPomeriumAuthorization = "x-pomerium-authorization"
|
||||
HeaderPomeriumIDPAccessToken = "x-pomerium-idp-access-token" //nolint: gosec
|
||||
HeaderPomeriumIDPIdentityToken = "x-pomerium-idp-identity-token" //nolint: gosec
|
||||
// HeaderPomeriumResponse is set when pomerium itself creates a response,
|
||||
// as opposed to the upstream application and can be used to distinguish
|
||||
// between an application error, and a pomerium related error when debugging.
|
||||
|
|
160
internal/jwtutil/jwtutil.go
Normal file
160
internal/jwtutil/jwtutil.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
// Package jwtutil contains functions for working with JWTs.
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Claims represent claims in a JWT.
|
||||
type Claims map[string]any
|
||||
|
||||
// UnmarshalJSON implements a custom unmarshaller for claims data.
|
||||
func (claims *Claims) UnmarshalJSON(raw []byte) error {
|
||||
dst := map[string]any{}
|
||||
dec := json.NewDecoder(bytes.NewReader(raw))
|
||||
dec.UseNumber()
|
||||
err := dec.Decode(&dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*claims = Claims(dst)
|
||||
return nil
|
||||
}
|
||||
|
||||
// registered claims
|
||||
|
||||
// GetIssuer gets the iss claim.
|
||||
func (claims Claims) GetIssuer() (issuer string, ok bool) {
|
||||
return claims.GetString("iss")
|
||||
}
|
||||
|
||||
// GetSubject gets the sub claim.
|
||||
func (claims Claims) GetSubject() (subject string, ok bool) {
|
||||
return claims.GetString("sub")
|
||||
}
|
||||
|
||||
// GetAudience gets the aud claim.
|
||||
func (claims Claims) GetAudience() (audiences []string, ok bool) {
|
||||
return claims.GetStringSlice("aud")
|
||||
}
|
||||
|
||||
// GetExpirationTime gets the exp claim.
|
||||
func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) {
|
||||
return claims.GetNumericDate("exp")
|
||||
}
|
||||
|
||||
// GetNotBefore gets the nbf claim.
|
||||
func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) {
|
||||
return claims.GetNumericDate("nbf")
|
||||
}
|
||||
|
||||
// GetIssuedAt gets the iat claim.
|
||||
func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) {
|
||||
return claims.GetNumericDate("iat")
|
||||
}
|
||||
|
||||
// GetJWTID gets the jti claim.
|
||||
func (claims Claims) GetJWTID() (jwtID string, ok bool) {
|
||||
return claims.GetString("jti")
|
||||
}
|
||||
|
||||
// custom claims
|
||||
|
||||
// GetUserID returns the oid or sub claim.
|
||||
func (claims Claims) GetUserID() (userID string, ok bool) {
|
||||
if oid, ok := claims.GetString("oid"); ok {
|
||||
return oid, true
|
||||
}
|
||||
|
||||
if sub, ok := claims.GetSubject(); ok {
|
||||
return sub, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// GetNumericDate returns the claim as a numeric date.
|
||||
func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) {
|
||||
if claims == nil {
|
||||
return tm, false
|
||||
}
|
||||
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return tm, false
|
||||
}
|
||||
|
||||
switch v := raw.(type) {
|
||||
case float64:
|
||||
return time.Unix(int64(v), 0), true
|
||||
case int64:
|
||||
return time.Unix(v, 0), true
|
||||
case json.Number:
|
||||
i, err := v.Int64()
|
||||
if err != nil {
|
||||
if f, err := v.Float64(); err == nil {
|
||||
i = int64(f)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return tm, false
|
||||
}
|
||||
return time.Unix(i, 0), true
|
||||
}
|
||||
|
||||
return tm, false
|
||||
}
|
||||
|
||||
// GetString returns the claim as a string.
|
||||
func (claims Claims) GetString(name string) (value string, ok bool) {
|
||||
if claims == nil {
|
||||
return value, false
|
||||
}
|
||||
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return value, false
|
||||
}
|
||||
|
||||
return toString(raw), true
|
||||
}
|
||||
|
||||
// GetStringSlice returns the claim as a slice of strings.
|
||||
func (claims Claims) GetStringSlice(name string) (values []string, ok bool) {
|
||||
if claims == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
raw, ok := claims[name]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return toStringSlice(raw), true
|
||||
}
|
||||
|
||||
func toString(data any) string {
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
return v
|
||||
}
|
||||
return fmt.Sprint(data)
|
||||
}
|
||||
|
||||
func toStringSlice(obj any) []string {
|
||||
v := reflect.ValueOf(obj)
|
||||
switch v.Kind() {
|
||||
case reflect.Slice:
|
||||
vs := make([]string, v.Len())
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
vs[i] = toString(v.Index(i).Interface())
|
||||
}
|
||||
return vs
|
||||
}
|
||||
|
||||
return []string{toString(obj)}
|
||||
}
|
109
pkg/authenticateapi/authenticateapi.go
Normal file
109
pkg/authenticateapi/authenticateapi.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
// Package authenticateapi has the types and methods for the authenticate api.
|
||||
package authenticateapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
)
|
||||
|
||||
// VerifyAccessTokenRequest is used to verify access tokens.
|
||||
type VerifyAccessTokenRequest struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
||||
}
|
||||
|
||||
// VerifyIdentityTokenRequest is used to verify identity tokens.
|
||||
type VerifyIdentityTokenRequest struct {
|
||||
IdentityToken string `json:"identityToken"`
|
||||
IdentityProviderID string `json:"identityProviderId,omitempty"`
|
||||
}
|
||||
|
||||
// VerifyTokenResponse is the result of verifying an access or identity token.
|
||||
type VerifyTokenResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
Claims jwtutil.Claims `json:"claims,omitempty"`
|
||||
}
|
||||
|
||||
// An API is an api client for the authenticate service.
|
||||
type API struct {
|
||||
authenticateURL *url.URL
|
||||
transport http.RoundTripper
|
||||
}
|
||||
|
||||
// New creates a new API client.
|
||||
func New(
|
||||
authenticateURL *url.URL,
|
||||
transport http.RoundTripper,
|
||||
) *API {
|
||||
return &API{
|
||||
authenticateURL: authenticateURL,
|
||||
transport: transport,
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (api *API) VerifyAccessToken(ctx context.Context, request *VerifyAccessTokenRequest) (*VerifyTokenResponse, error) {
|
||||
var response VerifyTokenResponse
|
||||
err := api.call(ctx, "verify-access-token", request, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (api *API) VerifyIdentityToken(ctx context.Context, request *VerifyIdentityTokenRequest) (*VerifyTokenResponse, error) {
|
||||
var response VerifyTokenResponse
|
||||
err := api.call(ctx, "verify-identity-token", request, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
func (api *API) call(
|
||||
ctx context.Context,
|
||||
endpoint string,
|
||||
request, response any,
|
||||
) error {
|
||||
u := api.authenticateURL.ResolveReference(&url.URL{
|
||||
Path: "/.pomerium/" + endpoint,
|
||||
})
|
||||
|
||||
body, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling %s http request: %w", endpoint, err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating %s http request: %w", endpoint, err)
|
||||
}
|
||||
|
||||
res, err := (&http.Client{
|
||||
Transport: api.transport,
|
||||
}).Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error executing %s http request: %w", endpoint, err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err = io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading %s http response: %w", endpoint, err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading %s http response (body=%s): %w", endpoint, body, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
|
||||
)
|
||||
|
@ -95,7 +96,7 @@ func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string)
|
|||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||
}
|
||||
|
||||
claims = map[string]any{}
|
||||
claims = jwtutil.Claims(map[string]any{})
|
||||
err = token.Claims(&claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue