diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 9c536a6cc..30d649c50 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -100,7 +100,12 @@ func (a *Authenticate) mountDashboard(r *mux.Router) { sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut)) sr.Path("/webauthn").Handler(a.webauthn) sr.Path("/device-enrolled").Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - handlers.DeviceEnrolled().ServeHTTP(w, r) + userInfoData, err := a.getUserInfoData(r) + if err != nil { + return err + } + + handlers.DeviceEnrolled(userInfoData).ServeHTTP(w, r) return nil })) @@ -505,6 +510,7 @@ func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State, func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { ctx, span := trace.StartSpan(r.Context(), "authenticate.userInfo") defer span.End() + r = r.WithContext(ctx) // if we came in with a redirect URI, save it to a cookie so it doesn't expire with the HMAC if redirectURI := r.FormValue(urlutil.QueryRedirectURI); redirectURI != "" { @@ -519,32 +525,42 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { return nil } - state := a.state.Load() - - authenticateURL, err := a.options.Load().GetAuthenticateURL() + userInfoData, err := a.getUserInfoData(r) if err != nil { return err } - s, err := a.getSessionFromCtx(ctx) + handlers.UserInfo(userInfoData).ServeHTTP(w, r) + return nil +} + +func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) { + state := a.state.Load() + + authenticateURL, err := a.options.Load().GetAuthenticateURL() + if err != nil { + return handlers.UserInfoData{}, err + } + + s, err := a.getSessionFromCtx(r.Context()) if err != nil { s.ID = uuid.New().String() } - pbSession, isImpersonated, err := a.getCurrentSession(ctx) + pbSession, isImpersonated, err := a.getCurrentSession(r.Context()) if err != nil { pbSession = &session.Session{ Id: s.ID, } } - pbUser, err := a.getUser(ctx, pbSession.GetUserId()) + pbUser, err := a.getUser(r.Context(), pbSession.GetUserId()) if err != nil { pbUser = &user.User{ Id: pbSession.GetUserId(), } } - pbDirectoryUser, err := a.getDirectoryUser(ctx, pbSession.GetUserId()) + pbDirectoryUser, err := a.getDirectoryUser(r.Context(), pbSession.GetUserId()) if err != nil { pbDirectoryUser = &directory.User{ Id: pbSession.GetUserId(), @@ -552,7 +568,7 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { } var groups []*directory.Group for _, groupID := range pbDirectoryUser.GetGroupIds() { - pbDirectoryGroup, err := directory.GetGroup(ctx, state.dataBrokerClient, groupID) + pbDirectoryGroup, err := directory.GetGroup(r.Context(), state.dataBrokerClient, groupID) if err != nil { pbDirectoryGroup = &directory.Group{ Id: groupID, @@ -563,9 +579,9 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { groups = append(groups, pbDirectoryGroup) } - creationOptions, requestOptions, _ := a.webauthn.GetOptions(ctx) + creationOptions, requestOptions, _ := a.webauthn.GetOptions(r.Context()) - handlers.UserInfo(handlers.UserInfoData{ + return handlers.UserInfoData{ CSRFToken: csrf.Token(r), DirectoryGroups: groups, DirectoryUser: pbDirectoryUser, @@ -577,8 +593,7 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { WebAuthnRequestOptions: requestOptions, WebAuthnURL: urlutil.WebAuthnURL(r, authenticateURL, state.sharedKey, r.URL.Query()), PomeriumVersion: version.FullVersion(), - }).ServeHTTP(w, r) - return nil + }, nil } func (a *Authenticate) saveSessionToDataBroker( diff --git a/authenticate/handlers/device-enrolled.go b/authenticate/handlers/device-enrolled.go index 604be4af2..d8740dd49 100644 --- a/authenticate/handlers/device-enrolled.go +++ b/authenticate/handlers/device-enrolled.go @@ -8,8 +8,8 @@ import ( ) // DeviceEnrolled displays an HTML page informing the user that they've successfully enrolled a device. -func DeviceEnrolled() http.Handler { +func DeviceEnrolled(data UserInfoData) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - return ui.ServePage(w, r, "DeviceEnrolled", map[string]interface{}{}) + return ui.ServePage(w, r, "DeviceEnrolled", data.ToJSON()) }) } diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 047c57b8c..3df126663 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -1,4 +1,9 @@ -import DeviceEnrolledPage from "./components/DeviceEnrolledPage"; +import Box from "@mui/material/Box"; +import CssBaseline from "@mui/material/CssBaseline"; +import { ThemeProvider } from "@mui/material/styles"; +import { get } from "lodash"; +import React, { FC } from "react"; + import ErrorPage from "./components/ErrorPage"; import Footer from "./components/Footer"; import Header from "./components/Header"; @@ -8,12 +13,7 @@ import UserInfoPage from "./components/UserInfoPage"; import WebAuthnRegistrationPage from "./components/WebAuthnRegistrationPage"; import { SubpageContextProvider } from "./context/Subpage"; import { createTheme } from "./theme"; -import {PageData, UserInfoPageData} from "./types"; -import Box from "@mui/material/Box"; -import CssBaseline from "@mui/material/CssBaseline"; -import { ThemeProvider } from "@mui/material/styles"; -import React, { FC } from "react"; -import {get} from "lodash"; +import { PageData, UserInfoPageData } from "./types"; const theme = createTheme(); @@ -21,15 +21,13 @@ const App: FC = () => { const data = (window["POMERIUM_DATA"] || {}) as PageData; let body: React.ReactNode = <>; switch (data?.page) { - case "DeviceEnrolled": - body = ; - break; case "Error": body = ; break; case "SignOutConfirm": body = ; break; + case "DeviceEnrolled": case "UserInfo": body = ; break; @@ -40,7 +38,7 @@ const App: FC = () => { return ( - +
@@ -55,7 +53,7 @@ const App: FC = () => { -