diff --git a/config/session_test.go b/config/session_test.go index 6ff7fc7f0..c4a615bde 100644 --- a/config/session_test.go +++ b/config/session_test.go @@ -1,8 +1,11 @@ package config import ( + "context" "encoding/base64" + "encoding/json" "net/http" + "net/http/httptest" "net/url" "testing" "time" @@ -18,11 +21,14 @@ import ( "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/authenticateapi" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" identitypb "github.com/pomerium/pomerium/pkg/grpc/identity" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/identity" + "github.com/pomerium/pomerium/pkg/storage" ) func TestSessionStore_LoadSessionState(t *testing.T) { @@ -440,3 +446,89 @@ func Test_newUserFromIDPClaims(t *testing.T) { }) } } + +func TestIncomingIDPTokenSessionCreator_CreateSession(t *testing.T) { + t.Parallel() + + t.Run("access_token", func(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/.pomerium/verify-access-token", func(w http.ResponseWriter, _ *http.Request) { + json.NewEncoder(w).Encode(&authenticateapi.VerifyTokenResponse{ + Valid: true, + Claims: jwtutil.Claims{"sub": "U1"}, + }) + }) + srv := httptest.NewTLSServer(mux) + + ctx := testutil.GetContext(t, time.Minute) + cfg := &Config{Options: NewDefaultOptions()} + cfg.Options.AuthenticateURLString = srv.URL + cfg.Options.ClientSecret = "CLIENT_SECRET_1" + cfg.Options.ClientID = "CLIENT_ID_1" + route := &Policy{} + route.IDPClientSecret = "CLIENT_SECRET_2" + route.IDPClientID = "CLIENT_ID_2" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.example.com", nil) + require.NoError(t, err) + req.Header.Set(httputil.HeaderPomeriumIDPAccessToken, "ACCESS_TOKEN") + c := NewIncomingIDPTokenSessionCreator( + func(_ context.Context, recordType, _ string) (*databroker.Record, error) { + assert.Equal(t, "type.googleapis.com/session.Session", recordType) + return nil, storage.ErrNotFound + }, + func(_ context.Context, records []*databroker.Record) error { + if assert.Len(t, records, 2, "should put session and user") { + assert.Equal(t, "type.googleapis.com/session.Session", records[0].Type) + assert.Equal(t, "type.googleapis.com/user.User", records[1].Type) + } + return nil + }, + ) + s, err := c.CreateSession(ctx, cfg, route, req) + assert.NoError(t, err) + assert.Equal(t, "U1", s.GetUserId()) + assert.Equal(t, "ACCESS_TOKEN", s.GetOauthToken().GetAccessToken()) + }) + t.Run("identity_token", func(t *testing.T) { + t.Parallel() + + mux := http.NewServeMux() + mux.HandleFunc("/.pomerium/verify-identity-token", func(w http.ResponseWriter, _ *http.Request) { + json.NewEncoder(w).Encode(&authenticateapi.VerifyTokenResponse{ + Valid: true, + Claims: jwtutil.Claims{"sub": "U1"}, + }) + }) + srv := httptest.NewTLSServer(mux) + + ctx := testutil.GetContext(t, time.Minute) + cfg := &Config{Options: NewDefaultOptions()} + cfg.Options.AuthenticateURLString = srv.URL + cfg.Options.ClientSecret = "CLIENT_SECRET_1" + cfg.Options.ClientID = "CLIENT_ID_1" + route := &Policy{} + route.IDPClientSecret = "CLIENT_SECRET_2" + route.IDPClientID = "CLIENT_ID_2" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.example.com", nil) + require.NoError(t, err) + req.Header.Set(httputil.HeaderPomeriumIDPIdentityToken, "IDENTITY_TOKEN") + c := NewIncomingIDPTokenSessionCreator( + func(_ context.Context, recordType, _ string) (*databroker.Record, error) { + assert.Equal(t, "type.googleapis.com/session.Session", recordType) + return nil, storage.ErrNotFound + }, + func(_ context.Context, records []*databroker.Record) error { + if assert.Len(t, records, 2, "should put session and user") { + assert.Equal(t, "type.googleapis.com/session.Session", records[0].Type) + assert.Equal(t, "type.googleapis.com/user.User", records[1].Type) + } + return nil + }, + ) + s, err := c.CreateSession(ctx, cfg, route, req) + assert.NoError(t, err) + assert.Equal(t, "U1", s.GetUserId()) + }) +}