authenticateflow: add stateful flow (#4822)

Add a new Stateful type implementing the stateful authentication flow
from Pomerium v0.20 and earlier.

This consists mainly of logic from authenticate/handlers.go prior to
commits 57217af and 539fd51.

One significant change is to set the default IdP ID when an IdP ID is
not provided in the request URL (e.g. when signing in directly at the
authenticate service domain). Otherwise, if session state is stored with
an empty IdP ID, it won't be valid for any route.
This commit is contained in:
Kenneth Jenkins 2023-12-07 09:54:42 -08:00 committed by GitHub
parent 0e9a07eac9
commit c01d0e045d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 637 additions and 0 deletions

View file

@ -13,6 +13,11 @@ type signatureVerifier struct {
sharedKey []byte
}
// VerifySignature checks that the provided request has a valid signature.
func (v signatureVerifier) VerifySignature(r *http.Request) error {
return middleware.ValidateRequestURL(r, v.sharedKey)
}
// VerifyAuthenticateSignature checks that the provided request has a valid
// signature (for the authenticate service).
func (v signatureVerifier) VerifyAuthenticateSignature(r *http.Request) error {

View file

@ -0,0 +1,361 @@
package authenticateflow
import (
"context"
"crypto/cipher"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"time"
"golang.org/x/oauth2"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/manager"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
// Stateful implements the stateful authentication flow. In this flow, the
// authenticate service has direct access to the databroker.
type Stateful struct {
signatureVerifier
// sharedEncoder is the encoder to use to serialize data to be consumed
// by other services
sharedEncoder encoding.MarshalUnmarshaler
// sharedKey is the secret to encrypt and authenticate data shared between services
sharedKey []byte
// sharedCipher is the cipher to use to encrypt/decrypt data shared between services
sharedCipher cipher.AEAD
// sessionDuration is the maximum Pomerium session duration
sessionDuration time.Duration
// sessionStore is the session store used to persist a user's session
sessionStore sessions.SessionStore
defaultIdentityProviderID string
authenticateURL *url.URL
dataBrokerClient databroker.DataBrokerServiceClient
}
// NewStateful initializes the authentication flow for the given configuration
// and session store.
func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) {
s := &Stateful{
sessionDuration: cfg.Options.CookieExpire,
sessionStore: sessionStore,
}
var err error
s.authenticateURL, err = cfg.Options.GetAuthenticateURL()
if err != nil {
return nil, err
}
// shared cipher to encrypt data before passing data between services
s.sharedKey, err = cfg.Options.GetSharedKey()
if err != nil {
return nil, err
}
s.sharedCipher, err = cryptutil.NewAEADCipher(s.sharedKey)
if err != nil {
return nil, err
}
// shared state encoder setup
s.sharedEncoder, err = jws.NewHS256Signer(s.sharedKey)
if err != nil {
return nil, err
}
s.signatureVerifier = signatureVerifier{cfg.Options, s.sharedKey}
idp, err := cfg.Options.GetIdentityProviderForPolicy(nil)
if err == nil {
s.defaultIdentityProviderID = idp.GetId()
}
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(),
&grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services,
SignedJWTKey: s.sharedKey,
})
if err != nil {
return nil, err
}
s.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
return s, nil
}
// SignIn redirects to a route callback URL, if the provided request and
// session state are valid.
func (s *Stateful) SignIn(
w http.ResponseWriter,
r *http.Request,
sessionState *sessions.State,
) error {
if err := s.VerifyAuthenticateSignature(r); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
// start over if this is a different identity provider
if sessionState == nil || sessionState.IdentityProviderID != idpID {
sessionState = sessions.NewState(idpID)
}
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
jwtAudience := []string{s.authenticateURL.Host, redirectURL.Host}
// if the callback is explicitly set, set it and add an additional audience
if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" {
callbackURL, err := urlutil.ParseAndValidateURL(callbackStr)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
jwtAudience = append(jwtAudience, callbackURL.Host)
}
newSession := sessionState.WithNewIssuer(s.authenticateURL.Host, jwtAudience)
// re-persist the session, useful when session was evicted from session store
if err := s.sessionStore.SaveSession(w, r, sessionState); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
// sign the route session, as a JWT
signedJWT, err := s.sharedEncoder.Marshal(newSession)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
// encrypt our route-scoped JWT to avoid accidental logging of queryparams
encryptedJWT := cryptutil.Encrypt(s.sharedCipher, signedJWT, nil)
// base64 our encrypted payload for URL-friendlyness
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
// build our hmac-d redirect URL with our session, pointing back to the
// proxy's callback URL which is responsible for setting our new route-session
uri := urlutil.NewSignedURL(s.sharedKey, callbackURL)
httputil.Redirect(w, r, uri.String(), http.StatusFound)
return nil
}
// PersistSession stores session and user data in the databroker.
func (s *Stateful) PersistSession(
ctx context.Context,
_ http.ResponseWriter,
sessionState *sessions.State,
claims identity.SessionClaims,
accessToken *oauth2.Token,
) error {
sessionExpiry := timestamppb.New(time.Now().Add(s.sessionDuration))
idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time())
sess := &session.Session{
Id: sessionState.ID,
UserId: sessionState.UserID(),
IssuedAt: timestamppb.Now(),
AccessedAt: timestamppb.Now(),
ExpiresAt: sessionExpiry,
IdToken: &session.IDToken{
Issuer: sessionState.Issuer, // todo(bdd): the issuer is not authN but the downstream IdP from the claims
Subject: sessionState.Subject,
ExpiresAt: sessionExpiry,
IssuedAt: idTokenIssuedAt,
},
OauthToken: manager.ToOAuthToken(accessToken),
Audience: sessionState.Audience,
}
sess.SetRawIDToken(claims.RawIDToken)
sess.AddClaims(claims.Flatten())
var managerUser manager.User
managerUser.User, _ = user.Get(ctx, s.dataBrokerClient, sess.GetUserId())
if managerUser.User == nil {
// if no user exists yet, create a new one
managerUser.User = &user.User{
Id: sess.GetUserId(),
}
}
populateUserFromClaims(managerUser.User, claims.Claims)
_, err := databroker.Put(ctx, s.dataBrokerClient, managerUser.User)
if err != nil {
return fmt.Errorf("authenticate: error saving user: %w", err)
}
res, err := session.Put(ctx, s.dataBrokerClient, sess)
if err != nil {
return fmt.Errorf("authenticate: error saving session: %w", err)
}
sessionState.DatabrokerServerVersion = res.GetServerVersion()
sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
return nil
}
// GetUserInfoData returns user info data associated with the given request (if
// any).
func (s *Stateful) GetUserInfoData(
r *http.Request, sessionState *sessions.State,
) handlers.UserInfoData {
var isImpersonated bool
pbSession, err := session.Get(r.Context(), s.dataBrokerClient, sessionState.ID)
if sid := pbSession.GetImpersonateSessionId(); sid != "" {
pbSession, err = session.Get(r.Context(), s.dataBrokerClient, sid)
isImpersonated = true
}
if err != nil {
pbSession = &session.Session{
Id: sessionState.ID,
}
}
pbUser, err := user.Get(r.Context(), s.dataBrokerClient, pbSession.GetUserId())
if err != nil {
pbUser = &user.User{
Id: pbSession.GetUserId(),
}
}
return handlers.UserInfoData{
IsImpersonated: isImpersonated,
Session: pbSession,
User: pbUser,
}
}
// RevokeSession revokes the session associated with the provided request,
// returning the ID token from the revoked session.
func (s *Stateful) RevokeSession(
ctx context.Context,
_ *http.Request,
authenticator identity.Authenticator,
sessionState *sessions.State,
) string {
if sessionState == nil {
return ""
}
var rawIDToken string
sess, _ := session.Get(ctx, s.dataBrokerClient, sessionState.ID)
if sess != nil && sess.OauthToken != nil {
rawIDToken = sess.GetIdToken().GetRaw()
if err := authenticator.Revoke(ctx, manager.FromOAuthToken(sess.OauthToken)); err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
}
}
if err := session.Delete(ctx, s.dataBrokerClient, sessionState.ID); err != nil {
log.Ctx(ctx).Warn().Err(err).
Msg("authenticate: failed to delete session from session store")
}
return rawIDToken
}
// VerifySession checks that an existing session is still valid.
func (s *Stateful) VerifySession(
ctx context.Context, _ *http.Request, sessionState *sessions.State,
) error {
sess, err := session.Get(ctx, s.dataBrokerClient, sessionState.ID)
if err != nil {
return fmt.Errorf("session not found in databroker: %w", err)
}
return sess.Validate()
}
// LogAuthenticateEvent is a no-op for the stateful authentication flow.
func (s *Stateful) LogAuthenticateEvent(*http.Request) {}
// AuthenticateSignInURL returns a URL to redirect the user to the authenticate
// domain.
func (s *Stateful) AuthenticateSignInURL(
_ context.Context, queryParams url.Values, redirectURL *url.URL, idpID string,
) (string, error) {
signinURL := s.authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/sign_in",
})
if queryParams == nil {
queryParams = url.Values{}
}
queryParams.Set(urlutil.QueryRedirectURI, redirectURL.String())
queryParams.Set(urlutil.QueryIdentityProviderID, idpID)
signinURL.RawQuery = queryParams.Encode()
redirectTo := urlutil.NewSignedURL(s.sharedKey, signinURL).String()
return redirectTo, nil
}
// GetIdentityProviderIDForURLValues returns the identity provider ID
// associated with the given URL values.
func (s *Stateful) GetIdentityProviderIDForURLValues(vs url.Values) string {
if id := vs.Get(urlutil.QueryIdentityProviderID); id != "" {
return id
}
return s.defaultIdentityProviderID
}
// Callback handles a redirect to a route domain once signed in.
func (s *Stateful) Callback(w http.ResponseWriter, r *http.Request) error {
if err := s.VerifySignature(r); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
encryptedJWT, err := base64.URLEncoding.DecodeString(encryptedSession)
if err != nil {
return fmt.Errorf("proxy: malfromed callback token: %w", err)
}
rawJWT, err := cryptutil.Decrypt(s.sharedCipher, encryptedJWT, nil)
if err != nil {
return fmt.Errorf("proxy: callback token decrypt error: %w", err)
}
// save the session state
if err = s.sessionStore.SaveSession(w, r, rawJWT); err != nil {
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err))
}
// if programmatic, encode the session jwt as a query param
if isProgrammatic := r.FormValue(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
q := redirectURL.Query()
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
redirectURL.RawQuery = q.Encode()
}
// redirect
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
return nil
}

View file

@ -0,0 +1,271 @@
package authenticateflow
import (
"encoding/base64"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/sessions"
mstore "github.com/pomerium/pomerium/internal/sessions/mock"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
func TestStatefulSignIn(t *testing.T) {
opts := config.NewDefaultOptions()
tests := []struct {
name string
host string
qp map[string]string
validSignature bool
session *sessions.State
encoder encoding.MarshalUnmarshaler
saveError error
wantErrorMsg string
wantRedirectBaseURL string
}{
{"good", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
{"good alternate port", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
{"invalid signature", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, false, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
{"bad redirect uri query", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
{"bad marshal", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
{"good with different programmatic redirect", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example"},
{"encrypted encoder error", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""},
{"good with callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example/"},
{"bad callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""},
{"good programmatic request", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
sessionStore := &mstore.Store{SaveError: tt.saveError}
flow, err := NewStateful(&config.Config{Options: opts}, sessionStore)
if err != nil {
t.Fatal(err)
}
flow.sharedEncoder = tt.encoder
uri := &url.URL{Scheme: "https", Host: tt.host}
queryString := uri.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
uri.RawQuery = queryString.Encode()
if tt.validSignature {
sharedKey, _ := opts.GetSharedKey()
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
}
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
err = flow.SignIn(w, r, tt.session)
result := w.Result()
if tt.wantErrorMsg == "" {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
expectedStatus := "302 Found"
if result.Status != expectedStatus {
t.Errorf("wrong status code: got %v, want %v", result.Status, expectedStatus)
}
loc, err := url.Parse(result.Header.Get("Location"))
if err != nil {
t.Fatalf("couldn't parse redirect URL: %v", err)
}
loc.RawQuery = "" // ignore the query parameters
if loc.String() != tt.wantRedirectBaseURL {
t.Errorf("wrong redirect base URL: got %q, want %q",
loc.String(), tt.wantRedirectBaseURL)
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
}
}
})
}
}
func TestStatefulAuthenticateSignInURL(t *testing.T) {
opts := config.NewDefaultOptions()
opts.AuthenticateURLString = "https://authenticate.example.com"
key := cryptutil.NewKey()
opts.SharedKey = base64.StdEncoding.EncodeToString(key)
flow, err := NewStateful(&config.Config{Options: opts}, nil)
require.NoError(t, err)
t.Run("NilQueryParams", func(t *testing.T) {
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
u, err := flow.AuthenticateSignInURL(nil, nil, redirectURL, "fake-idp-id")
assert.NoError(t, err)
parsed, _ := url.Parse(u)
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
assert.Equal(t, "https", parsed.Scheme)
assert.Equal(t, "authenticate.example.com", parsed.Host)
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
q := parsed.Query()
assert.Equal(t, "https://example.com", parsed.Query().Get("pomerium_redirect_uri"))
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
})
t.Run("ExtraQueryParams", func(t *testing.T) {
redirectURL := &url.URL{Scheme: "https", Host: "example.com"}
q := url.Values{}
q.Set("foo", "bar")
u, err := flow.AuthenticateSignInURL(nil, q, redirectURL, "fake-idp-id")
assert.NoError(t, err)
parsed, _ := url.Parse(u)
assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate())
assert.Equal(t, "https", parsed.Scheme)
assert.Equal(t, "authenticate.example.com", parsed.Host)
assert.Equal(t, "/.pomerium/sign_in", parsed.Path)
q = parsed.Query()
assert.Equal(t, "https://example.com", q.Get("pomerium_redirect_uri"))
assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id"))
assert.Equal(t, "bar", q.Get("foo"))
})
}
func TestStatefulGetIdentityProviderIDForURLValues(t *testing.T) {
flow := Stateful{defaultIdentityProviderID: "default-id"}
assert.Equal(t, "default-id", flow.GetIdentityProviderIDForURLValues(nil))
q := url.Values{"pomerium_idp_id": []string{"idp-id"}}
assert.Equal(t, "idp-id", flow.GetIdentityProviderIDForURLValues(q))
}
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
func TestStatefulCallback(t *testing.T) {
opts := config.NewDefaultOptions()
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
tests := []struct {
name string
qp map[string]string
validSignature bool
cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore
wantErrorMsg string
}{
{
"good",
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"",
},
{
"good programmatic",
map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"",
},
{
"invalid signature",
map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString},
false,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"Bad Request:",
},
{
"bad decrypt",
map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"proxy: callback token decrypt error:",
},
{
"bad save session",
map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{SaveError: errors.New("hi")},
"Internal Server Error: proxy: error saving session state:",
},
{
"bad base64",
map[string]string{urlutil.QuerySessionEncrypted: "^"},
true,
&mock.Encoder{MarshalResponse: []byte("x")},
&mstore.Store{Session: &sessions.State{}},
"proxy: malfromed callback token:",
},
{
"malformed redirect",
nil,
true,
&mock.Encoder{},
&mstore.Store{Session: &sessions.State{}},
"Bad Request:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
flow, err := NewStateful(&config.Config{Options: opts}, tt.sessionStore)
if err != nil {
t.Fatal(err)
}
flow.sharedEncoder = tt.cipher
redirectURI := &url.URL{Scheme: "http", Host: "example.com", Path: "/"}
queryString := redirectURI.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
redirectURI.RawQuery = queryString.Encode()
uri := &url.URL{Scheme: "https", Host: "example.com", Path: "/"}
if tt.qp != nil {
qu := uri.Query()
for k, v := range tt.qp {
qu.Set(k, v)
}
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
uri.RawQuery = qu.Encode()
}
if tt.validSignature {
sharedKey, _ := opts.GetSharedKey()
uri = urlutil.NewSignedURL(sharedKey, uri).Sign()
}
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
//fmt.Println(uri.String())
r.Host = r.URL.Host
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
err = flow.Callback(w, r)
if tt.wantErrorMsg == "" {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) {
t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err)
}
}
// XXX: assert redirect URL
})
}
}