diff --git a/internal/authenticateflow/request.go b/internal/authenticateflow/request.go index 09eb8f154..813be9f26 100644 --- a/internal/authenticateflow/request.go +++ b/internal/authenticateflow/request.go @@ -13,6 +13,11 @@ type signatureVerifier struct { sharedKey []byte } +// VerifySignature checks that the provided request has a valid signature. +func (v signatureVerifier) VerifySignature(r *http.Request) error { + return middleware.ValidateRequestURL(r, v.sharedKey) +} + // VerifyAuthenticateSignature checks that the provided request has a valid // signature (for the authenticate service). func (v signatureVerifier) VerifyAuthenticateSignature(r *http.Request) error { diff --git a/internal/authenticateflow/stateful.go b/internal/authenticateflow/stateful.go new file mode 100644 index 000000000..51e0fa512 --- /dev/null +++ b/internal/authenticateflow/stateful.go @@ -0,0 +1,361 @@ +package authenticateflow + +import ( + "context" + "crypto/cipher" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "time" + + "golang.org/x/oauth2" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/encoding/jws" + "github.com/pomerium/pomerium/internal/handlers" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/identity" + "github.com/pomerium/pomerium/internal/identity/manager" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" +) + +// Stateful implements the stateful authentication flow. In this flow, the +// authenticate service has direct access to the databroker. +type Stateful struct { + signatureVerifier + + // sharedEncoder is the encoder to use to serialize data to be consumed + // by other services + sharedEncoder encoding.MarshalUnmarshaler + // sharedKey is the secret to encrypt and authenticate data shared between services + sharedKey []byte + // sharedCipher is the cipher to use to encrypt/decrypt data shared between services + sharedCipher cipher.AEAD + // sessionDuration is the maximum Pomerium session duration + sessionDuration time.Duration + // sessionStore is the session store used to persist a user's session + sessionStore sessions.SessionStore + + defaultIdentityProviderID string + + authenticateURL *url.URL + + dataBrokerClient databroker.DataBrokerServiceClient +} + +// NewStateful initializes the authentication flow for the given configuration +// and session store. +func NewStateful(cfg *config.Config, sessionStore sessions.SessionStore) (*Stateful, error) { + s := &Stateful{ + sessionDuration: cfg.Options.CookieExpire, + sessionStore: sessionStore, + } + + var err error + s.authenticateURL, err = cfg.Options.GetAuthenticateURL() + if err != nil { + return nil, err + } + // shared cipher to encrypt data before passing data between services + s.sharedKey, err = cfg.Options.GetSharedKey() + if err != nil { + return nil, err + } + s.sharedCipher, err = cryptutil.NewAEADCipher(s.sharedKey) + if err != nil { + return nil, err + } + // shared state encoder setup + s.sharedEncoder, err = jws.NewHS256Signer(s.sharedKey) + if err != nil { + return nil, err + } + s.signatureVerifier = signatureVerifier{cfg.Options, s.sharedKey} + + idp, err := cfg.Options.GetIdentityProviderForPolicy(nil) + if err == nil { + s.defaultIdentityProviderID = idp.GetId() + } + + dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), + &grpc.OutboundOptions{ + OutboundPort: cfg.OutboundPort, + InstallationID: cfg.Options.InstallationID, + ServiceName: cfg.Options.Services, + SignedJWTKey: s.sharedKey, + }) + if err != nil { + return nil, err + } + + s.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn) + return s, nil +} + +// SignIn redirects to a route callback URL, if the provided request and +// session state are valid. +func (s *Stateful) SignIn( + w http.ResponseWriter, + r *http.Request, + sessionState *sessions.State, +) error { + if err := s.VerifyAuthenticateSignature(r); err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + idpID := r.FormValue(urlutil.QueryIdentityProviderID) + + // start over if this is a different identity provider + if sessionState == nil || sessionState.IdentityProviderID != idpID { + sessionState = sessions.NewState(idpID) + } + + redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + jwtAudience := []string{s.authenticateURL.Host, redirectURL.Host} + + // if the callback is explicitly set, set it and add an additional audience + if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" { + callbackURL, err := urlutil.ParseAndValidateURL(callbackStr) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + jwtAudience = append(jwtAudience, callbackURL.Host) + } + + newSession := sessionState.WithNewIssuer(s.authenticateURL.Host, jwtAudience) + + // re-persist the session, useful when session was evicted from session store + if err := s.sessionStore.SaveSession(w, r, sessionState); err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + // sign the route session, as a JWT + signedJWT, err := s.sharedEncoder.Marshal(newSession) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + // encrypt our route-scoped JWT to avoid accidental logging of queryparams + encryptedJWT := cryptutil.Encrypt(s.sharedCipher, signedJWT, nil) + // base64 our encrypted payload for URL-friendlyness + encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT) + + callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + // build our hmac-d redirect URL with our session, pointing back to the + // proxy's callback URL which is responsible for setting our new route-session + uri := urlutil.NewSignedURL(s.sharedKey, callbackURL) + httputil.Redirect(w, r, uri.String(), http.StatusFound) + return nil +} + +// PersistSession stores session and user data in the databroker. +func (s *Stateful) PersistSession( + ctx context.Context, + _ http.ResponseWriter, + sessionState *sessions.State, + claims identity.SessionClaims, + accessToken *oauth2.Token, +) error { + sessionExpiry := timestamppb.New(time.Now().Add(s.sessionDuration)) + idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time()) + + sess := &session.Session{ + Id: sessionState.ID, + UserId: sessionState.UserID(), + IssuedAt: timestamppb.Now(), + AccessedAt: timestamppb.Now(), + ExpiresAt: sessionExpiry, + IdToken: &session.IDToken{ + Issuer: sessionState.Issuer, // todo(bdd): the issuer is not authN but the downstream IdP from the claims + Subject: sessionState.Subject, + ExpiresAt: sessionExpiry, + IssuedAt: idTokenIssuedAt, + }, + OauthToken: manager.ToOAuthToken(accessToken), + Audience: sessionState.Audience, + } + sess.SetRawIDToken(claims.RawIDToken) + sess.AddClaims(claims.Flatten()) + + var managerUser manager.User + managerUser.User, _ = user.Get(ctx, s.dataBrokerClient, sess.GetUserId()) + if managerUser.User == nil { + // if no user exists yet, create a new one + managerUser.User = &user.User{ + Id: sess.GetUserId(), + } + } + populateUserFromClaims(managerUser.User, claims.Claims) + _, err := databroker.Put(ctx, s.dataBrokerClient, managerUser.User) + if err != nil { + return fmt.Errorf("authenticate: error saving user: %w", err) + } + + res, err := session.Put(ctx, s.dataBrokerClient, sess) + if err != nil { + return fmt.Errorf("authenticate: error saving session: %w", err) + } + sessionState.DatabrokerServerVersion = res.GetServerVersion() + sessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion() + + return nil +} + +// GetUserInfoData returns user info data associated with the given request (if +// any). +func (s *Stateful) GetUserInfoData( + r *http.Request, sessionState *sessions.State, +) handlers.UserInfoData { + var isImpersonated bool + pbSession, err := session.Get(r.Context(), s.dataBrokerClient, sessionState.ID) + if sid := pbSession.GetImpersonateSessionId(); sid != "" { + pbSession, err = session.Get(r.Context(), s.dataBrokerClient, sid) + isImpersonated = true + } + if err != nil { + pbSession = &session.Session{ + Id: sessionState.ID, + } + } + + pbUser, err := user.Get(r.Context(), s.dataBrokerClient, pbSession.GetUserId()) + if err != nil { + pbUser = &user.User{ + Id: pbSession.GetUserId(), + } + } + return handlers.UserInfoData{ + IsImpersonated: isImpersonated, + Session: pbSession, + User: pbUser, + } +} + +// RevokeSession revokes the session associated with the provided request, +// returning the ID token from the revoked session. +func (s *Stateful) RevokeSession( + ctx context.Context, + _ *http.Request, + authenticator identity.Authenticator, + sessionState *sessions.State, +) string { + if sessionState == nil { + return "" + } + + var rawIDToken string + sess, _ := session.Get(ctx, s.dataBrokerClient, sessionState.ID) + if sess != nil && sess.OauthToken != nil { + rawIDToken = sess.GetIdToken().GetRaw() + if err := authenticator.Revoke(ctx, manager.FromOAuthToken(sess.OauthToken)); err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") + } + } + if err := session.Delete(ctx, s.dataBrokerClient, sessionState.ID); err != nil { + log.Ctx(ctx).Warn().Err(err). + Msg("authenticate: failed to delete session from session store") + } + return rawIDToken +} + +// VerifySession checks that an existing session is still valid. +func (s *Stateful) VerifySession( + ctx context.Context, _ *http.Request, sessionState *sessions.State, +) error { + sess, err := session.Get(ctx, s.dataBrokerClient, sessionState.ID) + if err != nil { + return fmt.Errorf("session not found in databroker: %w", err) + } + return sess.Validate() +} + +// LogAuthenticateEvent is a no-op for the stateful authentication flow. +func (s *Stateful) LogAuthenticateEvent(*http.Request) {} + +// AuthenticateSignInURL returns a URL to redirect the user to the authenticate +// domain. +func (s *Stateful) AuthenticateSignInURL( + _ context.Context, queryParams url.Values, redirectURL *url.URL, idpID string, +) (string, error) { + signinURL := s.authenticateURL.ResolveReference(&url.URL{ + Path: "/.pomerium/sign_in", + }) + + if queryParams == nil { + queryParams = url.Values{} + } + queryParams.Set(urlutil.QueryRedirectURI, redirectURL.String()) + queryParams.Set(urlutil.QueryIdentityProviderID, idpID) + signinURL.RawQuery = queryParams.Encode() + redirectTo := urlutil.NewSignedURL(s.sharedKey, signinURL).String() + + return redirectTo, nil +} + +// GetIdentityProviderIDForURLValues returns the identity provider ID +// associated with the given URL values. +func (s *Stateful) GetIdentityProviderIDForURLValues(vs url.Values) string { + if id := vs.Get(urlutil.QueryIdentityProviderID); id != "" { + return id + } + return s.defaultIdentityProviderID +} + +// Callback handles a redirect to a route domain once signed in. +func (s *Stateful) Callback(w http.ResponseWriter, r *http.Request) error { + if err := s.VerifySignature(r); err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + redirectURLString := r.FormValue(urlutil.QueryRedirectURI) + encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) + + redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + encryptedJWT, err := base64.URLEncoding.DecodeString(encryptedSession) + if err != nil { + return fmt.Errorf("proxy: malfromed callback token: %w", err) + } + + rawJWT, err := cryptutil.Decrypt(s.sharedCipher, encryptedJWT, nil) + if err != nil { + return fmt.Errorf("proxy: callback token decrypt error: %w", err) + } + + // save the session state + if err = s.sessionStore.SaveSession(w, r, rawJWT); err != nil { + return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err)) + } + + // if programmatic, encode the session jwt as a query param + if isProgrammatic := r.FormValue(urlutil.QueryIsProgrammatic); isProgrammatic == "true" { + q := redirectURL.Query() + q.Set(urlutil.QueryPomeriumJWT, string(rawJWT)) + redirectURL.RawQuery = q.Encode() + } + + // redirect + httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) + return nil +} diff --git a/internal/authenticateflow/stateful_test.go b/internal/authenticateflow/stateful_test.go new file mode 100644 index 000000000..a5f0e83ef --- /dev/null +++ b/internal/authenticateflow/stateful_test.go @@ -0,0 +1,271 @@ +package authenticateflow + +import ( + "encoding/base64" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/encoding/mock" + "github.com/pomerium/pomerium/internal/sessions" + mstore "github.com/pomerium/pomerium/internal/sessions/mock" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/cryptutil" +) + +func TestStatefulSignIn(t *testing.T) { + opts := config.NewDefaultOptions() + tests := []struct { + name string + + host string + qp map[string]string + validSignature bool + + session *sessions.State + encoder encoding.MarshalUnmarshaler + saveError error + + wantErrorMsg string + wantRedirectBaseURL string + }{ + {"good", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"}, + {"good alternate port", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"}, + {"invalid signature", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, false, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""}, + {"bad redirect uri query", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""}, + {"bad marshal", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""}, + {"good with different programmatic redirect", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example"}, + {"encrypted encoder error", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, true, &sessions.State{}, &mock.Encoder{MarshalError: errors.New("error")}, nil, "Bad Request: error", ""}, + {"good with callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://some.example/"}, + {"bad callback uri set", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "Bad Request:", ""}, + {"good programmatic request", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, true, &sessions.State{}, &mock.Encoder{}, nil, "", "https://dst.some.example/.pomerium/callback/"}, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + sessionStore := &mstore.Store{SaveError: tt.saveError} + flow, err := NewStateful(&config.Config{Options: opts}, sessionStore) + if err != nil { + t.Fatal(err) + } + flow.sharedEncoder = tt.encoder + + uri := &url.URL{Scheme: "https", Host: tt.host} + queryString := uri.Query() + for k, v := range tt.qp { + queryString.Set(k, v) + } + uri.RawQuery = queryString.Encode() + if tt.validSignature { + sharedKey, _ := opts.GetSharedKey() + uri = urlutil.NewSignedURL(sharedKey, uri).Sign() + } + + r := httptest.NewRequest(http.MethodGet, uri.String(), nil) + r.Header.Set("Accept", "application/json") + + w := httptest.NewRecorder() + err = flow.SignIn(w, r, tt.session) + result := w.Result() + if tt.wantErrorMsg == "" { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + expectedStatus := "302 Found" + if result.Status != expectedStatus { + t.Errorf("wrong status code: got %v, want %v", result.Status, expectedStatus) + } + loc, err := url.Parse(result.Header.Get("Location")) + if err != nil { + t.Fatalf("couldn't parse redirect URL: %v", err) + } + loc.RawQuery = "" // ignore the query parameters + if loc.String() != tt.wantRedirectBaseURL { + t.Errorf("wrong redirect base URL: got %q, want %q", + loc.String(), tt.wantRedirectBaseURL) + } + } else { + if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) { + t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err) + } + } + }) + } +} + +func TestStatefulAuthenticateSignInURL(t *testing.T) { + opts := config.NewDefaultOptions() + opts.AuthenticateURLString = "https://authenticate.example.com" + key := cryptutil.NewKey() + opts.SharedKey = base64.StdEncoding.EncodeToString(key) + flow, err := NewStateful(&config.Config{Options: opts}, nil) + require.NoError(t, err) + + t.Run("NilQueryParams", func(t *testing.T) { + redirectURL := &url.URL{Scheme: "https", Host: "example.com"} + u, err := flow.AuthenticateSignInURL(nil, nil, redirectURL, "fake-idp-id") + assert.NoError(t, err) + parsed, _ := url.Parse(u) + assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate()) + assert.Equal(t, "https", parsed.Scheme) + assert.Equal(t, "authenticate.example.com", parsed.Host) + assert.Equal(t, "/.pomerium/sign_in", parsed.Path) + q := parsed.Query() + assert.Equal(t, "https://example.com", parsed.Query().Get("pomerium_redirect_uri")) + assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id")) + }) + t.Run("ExtraQueryParams", func(t *testing.T) { + redirectURL := &url.URL{Scheme: "https", Host: "example.com"} + q := url.Values{} + q.Set("foo", "bar") + u, err := flow.AuthenticateSignInURL(nil, q, redirectURL, "fake-idp-id") + assert.NoError(t, err) + parsed, _ := url.Parse(u) + assert.NoError(t, urlutil.NewSignedURL(key, parsed).Validate()) + assert.Equal(t, "https", parsed.Scheme) + assert.Equal(t, "authenticate.example.com", parsed.Host) + assert.Equal(t, "/.pomerium/sign_in", parsed.Path) + q = parsed.Query() + assert.Equal(t, "https://example.com", q.Get("pomerium_redirect_uri")) + assert.Equal(t, "fake-idp-id", q.Get("pomerium_idp_id")) + assert.Equal(t, "bar", q.Get("foo")) + }) +} + +func TestStatefulGetIdentityProviderIDForURLValues(t *testing.T) { + flow := Stateful{defaultIdentityProviderID: "default-id"} + assert.Equal(t, "default-id", flow.GetIdentityProviderIDForURLValues(nil)) + q := url.Values{"pomerium_idp_id": []string{"idp-id"}} + assert.Equal(t, "idp-id", flow.GetIdentityProviderIDForURLValues(q)) +} + +const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg=" + +func TestStatefulCallback(t *testing.T) { + opts := config.NewDefaultOptions() + opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" + tests := []struct { + name string + + qp map[string]string + validSignature bool + cipher encoding.MarshalUnmarshaler + sessionStore sessions.SessionStore + + wantErrorMsg string + }{ + { + "good", + map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, + true, + &mock.Encoder{MarshalResponse: []byte("x")}, + &mstore.Store{Session: &sessions.State{}}, + "", + }, + { + "good programmatic", + map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, + true, + &mock.Encoder{MarshalResponse: []byte("x")}, + &mstore.Store{Session: &sessions.State{}}, + "", + }, + { + "invalid signature", + map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, + false, + &mock.Encoder{MarshalResponse: []byte("x")}, + &mstore.Store{Session: &sessions.State{}}, + "Bad Request:", + }, + { + "bad decrypt", + map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, + true, + &mock.Encoder{MarshalResponse: []byte("x")}, + &mstore.Store{Session: &sessions.State{}}, + "proxy: callback token decrypt error:", + }, + { + "bad save session", + map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, + true, + &mock.Encoder{MarshalResponse: []byte("x")}, + &mstore.Store{SaveError: errors.New("hi")}, + "Internal Server Error: proxy: error saving session state:", + }, + { + "bad base64", + map[string]string{urlutil.QuerySessionEncrypted: "^"}, + true, + &mock.Encoder{MarshalResponse: []byte("x")}, + &mstore.Store{Session: &sessions.State{}}, + "proxy: malfromed callback token:", + }, + { + "malformed redirect", + nil, + true, + &mock.Encoder{}, + &mstore.Store{Session: &sessions.State{}}, + "Bad Request:", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flow, err := NewStateful(&config.Config{Options: opts}, tt.sessionStore) + if err != nil { + t.Fatal(err) + } + flow.sharedEncoder = tt.cipher + redirectURI := &url.URL{Scheme: "http", Host: "example.com", Path: "/"} + queryString := redirectURI.Query() + for k, v := range tt.qp { + queryString.Set(k, v) + } + redirectURI.RawQuery = queryString.Encode() + + uri := &url.URL{Scheme: "https", Host: "example.com", Path: "/"} + if tt.qp != nil { + qu := uri.Query() + for k, v := range tt.qp { + qu.Set(k, v) + } + qu.Set(urlutil.QueryRedirectURI, redirectURI.String()) + uri.RawQuery = qu.Encode() + } + if tt.validSignature { + sharedKey, _ := opts.GetSharedKey() + uri = urlutil.NewSignedURL(sharedKey, uri).Sign() + } + + r := httptest.NewRequest(http.MethodGet, uri.String(), nil) + //fmt.Println(uri.String()) + r.Host = r.URL.Host + + r.Header.Set("Accept", "application/json") + + w := httptest.NewRecorder() + err = flow.Callback(w, r) + if tt.wantErrorMsg == "" { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } else { + if err == nil || !strings.Contains(err.Error(), tt.wantErrorMsg) { + t.Errorf("expected error containing %q; got %v", tt.wantErrorMsg, err) + } + } + + // XXX: assert redirect URL + }) + } +}