package authenticate

import (
	"context"
	"encoding/base64"
	"errors"
	"fmt"
	"net/http"
	"net/http/httptest"
	"net/url"
	"strings"
	"testing"
	"time"

	"github.com/go-jose/go-jose/v3/jwt"
	"github.com/golang/mock/gomock"
	"github.com/golang/protobuf/ptypes/empty"
	"github.com/google/go-cmp/cmp"
	"github.com/stretchr/testify/assert"
	"golang.org/x/crypto/chacha20poly1305"
	"golang.org/x/oauth2"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/types/known/timestamppb"

	"github.com/pomerium/pomerium/authenticate/handlers/webauthn"
	"github.com/pomerium/pomerium/config"
	"github.com/pomerium/pomerium/internal/encoding"
	"github.com/pomerium/pomerium/internal/encoding/jws"
	"github.com/pomerium/pomerium/internal/encoding/mock"
	"github.com/pomerium/pomerium/internal/httputil"
	"github.com/pomerium/pomerium/internal/identity"
	"github.com/pomerium/pomerium/internal/identity/oidc"
	"github.com/pomerium/pomerium/internal/sessions"
	mstore "github.com/pomerium/pomerium/internal/sessions/mock"
	"github.com/pomerium/pomerium/internal/urlutil"
	"github.com/pomerium/pomerium/pkg/cryptutil"
	"github.com/pomerium/pomerium/pkg/grpc/databroker"
	"github.com/pomerium/pomerium/pkg/grpc/directory"
	"github.com/pomerium/pomerium/pkg/grpc/session"
)

func testAuthenticate() *Authenticate {
	redirectURL, _ := url.Parse("https://auth.example.com/oauth/callback")
	var auth Authenticate
	auth.state = newAtomicAuthenticateState(&authenticateState{
		redirectURL:  redirectURL,
		cookieSecret: cryptutil.NewKey(),
	})
	auth.options = config.NewAtomicOptions()
	auth.options.Store(&config.Options{
		SharedKey: cryptutil.NewBase64Key(),
	})
	return &auth
}

func TestAuthenticate_RobotsTxt(t *testing.T) {
	auth := testAuthenticate()
	req, err := http.NewRequest("GET", "/robots.txt", nil)
	if err != nil {
		t.Fatal(err)
	}
	rr := httptest.NewRecorder()
	handler := http.HandlerFunc(auth.RobotsTxt)
	handler.ServeHTTP(rr, req)
	if status := rr.Code; status != http.StatusOK {
		t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK)
	}
	expected := fmt.Sprintf("User-agent: *\nDisallow: /")
	if rr.Body.String() != expected {
		t.Errorf("handler returned wrong body: got %v want %v", rr.Body.String(), expected)
	}
}

func TestAuthenticate_Handler(t *testing.T) {
	auth := testAuthenticate()

	h := auth.Handler()
	if h == nil {
		t.Error("handler cannot be nil")
	}
	req := httptest.NewRequest("GET", "/robots.txt", nil)
	req.Header.Set("Accept", "application/json")

	rr := httptest.NewRecorder()
	h.ServeHTTP(rr, req)
	expected := fmt.Sprintf("User-agent: *\nDisallow: /")

	body := rr.Body.String()
	if body != expected {
		t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
	}

	// cors preflight
	req = httptest.NewRequest(http.MethodOptions, "/.pomerium/sign_in", nil)
	req.Header.Set("Accept", "application/json")
	req.Header.Set("Access-Control-Request-Method", "GET")
	req.Header.Set("Access-Control-Request-Headers", "X-Requested-With")
	rr = httptest.NewRecorder()
	h.ServeHTTP(rr, req)
	expected = fmt.Sprintf("User-agent: *\nDisallow: /")
	code := rr.Code
	if code/100 != 2 {
		t.Errorf("bad preflight code %v", code)
	}
	resp := rr.Result()
	body = resp.Header.Get("vary")
	if body == "" {
		t.Errorf("handler returned unexpected body: got %v want %v", body, expected)
	}
}

func TestAuthenticate_SignIn(t *testing.T) {
	t.Parallel()
	tests := []struct {
		name string

		scheme string
		host   string
		qp     map[string]string

		session  sessions.SessionStore
		provider identity.MockProvider
		encoder  encoding.MarshalUnmarshaler
		wantCode int
	}{
		{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
		{"good alternate port", "https", "corp.example.example:8443", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
		{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
		{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
		{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
		{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
		{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
		{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
		{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
		{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
		{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
		{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &mstore.Store{Session: &sessions.State{}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()

			sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())

			a := &Authenticate{
				cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
					return tt.provider, nil
				})),
				state: newAtomicAuthenticateState(&authenticateState{
					sharedCipher:     sharedCipher,
					sessionStore:     tt.session,
					redirectURL:      uriParseHelper("https://some.example"),
					sharedEncoder:    tt.encoder,
					encryptedEncoder: tt.encoder,
					dataBrokerClient: mockDataBrokerServiceClient{
						get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
							return &databroker.GetResponse{
								Record: databroker.NewRecord(&session.Session{
									Id: "SESSION_ID",
								}),
							}, nil
						},
					},
					directoryClient: new(mockDirectoryServiceClient),
				}),

				options: config.NewAtomicOptions(),
			}
			a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())})
			uri := &url.URL{Scheme: tt.scheme, Host: tt.host}

			queryString := uri.Query()
			for k, v := range tt.qp {
				queryString.Set(k, v)
			}
			uri.RawQuery = queryString.Encode()
			r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
			r.Header.Set("Accept", "application/json")
			state, err := tt.session.LoadSession(r)
			ctx := r.Context()
			ctx = sessions.NewContext(ctx, state, err)
			r = r.WithContext(ctx)

			w := httptest.NewRecorder()
			httputil.HandlerFunc(a.SignIn).ServeHTTP(w, r)
			if status := w.Code; status != tt.wantCode {
				t.Errorf("handler returned wrong status code: got %v want %v %s", status, tt.wantCode, uri)
				t.Errorf("\n%+v", w.Body)
			}
		})
	}
}

func uriParseHelper(s string) *url.URL {
	uri, _ := url.Parse(s)
	return uri
}

func TestAuthenticate_SignOut(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name   string
		method string

		ctxError           error
		redirectURL        string
		signoutRedirectURL string
		sig                string
		ts                 string

		provider     identity.Authenticator
		sessionStore sessions.SessionStore
		wantCode     int
		wantBody     string
	}{
		{
			"good post",
			http.MethodPost,
			nil,
			"https://corp.pomerium.io/",
			"",
			"sig",
			"ts",
			identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
			&mstore.Store{Encrypted: true, Session: &sessions.State{}},
			http.StatusFound,
			"",
		},
		{
			"signout redirect url",
			http.MethodPost,
			nil,
			"",
			"https://signout-redirect-url.example.com",
			"sig",
			"ts",
			identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
			&mstore.Store{Encrypted: true, Session: &sessions.State{}},
			http.StatusFound,
			"",
		},
		{
			"failed revoke",
			http.MethodPost,
			nil,
			"https://corp.pomerium.io/",
			"",
			"sig",
			"ts",
			identity.MockProvider{RevokeError: errors.New("OH NO")},
			&mstore.Store{Encrypted: true, Session: &sessions.State{}},
			http.StatusFound,
			"",
		},
		{
			"load session error",
			http.MethodPost,
			errors.New("error"),
			"https://corp.pomerium.io/",
			"",
			"sig",
			"ts",
			identity.MockProvider{RevokeError: errors.New("OH NO")},
			&mstore.Store{Encrypted: true, Session: &sessions.State{}},
			http.StatusFound,
			"",
		},
		{
			"bad redirect uri",
			http.MethodPost,
			nil,
			"corp.pomerium.io/",
			"",
			"sig",
			"ts",
			identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented},
			&mstore.Store{Encrypted: true, Session: &sessions.State{}},
			http.StatusFound,
			"",
		},
		{
			"no redirect uri",
			http.MethodPost,
			nil,
			"",
			"",
			"sig",
			"ts",
			identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
			&mstore.Store{Encrypted: true, Session: &sessions.State{}},
			http.StatusOK,
			"{\"Status\":200,\"Error\":\"OK: user logged out\"}\n",
		},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()
			a := &Authenticate{
				cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
					return tt.provider, nil
				})),
				state: newAtomicAuthenticateState(&authenticateState{
					sessionStore:     tt.sessionStore,
					encryptedEncoder: mock.Encoder{},
					sharedEncoder:    mock.Encoder{},
					dataBrokerClient: mockDataBrokerServiceClient{
						get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
							return &databroker.GetResponse{
								Record: databroker.NewRecord(&session.Session{
									Id: "SESSION_ID",
								}),
							}, nil
						},
						put: func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
							return nil, nil
						},
					},
					directoryClient: new(mockDirectoryServiceClient),
				}),
				options: config.NewAtomicOptions(),
			}
			if tt.signoutRedirectURL != "" {
				opts := a.options.Load()
				opts.SignOutRedirectURLString = tt.signoutRedirectURL
				a.options.Store(opts)
			}
			u, _ := url.Parse("/sign_out")
			params, _ := url.ParseQuery(u.RawQuery)
			params.Add("sig", tt.sig)
			params.Add("ts", tt.ts)
			params.Add(urlutil.QueryRedirectURI, tt.redirectURL)
			u.RawQuery = params.Encode()
			r := httptest.NewRequest(tt.method, u.String(), nil)
			state, err := tt.sessionStore.LoadSession(r)
			if err != nil {
				t.Fatal(err)
			}
			ctx := r.Context()
			ctx = sessions.NewContext(ctx, state, tt.ctxError)
			r = r.WithContext(ctx)
			r.Header.Set("Accept", "application/json")

			w := httptest.NewRecorder()
			httputil.HandlerFunc(a.signOutRedirect).ServeHTTP(w, r)
			if status := w.Code; status != tt.wantCode {
				t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
			}
			body := w.Body.String()
			if diff := cmp.Diff(body, tt.wantBody); diff != "" {
				t.Errorf("handler returned wrong body Body: %s", diff)
			}
			if tt.signoutRedirectURL != "" {
				loc := w.Header().Get("Location")
				assert.Contains(t, loc, url.QueryEscape(tt.signoutRedirectURL))
			}
		})
	}
}

func TestAuthenticate_OAuthCallback(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name   string
		method string

		ts            int64
		stateOvveride string
		extraMac      string
		extraState    string
		paramErr      string
		code          string
		redirectURI   string

		authenticateURL string
		session         sessions.SessionStore
		provider        identity.MockProvider

		want     string
		wantCode int
	}{
		{"good", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusFound},
		{"failed authenticate", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}, AuthenticateError: errors.New("error")}, "", http.StatusInternalServerError},
		{"failed save session", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{SaveError: errors.New("error")}, identity.MockProvider{}, "", http.StatusInternalServerError},
		{"provider returned error", http.MethodGet, time.Now().Unix(), "", "", "", "idp error", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
		{"provider returned error imply 401", http.MethodGet, time.Now().Unix(), "", "", "", "access_denied", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusUnauthorized},
		{"empty code", http.MethodGet, time.Now().Unix(), "", "", "", "", "", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
		{"invalid redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "", http.StatusBadRequest},
		{"bad redirect uri", http.MethodGet, time.Now().Unix(), "", "", "", "", "code", "http://^^^", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
		{"bad timing - too soon", http.MethodGet, time.Now().Add(1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
		{"bad timing - expired", http.MethodGet, time.Now().Add(-1 * time.Hour).Unix(), "", "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
		{"bad base64", http.MethodGet, time.Now().Unix(), "", "", "^", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
		{"too many seperators", http.MethodGet, time.Now().Unix(), "", "", "|ok|now|what", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
		{"bad hmac", http.MethodGet, time.Now().Unix(), "", "NOTMAC", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
		{"bad hmac", http.MethodGet, time.Now().Unix(), base64.URLEncoding.EncodeToString([]byte("malformed_state")), "", "", "", "code", "https://corp.pomerium.io", "https://authenticate.pomerium.io", &mstore.Store{}, identity.MockProvider{AuthenticateResponse: oauth2.Token{}}, "https://corp.pomerium.io", http.StatusBadRequest},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()
			aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
			if err != nil {
				t.Fatal(err)
			}
			signer, err := jws.NewHS256Signer(nil)
			if err != nil {
				t.Fatal(err)
			}
			authURL, _ := url.Parse(tt.authenticateURL)
			a := &Authenticate{
				cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
					return tt.provider, nil
				})),
				state: newAtomicAuthenticateState(&authenticateState{
					dataBrokerClient: mockDataBrokerServiceClient{
						get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
							return nil, fmt.Errorf("not implemented")
						},
						put: func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
							return nil, nil
						},
					},
					directoryClient:  new(mockDirectoryServiceClient),
					redirectURL:      authURL,
					sessionStore:     tt.session,
					cookieCipher:     aead,
					encryptedEncoder: signer,
				}),
				options: config.NewAtomicOptions(),
			}
			u, _ := url.Parse("/oauthGet")
			params, _ := url.ParseQuery(u.RawQuery)
			params.Add("error", tt.paramErr)
			params.Add("code", tt.code)
			nonce := cryptutil.NewBase64Key() // mock csrf
			// (nonce|timestamp|redirect_url|encrypt(redirect_url),mac(nonce,ts))
			b := []byte(fmt.Sprintf("%s|%d|%s", nonce, tt.ts, tt.extraMac))

			enc := cryptutil.Encrypt(a.state.Load().cookieCipher, []byte(tt.redirectURI), b)
			b = append(b, enc...)
			encodedState := base64.URLEncoding.EncodeToString(b)
			if tt.extraState != "" {
				encodedState += tt.extraState
			}
			if tt.stateOvveride != "" {
				encodedState = tt.stateOvveride
			}
			params.Add("state", encodedState)

			u.RawQuery = params.Encode()

			r := httptest.NewRequest(tt.method, u.String(), nil)
			r.Header.Set("Accept", "application/json")
			w := httptest.NewRecorder()
			httputil.HandlerFunc(a.OAuthCallback).ServeHTTP(w, r)
			if w.Result().StatusCode != tt.wantCode {
				t.Errorf("Authenticate.OAuthCallback() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantCode, w.Body.String())
				return
			}
		})
	}
}

func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
	t.Parallel()
	fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
		fmt.Fprintln(w, "RVSI FILIVS CAISAR")
		w.WriteHeader(http.StatusOK)
	})

	tests := []struct {
		name    string
		headers map[string]string

		session  sessions.SessionStore
		ctxError error
		provider identity.Authenticator

		wantStatus int
	}{
		{
			"good",
			nil,
			&mstore.Store{Session: &sessions.State{ID: "xyz"}},
			nil,
			identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
			http.StatusOK,
		},
		{
			"invalid session",
			nil,
			&mstore.Store{Session: &sessions.State{ID: "xyz"}},
			errors.New("hi"),
			identity.MockProvider{},
			http.StatusFound,
		},
		{
			"good refresh expired",
			nil,
			&mstore.Store{Session: &sessions.State{ID: "xyz"}},
			nil,
			identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
			http.StatusOK,
		},
		{
			"expired,refresh error",
			nil,
			&mstore.Store{Session: &sessions.State{ID: "xyz"}},
			sessions.ErrExpired,
			identity.MockProvider{RefreshError: errors.New("error")},
			http.StatusFound,
		},
		{
			"expired,save error",
			nil,
			&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{ID: "xyz"}},
			sessions.ErrExpired,
			identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
			http.StatusFound,
		},
		{
			"expired XHR,refresh error",
			map[string]string{"X-Requested-With": "XmlHttpRequest"},
			&mstore.Store{Session: &sessions.State{ID: "xyz"}},
			sessions.ErrExpired,
			identity.MockProvider{RefreshError: errors.New("error")},
			http.StatusUnauthorized,
		},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()

			aead, err := chacha20poly1305.NewX(cryptutil.NewKey())
			if err != nil {
				t.Fatal(err)
			}
			signer, err := jws.NewHS256Signer(nil)
			if err != nil {
				t.Fatal(err)
			}
			a := &Authenticate{
				cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
					return tt.provider, nil
				})),
				state: newAtomicAuthenticateState(&authenticateState{
					cookieSecret:     cryptutil.NewKey(),
					redirectURL:      uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
					sessionStore:     tt.session,
					cookieCipher:     aead,
					encryptedEncoder: signer,
					sharedEncoder:    signer,
					dataBrokerClient: mockDataBrokerServiceClient{
						get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
							return &databroker.GetResponse{
								Record: databroker.NewRecord(&session.Session{
									Id: "SESSION_ID",
								}),
							}, nil
						},
					},
					directoryClient: new(mockDirectoryServiceClient),
				}),
				options: config.NewAtomicOptions(),
			}
			r := httptest.NewRequest("GET", "/", nil)
			state, err := tt.session.LoadSession(r)
			if err != nil {
				t.Fatal(err)
			}
			ctx := r.Context()
			ctx = sessions.NewContext(ctx, state, tt.ctxError)
			r = r.WithContext(ctx)

			r.Header.Set("Accept", "application/json")
			if len(tt.headers) != 0 {
				for k, v := range tt.headers {
					r.Header.Set(k, v)
				}
			}
			w := httptest.NewRecorder()

			got := a.VerifySession(fn)
			got.ServeHTTP(w, r)
			if status := w.Code; status != tt.wantStatus {
				t.Errorf("VerifySession() error = %v, wantErr %v\n%v\n%v", w.Result().StatusCode, tt.wantStatus, w.Header(), w.Body.String())
			}
		})
	}
}

func TestWellKnownEndpoint(t *testing.T) {
	auth := testAuthenticate()

	h := auth.Handler()
	if h == nil {
		t.Error("handler cannot be nil")
	}
	req := httptest.NewRequest("GET", "/.well-known/pomerium/", nil)
	req.Header.Set("Accept", "application/json")
	rr := httptest.NewRecorder()
	h.ServeHTTP(rr, req)
	body := rr.Body.String()
	expected := "{\"authentication_callback_endpoint\":\"https://auth.example.com/oauth2/callback\",\"jwks_uri\":\"https://auth.example.com/.well-known/pomerium/jwks.json\",\"frontchannel_logout_uri\":\"https://auth.example.com/.pomerium/sign_out\"}\n"
	assert.Equal(t, body, expected)
}

func TestJwksEndpoint(t *testing.T) {
	o := newTestOptions(t)
	o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
	auth, err := New(&config.Config{Options: o})
	if err != nil {
		t.Fatal(err)
		return
	}
	h := auth.Handler()
	if h == nil {
		t.Error("handler cannot be nil")
	}
	req := httptest.NewRequest("GET", "/.well-known/pomerium/jwks.json", nil)
	req.Header.Set("Accept", "application/json")
	rr := httptest.NewRecorder()
	h.ServeHTTP(rr, req)
	body := rr.Body.String()
	expected := "{\"keys\":[{\"use\":\"sig\",\"kty\":\"EC\",\"kid\":\"5b419ade1895fec2d2def6cd33b1b9a018df60db231dc5ecb85cbed6d942813c\",\"crv\":\"P-256\",\"alg\":\"ES256\",\"x\":\"UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo\",\"y\":\"KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ\"}]}\n"
	assert.Equal(t, expected, body)
}

func TestAuthenticate_userInfo(t *testing.T) {
	t.Parallel()

	now := time.Now()
	tests := []struct {
		name         string
		url          *url.URL
		method       string
		sessionStore sessions.SessionStore
		wantCode     int
		wantBody     string
	}{
		{
			"good",
			mustParseURL("/"),
			http.MethodGet,
			&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
			http.StatusOK,
			"",
		},
		{
			"missing signature",
			mustParseURL("/?pomerium_redirect_uri=http://example.com"),
			http.MethodGet,
			&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
			http.StatusBadRequest,
			"",
		},
		{
			"bad signature",
			urlutil.NewSignedURL([]byte("BAD KEY"), mustParseURL("/?pomerium_redirect_uri=http://example.com")).Sign(),
			http.MethodGet,
			&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
			http.StatusBadRequest,
			"",
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			ctrl := gomock.NewController(t)
			defer ctrl.Finish()

			signer, err := jws.NewHS256Signer(nil)
			if err != nil {
				t.Fatal(err)
			}
			o := config.NewAtomicOptions()
			o.Store(&config.Options{
				AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
				SharedKey:             "SHARED KEY",
			})
			a := &Authenticate{
				options: o,
				state: newAtomicAuthenticateState(&authenticateState{
					sessionStore:     tt.sessionStore,
					encryptedEncoder: signer,
					sharedEncoder:    signer,
					dataBrokerClient: mockDataBrokerServiceClient{
						get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
							return &databroker.GetResponse{
								Record: databroker.NewRecord(&session.Session{
									Id:      "SESSION_ID",
									UserId:  "USER_ID",
									IdToken: &session.IDToken{IssuedAt: timestamppb.New(now)},
								}),
							}, nil
						},
					},
					directoryClient: new(mockDirectoryServiceClient),
				}),
			}
			a.webauthn = webauthn.New(a.getWebauthnState)
			r := httptest.NewRequest(tt.method, tt.url.String(), nil)
			state, err := tt.sessionStore.LoadSession(r)
			if err != nil {
				t.Fatal(err)
			}
			ctx := r.Context()
			ctx = sessions.NewContext(ctx, state, nil)
			r = r.WithContext(ctx)
			r.Header.Set("Accept", "application/json")

			w := httptest.NewRecorder()
			a.requireValidSignatureOnRedirect(a.userInfo).ServeHTTP(w, r)
			if status := w.Code; status != tt.wantCode {
				t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
			}
			body := w.Body.String()
			if !strings.Contains(body, tt.wantBody) {
				t.Errorf("Unexpected body, contains: %s, got: %s", tt.wantBody, body)
			}
		})
	}
}

type mockDataBrokerServiceClient struct {
	databroker.DataBrokerServiceClient

	get func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error)
	put func(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error)
}

func (m mockDataBrokerServiceClient) Get(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
	return m.get(ctx, in, opts...)
}

func (m mockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
	return m.put(ctx, in, opts...)
}

type mockDirectoryServiceClient struct {
	directory.DirectoryServiceClient

	refreshUser func(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error)
}

func (m mockDirectoryServiceClient) RefreshUser(ctx context.Context, in *directory.RefreshUserRequest, opts ...grpc.CallOption) (*empty.Empty, error) {
	if m.refreshUser != nil {
		return m.refreshUser(ctx, in, opts...)
	}
	return nil, status.Error(codes.Unimplemented, "")
}

func mustParseURL(rawurl string) *url.URL {
	u, err := url.Parse(rawurl)
	if err != nil {
		panic(err)
	}
	return u
}