Split guest user feature

This commit is contained in:
Luke Vella 2024-11-02 18:51:34 +00:00
parent bd8029774e
commit 401e132f11
No known key found for this signature in database
GPG key ID: 469CAD687F0D784C
42 changed files with 246 additions and 510 deletions

View file

@ -3,7 +3,7 @@ import type { TimeFormat } from "@rallly/database";
import { extend } from "lodash";
import type { DefaultSession, DefaultUser } from "next-auth";
import NextAuth from "next-auth";
import type { DefaultJWT} from "next-auth/jwt";
import type { DefaultJWT } from "next-auth/jwt";
import { JWT } from "next-auth/jwt";
declare module "next-auth" {
@ -11,7 +11,7 @@ declare module "next-auth" {
* Returned by `useSession`, `getSession` and received as a prop on the `SessionProvider` React Context
*/
interface Session {
user: {
user?: {
id: string;
timeZone?: string | null;
timeFormat?: TimeFormat | null;

View file

@ -23,6 +23,7 @@ const nextConfig = {
"@rallly/ui",
"@rallly/tailwind-config",
"@rallly/posthog",
"@rallly/guest-user",
"@rallly/emails",
],
webpack(config) {

View file

@ -34,7 +34,7 @@
"@rallly/posthog": "*",
"@rallly/tailwind-config": "*",
"@rallly/ui": "*",
"@sentry/nextjs": "*",
"@sentry/nextjs": "^8.32.0",
"@svgr/webpack": "^6.5.1",
"@t3-oss/env-nextjs": "^0.11.0",
"@tanstack/react-query": "^4.0.0",

View file

@ -3,7 +3,7 @@
import { Badge } from "@rallly/ui/badge";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
export function ProBadge() {
const { user } = useUser();

View file

@ -12,13 +12,13 @@ import { useTranslation } from "next-i18next";
import { DeleteAccountDialog } from "@/app/[locale]/(admin)/settings/profile/delete-account-dialog";
import { ProfileSettings } from "@/app/[locale]/(admin)/settings/profile/profile-settings";
import { LogoutButton } from "@/app/components/logout-button";
import { useUser } from "@/auth/client/user-provider";
import {
Settings,
SettingsContent,
SettingsSection,
} from "@/components/settings/settings";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
export const ProfilePage = () => {
const { t } = useTranslation();

View file

@ -7,7 +7,7 @@ import { z } from "zod";
import { OptimizedAvatarImage } from "@/components/optimized-avatar-image";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { IfCloudHosted } from "@/contexts/environment";
import { useTranslation } from "@/i18n/client";
import { trpc } from "@/trpc/client";

View file

@ -11,7 +11,7 @@ import { useForm } from "react-hook-form";
import { ProfilePicture } from "@/app/[locale]/(admin)/settings/profile/profile-picture";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { trpc } from "@/trpc/client";
export const ProfileSettings = () => {

View file

@ -24,7 +24,7 @@ import { OptimizedAvatarImage } from "@/components/optimized-avatar-image";
import { PayWallDialog } from "@/components/pay-wall-dialog";
import { ProBadge } from "@/components/pro-badge";
import { Trans } from "@/components/trans";
import { IfGuest, useUser } from "@/components/user-provider";
import { IfGuest, useUser } from "@/auth/client/user-provider";
import { IfFreeUser } from "@/contexts/plan";
import type { IconComponent } from "@/types";

View file

@ -11,7 +11,7 @@ import { ScheduledEvent } from "@/components/poll/scheduled-event";
import { useTouchBeacon } from "@/components/poll/use-touch-beacon";
import { VotingForm } from "@/components/poll/voting-form";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { usePoll } from "@/contexts/poll";
const GoToApp = () => {

View file

@ -6,7 +6,7 @@ import Link from "next/link";
import { PageHeader } from "@/app/components/page-layout";
import { Trans } from "@/components/trans";
import { UserDropdown } from "@/components/user-dropdown";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { usePoll } from "@/contexts/poll";
export const Nav = () => {

View file

@ -12,6 +12,8 @@ import React from "react";
import { TimeZoneChangeDetector } from "@/app/[locale]/timezone-change-detector";
import { Providers } from "@/app/providers";
import { getServerSession } from "@/auth";
import { getGuestUser } from "@/auth/next";
import type { User } from "@/auth/schema";
import { SessionProvider } from "@/auth/session-provider";
const PostHogPageView = dynamic(() => import("@rallly/posthog/next"), {
@ -38,8 +40,11 @@ export default async function Root({
params: { locale: string };
}) {
let session: Session | null = null;
let guestUser: User | null = null;
try {
session = await getServerSession();
guestUser = await getGuestUser();
} catch (error) {
console.error(error);
}
@ -49,7 +54,7 @@ export default async function Root({
<body>
<Toaster />
<SessionProvider session={session}>
<PostHogProvider distinctId={session?.user?.id}>
<PostHogProvider distinctId={session?.user?.id ?? guestUser?.id}>
<PostHogPageView />
<Providers>
{children}

View file

@ -5,7 +5,7 @@ import { Trans } from "next-i18next";
import { LoginLink } from "@/components/login-link";
import { RegisterLink } from "@/components/register-link";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { usePoll } from "@/contexts/poll";
export const GuestPollAlert = () => {

View file

@ -1,15 +1,17 @@
import languages from "@rallly/languages";
import { absoluteUrl } from "@rallly/utils/absolute-url";
import { randomid } from "@rallly/utils/nanoid";
import languageParser from "accept-language-parser";
import type { NextRequest, NextResponse } from "next/server";
import type { JWT } from "next-auth/jwt";
import { encode } from "next-auth/jwt";
import { decode, encode } from "next-auth/jwt";
import { randomid } from "@/utils/nanoid";
import { GUEST_USER_COOKIE } from "@/auth/constants";
import { createGuestUser } from "@/auth/lib/create-guest-user";
const supportedLocales = Object.keys(languages);
function getCookieSettings() {
function getNextAuthCookieSettings() {
const secure = absoluteUrl().startsWith("https://");
const prefix = secure ? "__Secure-" : "";
const name = `${prefix}next-auth.session-token`;
@ -30,7 +32,7 @@ export async function getLocaleFromHeader(req: NextRequest) {
}
async function setCookie(res: NextResponse, jwt: JWT) {
const { name, secure } = getCookieSettings();
const { name, secure } = getNextAuthCookieSettings();
const token = await encode({
token: jwt,
@ -47,6 +49,33 @@ async function setCookie(res: NextResponse, jwt: JWT) {
});
}
export async function migrateGuestFromNextAuthCookie(
req: NextRequest,
res: NextResponse,
) {
const { name } = getNextAuthCookieSettings();
if (req.cookies.has(name)) {
// get user session token
const token = req.cookies.get(name)?.value;
if (token) {
const jwt = await decode({
token,
secret: process.env.SECRET_PASSWORD,
});
if (jwt?.sub && jwt?.locale) {
const user = await createGuestUser({
id: jwt.sub,
locale: jwt.locale,
timeZone: jwt.timeZone ?? undefined,
weekStart: jwt.weekStart ?? undefined,
timeFormat: jwt.timeFormat ?? undefined,
});
res.cookies.set(GUEST_USER_COOKIE, JSON.stringify(user));
}
}
}
}
export async function resetUser(req: NextRequest, res: NextResponse) {
// resets to a new guest user
const locale = await getLocaleFromHeader(req);
@ -61,7 +90,7 @@ export async function resetUser(req: NextRequest, res: NextResponse) {
}
export async function initGuest(req: NextRequest, res: NextResponse) {
const { name } = getCookieSettings();
const { name } = getNextAuthCookieSettings();
if (req.cookies.has(name)) {
// already has a session token

View file

@ -5,7 +5,7 @@ import { createTRPCReact } from "@trpc/react-query";
import { domMax, LazyMotion } from "framer-motion";
import { useState } from "react";
import { UserProvider } from "@/components/user-provider";
import { UserProvider } from "@/auth/client/user-provider";
import { I18nProvider } from "@/i18n/client";
import { trpcConfig } from "@/trpc/client/config";
import type { AppRouter } from "@/trpc/routers";

View file

@ -1,6 +1,7 @@
import { prisma } from "@rallly/database";
import { posthog } from "@rallly/posthog/server";
import { absoluteUrl } from "@rallly/utils/absolute-url";
import { generateOtp, randomid } from "@rallly/utils/nanoid";
import type {
GetServerSidePropsContext,
NextApiRequest,
@ -20,7 +21,6 @@ import { env } from "@/env";
import type { RegistrationTokenPayload } from "@/trpc/types";
import { getEmailClient } from "@/utils/emails";
import { getValueByPath } from "@/utils/get-value-by-path";
import { generateOtp, randomid } from "@/utils/nanoid";
import { decryptToken } from "@/utils/session";
import { CustomPrismaAdapter } from "./auth/custom-prisma-adapter";
@ -239,7 +239,7 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
return false;
}
} else {
// merge guest user into newly logged in user
// merge guest user into newly logged in user`
const session = await getServerSession(...args);
if (session && session.user.email === null) {
await mergeGuestsIntoUser(user.id, [session.user.id]);
@ -264,6 +264,9 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
return token;
},
async session({ session, token }) {
if (!session.user) {
return session;
}
// If the user is a guest, we don't need to fetch them from the database
if (token.sub?.startsWith("user-")) {
session.user.id = token.sub as string;

View file

@ -0,0 +1,13 @@
import Cookies from "js-cookie";
import { GUEST_USER_COOKIE } from "../constants";
import { safeParseGuestUser } from "../lib/parse-guest";
export function useGuestUser() {
const cookie = Cookies.get(GUEST_USER_COOKIE);
if (cookie) {
return safeParseGuestUser(cookie);
}
return null;
}

View file

@ -5,12 +5,13 @@ import { useSession } from "next-auth/react";
import React from "react";
import { Spinner } from "@/components/spinner";
import { useRequiredContext } from "@/components/use-required-context";
import { useSubscription } from "@/contexts/plan";
import { PreferencesProvider } from "@/contexts/preferences";
import { useTranslation } from "@/i18n/client";
import { trpc } from "@/trpc/client";
import { useRequiredContext } from "./use-required-context";
import { useGuestUser } from "./use-guest-user";
type UserData = {
id: string;
@ -55,16 +56,28 @@ export const IfGuest = (props: { children?: React.ReactNode }) => {
export const UserProvider = (props: { children?: React.ReactNode }) => {
const session = useSession();
const user = session.data?.user;
const guestUser = useGuestUser();
const authenticatedUser = session.data?.user;
const subscription = useSubscription();
const updatePreferences = trpc.user.updatePreferences.useMutation();
const { t, i18n } = useTranslation();
const posthog = usePostHog();
const isGuest = !user?.email;
const isGuest = !authenticatedUser?.email;
const tier = isGuest ? "guest" : subscription?.active ? "pro" : "hobby";
const user = {
id: authenticatedUser?.id ?? guestUser?.id,
name: authenticatedUser?.name,
email: authenticatedUser?.email,
timeZone: authenticatedUser?.timeZone ?? guestUser?.timeZone,
timeFormat: authenticatedUser?.timeFormat ?? guestUser?.timeFormat,
weekStart: authenticatedUser?.weekStart ?? guestUser?.weekStart,
image: authenticatedUser?.image,
locale: authenticatedUser?.locale ?? guestUser?.locale,
};
React.useEffect(() => {
if (user) {
posthog?.identify(user.id, {

View file

@ -0,0 +1 @@
export const GUEST_USER_COOKIE = "rallly-user";

36
apps/web/src/auth/edge.ts Normal file
View file

@ -0,0 +1,36 @@
import type { NextRequest, NextResponse } from "next/server";
import {
getLocaleFromHeader,
migrateGuestFromNextAuthCookie,
} from "@/app/guest";
import { GUEST_USER_COOKIE } from "./constants";
import { createGuestUser } from "./lib/create-guest-user";
import { safeParseGuestUser } from "./lib/parse-guest";
import { type User } from "./schema";
export async function initGuestUser(
req: NextRequest,
res: NextResponse,
): Promise<User> {
await migrateGuestFromNextAuthCookie(req, res);
const cookie = req.cookies.get(GUEST_USER_COOKIE);
if (cookie) {
const user = safeParseGuestUser(cookie.value);
if (user) {
return user;
}
}
const user = await createGuestUser({
locale: await getLocaleFromHeader(req),
});
res.cookies.set(GUEST_USER_COOKIE, JSON.stringify(user), {
httpOnly: false,
});
return user;
}

View file

@ -0,0 +1,12 @@
import { randomid } from "@rallly/utils/nanoid";
import type { User } from "../schema";
export async function createGuestUser(initialData: Partial<User>) {
const user: User = {
id: initialData.id ?? `user-${randomid()}`,
createdAt: new Date().toISOString(),
locale: initialData.locale ?? "en",
};
return user;
}

View file

@ -0,0 +1,13 @@
import { userSchema } from "../schema";
export function safeParseGuestUser(serialized: string) {
try {
const res = userSchema.safeParse(JSON.parse(serialized));
if (res.success) {
return res.data;
}
} catch (error) {
// TODO: Log error
}
return null;
}

31
apps/web/src/auth/next.ts Normal file
View file

@ -0,0 +1,31 @@
import type { NextApiRequest } from "next";
import { cookies } from "next/headers";
import { GUEST_USER_COOKIE } from "./constants";
import { safeParseGuestUser } from "./lib/parse-guest";
import { userSchema } from "./schema";
export async function getGuestUserFromApiRequest(req: NextApiRequest) {
const cookie = req.cookies[GUEST_USER_COOKIE];
if (cookie) {
try {
const res = userSchema.safeParse(JSON.parse(cookie));
if (res.success) {
return res.data;
}
} catch (error) {
console.error("Error parsing guest user cookie", error);
}
}
return null;
}
export async function getGuestUser() {
const cookie = cookies().get(GUEST_USER_COOKIE)?.value;
if (cookie) {
return safeParseGuestUser(cookie);
}
return null;
}

View file

@ -0,0 +1,12 @@
import { z } from "zod";
export const userSchema = z.object({
id: z.string(),
locale: z.string(),
createdAt: z.string(),
timeZone: z.string().optional().catch(undefined),
weekStart: z.number().optional().catch(undefined),
timeFormat: z.enum(["hours12", "hours24"]).optional().catch(undefined),
});
export type User = z.infer<typeof userSchema>;

View file

@ -17,11 +17,11 @@ import { useUnmount } from "react-use";
import { PollSettingsForm } from "@/components/forms/poll-settings";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { trpc } from "@/trpc/client";
import { setCookie } from "@/utils/cookies";
import type { NewEventData} from "./forms";
import type { NewEventData } from "./forms";
import { PollDetailsForm, PollOptionsForm } from "./forms";
const required = <T,>(v: T | undefined): T => {

View file

@ -41,7 +41,7 @@ import { trpc } from "@/trpc/client";
import { requiredString } from "../../utils/form-validation";
import TruncatedLinkify from "../poll/truncated-linkify";
import { useUser } from "../user-provider";
import { useUser } from "../../auth/client/user-provider";
interface CommentForm {
authorName: string;

View file

@ -29,7 +29,7 @@ import ManagePoll from "@/components/poll/manage-poll";
import NotificationsToggle from "@/components/poll/notifications-toggle";
import { LegacyPollContextProvider } from "@/components/poll/poll-context-provider";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { usePlan } from "@/contexts/plan";
import { usePoll } from "@/contexts/poll";
import { trpc } from "@/trpc/client";

View file

@ -11,7 +11,7 @@ import { Participant, ParticipantName } from "@/components/participant";
import { ParticipantDropdown } from "@/components/participant-dropdown";
import { usePoll } from "@/components/poll-context";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { usePermissions } from "@/contexts/permissions";
import type { Vote } from "@/trpc/client/types";

View file

@ -26,7 +26,7 @@ import { Trans } from "@/components/trans";
import { usePermissions } from "@/contexts/permissions";
import { useVisibleParticipants } from "../participants-provider";
import { useUser } from "../user-provider";
import { useUser } from "../../auth/client/user-provider";
import GroupedOptions from "./mobile-poll/grouped-options";
if (typeof window !== "undefined") {

View file

@ -9,7 +9,7 @@ import * as React from "react";
import { Skeleton } from "@/components/skeleton";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { trpc } from "@/trpc/client";
import { usePoll } from "../poll-context";

View file

@ -34,7 +34,11 @@ import { IfCloudHosted, IfSelfHosted } from "@/contexts/environment";
import { Plan, usePlan } from "@/contexts/plan";
import { isFeedbackEnabled } from "@/utils/constants";
import { IfAuthenticated, IfGuest, useUser } from "./user-provider";
import {
IfAuthenticated,
IfGuest,
useUser,
} from "../auth/client/user-provider";
function logout() {
// programmtically submit form with name="logout"

View file

@ -1,7 +1,7 @@
import React from "react";
import { useParticipants } from "@/components/participants-provider";
import { useUser } from "@/components/user-provider";
import { useUser } from "@/auth/client/user-provider";
import { usePoll } from "@/contexts/poll";
import { useRole } from "@/contexts/role";

View file

@ -2,7 +2,8 @@ import languages from "@rallly/languages";
import { NextResponse } from "next/server";
import withAuth from "next-auth/middleware";
import { getLocaleFromHeader, initGuest } from "@/app/guest";
import { getLocaleFromHeader } from "@/app/guest";
import { initGuestUser } from "@/auth/edge";
import { isSelfHosted } from "@/utils/constants";
const supportedLocales = Object.keys(languages);
@ -34,7 +35,7 @@ export const middleware = withAuth(
const res = NextResponse.rewrite(newUrl);
await initGuest(req, res);
await initGuestUser(req, res);
return res;
},

View file

@ -12,7 +12,7 @@ import { SessionProvider, signIn, useSession } from "next-auth/react";
import React from "react";
import Maintenance from "@/components/maintenance";
import { UserProvider } from "@/components/user-provider";
import { UserProvider } from "@/auth/client/user-provider";
import { I18nProvider } from "@/i18n/client";
import { trpc } from "@/trpc/client";
import { ConnectedDayjsProvider } from "@/utils/dayjs";

View file

@ -4,6 +4,7 @@ import { TRPCError } from "@trpc/server";
import { createNextApiHandler } from "@trpc/server/adapters/next";
import { getServerSession } from "@/auth";
import { getGuestUserFromApiRequest } from "@/auth/next";
import type { AppRouter } from "@/trpc/routers";
import { appRouter } from "@/trpc/routers";
import { getEmailClient } from "@/utils/emails";
@ -19,27 +20,28 @@ const trpcApiHandler = createNextApiHandler<AppRouter>({
router: appRouter,
createContext: async (opts) => {
const session = await getServerSession(opts.req, opts.res);
const guestUser = await getGuestUserFromApiRequest(opts.req);
if (!session) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "Unauthorized",
});
const id = session?.user?.id || guestUser?.id;
const isGuest = !session?.user?.email;
const locale = session?.user?.locale ?? guestUser?.locale;
const image = session?.user?.image ?? undefined;
if (!id) {
throw new TRPCError({ code: "UNAUTHORIZED" });
}
const res = {
return {
user: {
id: session.user.id,
isGuest: session.user.email === null,
locale: session.user.locale ?? undefined,
image: session.user.image ?? undefined,
getEmailClient: () => getEmailClient(session.user.locale ?? undefined),
id,
isGuest,
locale,
image,
getEmailClient: () => getEmailClient(locale),
},
req: opts.req,
res: opts.res,
};
return res;
},
onError({ error }) {
if (error.code === "INTERNAL_SERVER_ERROR") {

View file

@ -3,7 +3,7 @@ import { posthog } from "@rallly/posthog/server";
import { z } from "zod";
import { isEmailBlocked } from "@/auth";
import { generateOtp } from "@/utils/nanoid";
import { generateOtp } from "@rallly/utils/nanoid";
import { createToken, decryptToken } from "@/utils/session";
import { publicProcedure, rateLimitMiddleware, router } from "../trpc";

View file

@ -2,6 +2,7 @@ import type { PollStatus } from "@rallly/database";
import { prisma } from "@rallly/database";
import { posthog } from "@rallly/posthog/server";
import { absoluteUrl, shortUrl } from "@rallly/utils/absolute-url";
import { nanoid } from "@rallly/utils/nanoid";
import { TRPCError } from "@trpc/server";
import dayjs from "dayjs";
import * as ics from "ics";
@ -10,7 +11,6 @@ import { z } from "zod";
import { getEmailClient } from "@/utils/emails";
import { getTimeZoneAbbreviation } from "../../utils/date";
import { nanoid } from "../../utils/nanoid";
import {
possiblyPublicProcedure,
proProcedure,

View file

@ -1,13 +0,0 @@
import { customAlphabet } from "nanoid";
export const nanoid = customAlphabet(
"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
12,
);
export const randomid = customAlphabet(
"0123456789abcdefghijklmnopqrstuvwxyz",
12,
);
export const generateOtp = customAlphabet("0123456789", 6);