♻️ Make user session optional (#1515)

This commit is contained in:
Luke Vella 2025-01-27 13:02:34 +00:00 committed by GitHub
parent f6a0bca4f8
commit 58d5c42a6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 343 additions and 549 deletions

View file

@ -5,4 +5,5 @@ SECRET_PASSWORD=abcdef1234567890abcdef1234567890
DATABASE_URL=postgres://postgres:postgres@localhost:5450/rallly DATABASE_URL=postgres://postgres:postgres@localhost:5450/rallly
SUPPORT_EMAIL=support@rallly.co SUPPORT_EMAIL=support@rallly.co
SMTP_HOST=localhost SMTP_HOST=localhost
SMTP_PORT=1025 SMTP_PORT=1025
QUICK_CREATE_ENABLED=true

View file

@ -10,7 +10,7 @@ declare module "next-auth" {
* Returned by `useSession`, `getSession` and received as a prop on the `SessionProvider` React Context * Returned by `useSession`, `getSession` and received as a prop on the `SessionProvider` React Context
*/ */
interface Session { interface Session {
user: { user?: {
id: string; id: string;
timeZone?: string | null; timeZone?: string | null;
timeFormat?: TimeFormat | null; timeFormat?: TimeFormat | null;

View file

@ -1,12 +1,8 @@
"use client"; "use client";
import { Alert, AlertDescription, AlertTitle } from "@rallly/ui/alert";
import { Button } from "@rallly/ui/button"; import { Button } from "@rallly/ui/button";
import { DialogTrigger } from "@rallly/ui/dialog"; import { DialogTrigger } from "@rallly/ui/dialog";
import { Input } from "@rallly/ui/input"; import { LogOutIcon, TrashIcon } from "lucide-react";
import { Label } from "@rallly/ui/label";
import { InfoIcon, LogOutIcon, TrashIcon, UserXIcon } from "lucide-react";
import Head from "next/head"; import Head from "next/head";
import Link from "next/link";
import { useTranslation } from "next-i18next"; import { useTranslation } from "next-i18next";
import { DeleteAccountDialog } from "@/app/[locale]/(admin)/settings/profile/delete-account-dialog"; import { DeleteAccountDialog } from "@/app/[locale]/(admin)/settings/profile/delete-account-dialog";
@ -31,112 +27,71 @@ export const ProfilePage = () => {
<Head> <Head>
<title>{t("profile")}</title> <title>{t("profile")}</title>
</Head> </Head>
{user.isGuest ? ( <SettingsContent>
<SettingsContent> <SettingsSection
<SettingsSection title={<Trans i18nKey="profile" defaults="Profile" />}
title={<Trans i18nKey="profile" />} description={
description={<Trans i18nKey="profileDescription" />} <Trans
> i18nKey="profileDescription"
<Label className="mb-2.5"> defaults="Set your public profile information"
<Trans i18nKey="userId" defaults="User ID" />
</Label>
<Input
className="w-full"
value={user.id.substring(0, 10)}
readOnly
disabled
/> />
<Alert className="mt-4" icon={InfoIcon}> }
<AlertTitle> >
<Trans i18nKey="aboutGuest" defaults="Guest User" /> <ProfileSettings />
</AlertTitle> </SettingsSection>
<AlertDescription> <SettingsSection
<Trans title={
i18nKey="aboutGuestDescription" <Trans i18nKey="profileEmailAddress" defaults="Email Address" />
defaults="Profile settings are not available for guest users. <0>Sign in</0> to your existing account or <1>create a new account</1> to customize your profile." }
components={[ description={
<Link className="text-link" key={0} href="/login" />, <Trans
<Link className="text-link" key={1} href="/register" />, i18nKey="profileEmailAddressDescription"
]} defaults="Your email address is used to log in to your account"
/> />
</AlertDescription> }
</Alert> >
<LogoutButton className="mt-6" variant="destructive"> <ProfileEmailAddress />
<UserXIcon className="size-4" /> </SettingsSection>
<Trans i18nKey="forgetMe" /> <hr />
</LogoutButton>
</SettingsSection>
</SettingsContent>
) : (
<SettingsContent>
<SettingsSection
title={<Trans i18nKey="profile" defaults="Profile" />}
description={
<Trans
i18nKey="profileDescription"
defaults="Set your public profile information"
/>
}
>
<ProfileSettings />
</SettingsSection>
<SettingsSection
title={
<Trans i18nKey="profileEmailAddress" defaults="Email Address" />
}
description={
<Trans
i18nKey="profileEmailAddressDescription"
defaults="Your email address is used to log in to your account"
/>
}
>
<ProfileEmailAddress />
</SettingsSection>
<hr />
<SettingsSection <SettingsSection
title={<Trans i18nKey="logout" />} title={<Trans i18nKey="logout" />}
description={ description={
<Trans <Trans
i18nKey="logoutDescription" i18nKey="logoutDescription"
defaults="Sign out of your existing session" defaults="Sign out of your existing session"
/> />
} }
> >
<LogoutButton> <LogoutButton>
<LogOutIcon className="size-4" /> <LogOutIcon className="size-4" />
<Trans i18nKey="logout" defaults="Logout" /> <Trans i18nKey="logout" defaults="Logout" />
</LogoutButton> </LogoutButton>
</SettingsSection> </SettingsSection>
{user.email ? ( {user.email ? (
<> <>
<hr /> <hr />
<SettingsSection <SettingsSection
title={<Trans i18nKey="dangerZone" defaults="Danger Zone" />} title={<Trans i18nKey="dangerZone" defaults="Danger Zone" />}
description={ description={
<Trans <Trans
i18nKey="dangerZoneAccount" i18nKey="dangerZoneAccount"
defaults="Delete your account permanently. This action cannot be undone." defaults="Delete your account permanently. This action cannot be undone."
/> />
} }
> >
<DeleteAccountDialog email={user.email}> <DeleteAccountDialog email={user.email}>
<DialogTrigger asChild> <DialogTrigger asChild>
<Button className="text-destructive"> <Button className="text-destructive">
<TrashIcon className="size-4" /> <TrashIcon className="size-4" />
<Trans <Trans i18nKey="deleteAccount" defaults="Delete Account" />
i18nKey="deleteAccount" </Button>
defaults="Delete Account" </DialogTrigger>
/> </DeleteAccountDialog>
</Button> </SettingsSection>
</DialogTrigger> </>
</DeleteAccountDialog> ) : null}
</SettingsSection> </SettingsContent>
</>
) : null}
</SettingsContent>
)}
</Settings> </Settings>
); );
}; };

View file

@ -24,7 +24,7 @@ export const LoginPage = ({ magicLink, email }: PageProps) => {
if (!data.url.includes("auth/error")) { if (!data.url.includes("auth/error")) {
// if login was successful, update the session // if login was successful, update the session
const updatedSession = await session.update(); const updatedSession = await session.update();
if (updatedSession) { if (updatedSession?.user) {
// identify the user in posthog // identify the user in posthog
posthog?.identify(updatedSession.user.id, { posthog?.identify(updatedSession.user.id, {
email: updatedSession.user.email, email: updatedSession.user.email,

View file

@ -1,10 +0,0 @@
import type { NextRequest } from "next/server";
import { NextResponse } from "next/server";
import { resetUser } from "@/app/guest";
export async function POST(req: NextRequest) {
const res = NextResponse.json({ ok: 1 });
await resetUser(req, res);
return res;
}

View file

@ -16,7 +16,7 @@ export const GET = async (req: NextRequest) => {
const session = await getServerSession(); const session = await getServerSession();
if (!session || !session.user.email) { if (!session || !session.user?.email) {
return NextResponse.redirect(new URL("/login", req.url)); return NextResponse.redirect(new URL("/login", req.url));
} }

View file

@ -20,7 +20,7 @@ export async function POST(request: NextRequest) {
Object.fromEntries(formData.entries()), Object.fromEntries(formData.entries()),
); );
if (!userSession || userSession.user.email === null) { if (!userSession?.user || userSession.user?.email === null) {
// You need to be logged in to subscribe // You need to be logged in to subscribe
return NextResponse.redirect( return NextResponse.redirect(
new URL( new URL(

View file

@ -33,7 +33,7 @@ export async function GET(request: NextRequest) {
} }
} else { } else {
const userSession = await getServerSession(); const userSession = await getServerSession();
if (!userSession || userSession.user.email === null) { if (!userSession?.user || userSession.user.email === null) {
Sentry.captureException(new Error("User not logged in")); Sentry.captureException(new Error("User not logged in"));
return NextResponse.json( return NextResponse.json(
{ error: "User not logged in" }, { error: "User not logged in" },

View file

@ -1,38 +1,37 @@
import * as Sentry from "@sentry/nextjs"; import * as Sentry from "@sentry/nextjs";
import { TRPCError } from "@trpc/server";
import { fetchRequestHandler } from "@trpc/server/adapters/fetch"; import { fetchRequestHandler } from "@trpc/server/adapters/fetch";
import { ipAddress } from "@vercel/functions"; import { ipAddress } from "@vercel/functions";
import type { NextRequest } from "next/server";
import { getLocaleFromHeader } from "@/app/guest";
import { getServerSession } from "@/auth"; import { getServerSession } from "@/auth";
import type { TRPCContext } from "@/trpc/context"; import type { TRPCContext } from "@/trpc/context";
import { appRouter } from "@/trpc/routers"; import { appRouter } from "@/trpc/routers";
import { getEmailClient } from "@/utils/emails"; import { getEmailClient } from "@/utils/emails";
const handler = (request: Request) => { const handler = (req: NextRequest) => {
return fetchRequestHandler({ return fetchRequestHandler({
endpoint: "/api/trpc", endpoint: "/api/trpc",
req: request, req,
router: appRouter, router: appRouter,
createContext: async () => { createContext: async () => {
const session = await getServerSession(); const session = await getServerSession();
const locale = await getLocaleFromHeader(req);
if (!session?.user) { const user = session?.user
throw new TRPCError({ ? {
code: "UNAUTHORIZED", id: session.user.id,
message: "Unauthorized", isGuest: !session.user.email,
}); locale: session.user.locale ?? undefined,
} image: session.user.image ?? undefined,
getEmailClient: () =>
getEmailClient(session.user?.locale ?? undefined),
}
: undefined;
return { return {
user: { user,
id: session.user.id, locale,
isGuest: session.user.email === null, ip: ipAddress(req) ?? undefined,
locale: session.user.locale ?? undefined,
image: session.user.image ?? undefined,
getEmailClient: () =>
getEmailClient(session.user?.locale ?? undefined),
},
ip: ipAddress(request) ?? undefined,
} satisfies TRPCContext; } satisfies TRPCContext;
}, },
onError({ error }) { onError({ error }) {

View file

@ -52,7 +52,7 @@ export const GET = async (request: NextRequest) => {
const session = await getServerSession(); const session = await getServerSession();
if (!session || !session.user.email) { if (!session?.user || !session.user.email) {
return NextResponse.redirect( return NextResponse.redirect(
new URL(`/login?callbackUrl=${request.url}`, request.url), new URL(`/login?callbackUrl=${request.url}`, request.url),
); );

View file

@ -1,23 +1,9 @@
import languages from "@rallly/languages"; import languages from "@rallly/languages";
import { absoluteUrl } from "@rallly/utils/absolute-url";
import { randomid } from "@rallly/utils/nanoid";
import languageParser from "accept-language-parser"; import languageParser from "accept-language-parser";
import type { NextRequest, NextResponse } from "next/server"; import type { NextRequest } from "next/server";
import type { JWT } from "next-auth/jwt";
import { decode, encode } from "next-auth/jwt";
const supportedLocales = Object.keys(languages); const supportedLocales = Object.keys(languages);
function getCookieSettings() {
const secure = absoluteUrl().startsWith("https://");
const prefix = secure ? "__Secure-" : "";
const name = `${prefix}next-auth.session-token`;
return {
secure,
name,
};
}
export async function getLocaleFromHeader(req: NextRequest) { export async function getLocaleFromHeader(req: NextRequest) {
// Check if locale is specified in header // Check if locale is specified in header
const headers = req.headers; const headers = req.headers;
@ -27,65 +13,3 @@ export async function getLocaleFromHeader(req: NextRequest) {
: null; : null;
return localeFromHeader ?? "en"; return localeFromHeader ?? "en";
} }
async function setCookie(res: NextResponse, jwt: JWT) {
const { name, secure } = getCookieSettings();
const token = await encode({
token: jwt,
secret: process.env.SECRET_PASSWORD,
});
res.cookies.set({
name,
value: token,
httpOnly: true,
secure,
sameSite: "lax",
path: "/",
});
}
export async function resetUser(req: NextRequest, res: NextResponse) {
// resets to a new guest user
const locale = await getLocaleFromHeader(req);
const jwt: JWT = {
sub: `user-${randomid()}`,
email: null,
locale,
};
await setCookie(res, jwt);
}
export async function initGuest(req: NextRequest, res: NextResponse) {
const { name } = getCookieSettings();
const token = req.cookies.get(name)?.value;
if (token) {
try {
const jwt = await decode({
token,
secret: process.env.SECRET_PASSWORD,
});
if (jwt) {
return jwt;
}
} catch (error) {
// invalid token
console.error(error);
}
}
const locale = await getLocaleFromHeader(req);
const jwt: JWT = {
sub: `user-${randomid()}`,
email: null,
locale,
};
await setCookie(res, jwt);
return jwt;
}

View file

@ -174,7 +174,7 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
adapter: CustomPrismaAdapter(prisma, { adapter: CustomPrismaAdapter(prisma, {
migrateData: async (userId) => { migrateData: async (userId) => {
const session = await getServerSession(...args); const session = await getServerSession(...args);
if (session && session.user.email === null) { if (session?.user && session.user.email === null) {
await mergeGuestsIntoUser(userId, [session.user.id]); await mergeGuestsIntoUser(userId, [session.user.id]);
} }
}, },
@ -255,7 +255,7 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
if (!isInitialSocialLogin) { if (!isInitialSocialLogin) {
// merge guest user into newly logged in user // merge guest user into newly logged in user
const session = await getServerSession(...args); const session = await getServerSession(...args);
if (session && session.user.email === null) { if (session?.user && session.user.email === null) {
await mergeGuestsIntoUser(user.id, [session.user.id]); await mergeGuestsIntoUser(user.id, [session.user.id]);
} }
} }
@ -273,13 +273,18 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
return token; return token;
}, },
async session({ session, token }) { async session({ session, token }) {
if (!token.sub) {
return session;
}
if (token.sub?.startsWith("user-")) { if (token.sub?.startsWith("user-")) {
session.user.id = token.sub as string; session.user = {
session.user.locale = token.locale; id: token.sub as string,
session.user.timeFormat = token.timeFormat; locale: token.locale,
session.user.timeZone = token.timeZone; timeFormat: token.timeFormat,
session.user.locale = token.locale; timeZone: token.timeZone,
session.user.weekStart = token.weekStart; weekStart: token.weekStart,
};
return session; return session;
} }
@ -300,21 +305,18 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
}); });
if (user) { if (user) {
session.user.id = user.id; session.user = {
session.user.name = user.name; id: user.id,
session.user.email = user.email; name: user.name,
session.user.image = user.image; email: user.email,
} else { image: user.image,
session.user.id = token.sub || `user-${randomid()}`; locale: user.locale,
timeFormat: user.timeFormat,
timeZone: user.timeZone,
weekStart: user.weekStart,
};
} }
const source = user ?? token;
session.user.locale = source.locale;
session.user.timeFormat = source.timeFormat;
session.user.timeZone = source.timeZone;
session.user.weekStart = source.weekStart;
return session; return session;
}, },
}, },

View file

@ -16,26 +16,27 @@ import type { Adapter, AdapterAccount } from "next-auth/adapters";
export function CustomPrismaAdapter( export function CustomPrismaAdapter(
client: ExtendedPrismaClient, client: ExtendedPrismaClient,
options: { migrateData: (userId: string) => Promise<void> }, options: { migrateData: (userId: string) => Promise<void> },
): Adapter { ) {
const adapter = PrismaAdapter(client as PrismaClient);
return { return {
...PrismaAdapter(client as PrismaClient), ...adapter,
linkAccount: async (data) => { linkAccount: async (account: AdapterAccount) => {
await options.migrateData(data.userId); await options.migrateData(account.userId);
return client.account.create({ return (await client.account.create({
data: { data: {
userId: data.userId, userId: account.userId,
type: data.type, type: account.type,
provider: data.provider, provider: account.provider,
providerAccountId: data.providerAccountId, providerAccountId: account.providerAccountId,
access_token: data.access_token as string, access_token: account.access_token as string,
expires_at: data.expires_at as number, expires_at: account.expires_at as number,
id_token: data.id_token as string, id_token: account.id_token as string,
token_type: data.token_type as string, token_type: account.token_type as string,
refresh_token: data.refresh_token as string, refresh_token: account.refresh_token as string,
scope: data.scope as string, scope: account.scope as string,
session_state: data.session_state as string, session_state: account.session_state as string,
}, },
}) as unknown as AdapterAccount; })) as AdapterAccount;
}, },
}; } satisfies Adapter;
} }

View file

@ -10,6 +10,7 @@ import {
} from "@rallly/ui/card"; } from "@rallly/ui/card";
import { Form } from "@rallly/ui/form"; import { Form } from "@rallly/ui/form";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
import { signIn, useSession } from "next-auth/react";
import React from "react"; import React from "react";
import { useForm } from "react-hook-form"; import { useForm } from "react-hook-form";
import useFormPersist from "react-hook-form-persist"; import useFormPersist from "react-hook-form-persist";
@ -19,7 +20,6 @@ import { PollSettingsForm } from "@/components/forms/poll-settings";
import { Trans } from "@/components/trans"; import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider"; import { useUser } from "@/components/user-provider";
import { trpc } from "@/trpc/client"; import { trpc } from "@/trpc/client";
import { setCookie } from "@/utils/cookies";
import type { NewEventData } from "./forms"; import type { NewEventData } from "./forms";
import { PollDetailsForm, PollOptionsForm } from "./forms"; import { PollDetailsForm, PollOptionsForm } from "./forms";
@ -42,6 +42,7 @@ export interface CreatePollPageProps {
export const CreatePoll: React.FunctionComponent = () => { export const CreatePoll: React.FunctionComponent = () => {
const router = useRouter(); const router = useRouter();
const { user } = useUser(); const { user } = useUser();
const session = useSession();
const form = useForm<NewEventData>({ const form = useForm<NewEventData>({
defaultValues: { defaultValues: {
title: "", title: "",
@ -66,8 +67,12 @@ export const CreatePoll: React.FunctionComponent = () => {
const posthog = usePostHog(); const posthog = usePostHog();
const createPoll = trpc.polls.create.useMutation({ const createPoll = trpc.polls.create.useMutation({
networkMode: "always", networkMode: "always",
onSuccess: () => { onMutate: async () => {
setCookie("new-poll", "1"); if (session.status !== "authenticated") {
await signIn("guest", {
redirect: false,
});
}
}, },
}); });
@ -76,7 +81,6 @@ export const CreatePoll: React.FunctionComponent = () => {
<form <form
onSubmit={form.handleSubmit(async (formData) => { onSubmit={form.handleSubmit(async (formData) => {
const title = required(formData?.title); const title = required(formData?.title);
await createPoll.mutateAsync( await createPoll.mutateAsync(
{ {
title: title, title: title,

View file

@ -26,6 +26,7 @@ import {
MoreHorizontalIcon, MoreHorizontalIcon,
TrashIcon, TrashIcon,
} from "lucide-react"; } from "lucide-react";
import { signIn, useSession } from "next-auth/react";
import { useTranslation } from "next-i18next"; import { useTranslation } from "next-i18next";
import * as React from "react"; import * as React from "react";
import { Controller, useForm } from "react-hook-form"; import { Controller, useForm } from "react-hook-form";
@ -57,6 +58,7 @@ function NewCommentForm({
const { t } = useTranslation(); const { t } = useTranslation();
const poll = usePoll(); const poll = usePoll();
const { user } = useUser(); const { user } = useUser();
const session = useSession();
const { participants } = useParticipants(); const { participants } = useParticipants();
const authorName = React.useMemo(() => { const authorName = React.useMemo(() => {
@ -72,8 +74,6 @@ function NewCommentForm({
const posthog = usePostHog(); const posthog = usePostHog();
const session = useUser();
const { register, reset, control, handleSubmit, formState } = const { register, reset, control, handleSubmit, formState } =
useForm<CommentForm>({ useForm<CommentForm>({
defaultValues: { defaultValues: {
@ -84,6 +84,13 @@ function NewCommentForm({
const { toast } = useToast(); const { toast } = useToast();
const addComment = trpc.polls.comments.add.useMutation({ const addComment = trpc.polls.comments.add.useMutation({
onMutate: async () => {
if (session.status !== "authenticated") {
await signIn("guest", {
redirect: false,
});
}
},
onSuccess: () => { onSuccess: () => {
posthog?.capture("created comment"); posthog?.capture("created comment");
}, },
@ -119,7 +126,7 @@ function NewCommentForm({
> >
<Controller <Controller
name="authorName" name="authorName"
key={session.user?.id} key={user?.id}
control={control} control={control}
rules={{ validate: requiredString }} rules={{ validate: requiredString }}
render={({ field }) => ( render={({ field }) => (

View file

@ -14,7 +14,6 @@ import { useParams, usePathname } from "next/navigation";
import React from "react"; import React from "react";
import { GroupPollIcon } from "@/app/[locale]/(admin)/app-card"; import { GroupPollIcon } from "@/app/[locale]/(admin)/app-card";
import Loader from "@/app/[locale]/poll/[urlId]/skeleton";
import { LogoutButton } from "@/app/components/logout-button"; import { LogoutButton } from "@/app/components/logout-button";
import { InviteDialog } from "@/components/invite-dialog"; import { InviteDialog } from "@/components/invite-dialog";
import { LoginLink } from "@/components/login-link"; import { LoginLink } from "@/components/login-link";
@ -30,9 +29,7 @@ import NotificationsToggle from "@/components/poll/notifications-toggle";
import { LegacyPollContextProvider } from "@/components/poll/poll-context-provider"; import { LegacyPollContextProvider } from "@/components/poll/poll-context-provider";
import { Trans } from "@/components/trans"; import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider"; import { useUser } from "@/components/user-provider";
import { usePlan } from "@/contexts/plan";
import { usePoll } from "@/contexts/poll"; import { usePoll } from "@/contexts/poll";
import { trpc } from "@/trpc/client";
const AdminControls = () => { const AdminControls = () => {
return ( return (
@ -86,8 +83,9 @@ const Layout = ({ children }: React.PropsWithChildren) => {
}; };
const PermissionGuard = ({ children }: React.PropsWithChildren) => { const PermissionGuard = ({ children }: React.PropsWithChildren) => {
const poll = usePoll();
const { user } = useUser(); const { user } = useUser();
const poll = usePoll();
if (!poll.adminUrlId) { if (!poll.adminUrlId) {
return ( return (
<PageDialog icon={ShieldCloseIcon}> <PageDialog icon={ShieldCloseIcon}>
@ -139,30 +137,6 @@ const PermissionGuard = ({ children }: React.PropsWithChildren) => {
return <>{children}</>; return <>{children}</>;
}; };
const Prefetch = ({ children }: React.PropsWithChildren) => {
const params = useParams();
const urlId = params?.urlId as string;
const poll = trpc.polls.get.useQuery({ urlId });
const participants = trpc.polls.participants.list.useQuery({ pollId: urlId });
const watchers = trpc.polls.getWatchers.useQuery({ pollId: urlId });
const comments = trpc.polls.comments.list.useQuery({ pollId: urlId });
usePlan(); // prefetch plan
if (
!poll.isFetched ||
!watchers.isFetched ||
!participants.isFetched ||
!comments.isFetched
) {
return <Loader />;
}
return <>{children}</>;
};
export const PollLayout = ({ children }: React.PropsWithChildren) => { export const PollLayout = ({ children }: React.PropsWithChildren) => {
const params = useParams(); const params = useParams();
@ -174,12 +148,10 @@ export const PollLayout = ({ children }: React.PropsWithChildren) => {
} }
return ( return (
<Prefetch> <LegacyPollContextProvider>
<LegacyPollContextProvider> <PermissionGuard>
<PermissionGuard> <Layout>{children}</Layout>
<Layout>{children}</Layout> </PermissionGuard>
</PermissionGuard> </LegacyPollContextProvider>
</LegacyPollContextProvider>
</Prefetch>
); );
}; };

View file

@ -92,7 +92,6 @@ export const NewParticipantForm = (props: NewParticipantModalProps) => {
const { user } = useUser(); const { user } = useUser();
const isLoggedIn = !user.isGuest; const isLoggedIn = !user.isGuest;
const { register, setError, formState, handleSubmit } = const { register, setError, formState, handleSubmit } =
useForm<NewParticipantFormData>({ useForm<NewParticipantFormData>({
resolver: zodResolver(schema), resolver: zodResolver(schema),

View file

@ -1,4 +1,5 @@
import { usePostHog } from "@rallly/posthog/client"; import { usePostHog } from "@rallly/posthog/client";
import { signIn, useSession } from "next-auth/react";
import { usePoll } from "@/components/poll-context"; import { usePoll } from "@/components/poll-context";
import { trpc } from "@/trpc/client"; import { trpc } from "@/trpc/client";
@ -18,9 +19,16 @@ export const normalizeVotes = (
export const useAddParticipantMutation = () => { export const useAddParticipantMutation = () => {
const posthog = usePostHog(); const posthog = usePostHog();
const queryClient = trpc.useUtils(); const queryClient = trpc.useUtils();
const session = useSession();
return trpc.polls.participants.add.useMutation({ return trpc.polls.participants.add.useMutation({
onSuccess: (newParticipant, input) => { onMutate: async () => {
if (session.status !== "authenticated") {
await signIn("guest", {
redirect: false,
});
}
},
onSuccess: async (newParticipant, input) => {
const { pollId, name, email } = newParticipant; const { pollId, name, email } = newParticipant;
queryClient.polls.participants.list.setData( queryClient.polls.participants.list.setData(
{ pollId }, { pollId },
@ -31,6 +39,7 @@ export const useAddParticipantMutation = () => {
]; ];
}, },
); );
posthog?.capture("add participant", { posthog?.capture("add participant", {
pollId, pollId,
name, name,

View file

@ -44,7 +44,7 @@ const NotificationsToggle: React.FunctionComponent = () => {
queryClient.polls.getWatchers.setData( queryClient.polls.getWatchers.setData(
{ pollId: poll.id }, { pollId: poll.id },
(oldWatchers) => { (oldWatchers) => {
if (!oldWatchers) { if (!oldWatchers || !user.id) {
return; return;
} }
return [...oldWatchers, { userId: user.id }]; return [...oldWatchers, { userId: user.id }];
@ -124,11 +124,11 @@ const NotificationsToggle: React.FunctionComponent = () => {
values={{ values={{
value: isWatching value: isWatching
? t("notificationsOn", { ? t("notificationsOn", {
defaultValue: "On", defaultValue: "On",
}) })
: t("notificationsOff", { : t("notificationsOff", {
defaultValue: "Off", defaultValue: "Off",
}), }),
}} }}
/> />
)} )}

View file

@ -62,9 +62,11 @@ export const UserDropdown = ({ className }: { className?: string }) => {
<DropdownMenuLabel className="flex items-center gap-2"> <DropdownMenuLabel className="flex items-center gap-2">
<div className="grow"> <div className="grow">
<div>{user.isGuest ? <Trans i18nKey="guest" /> : user.name}</div> <div>{user.isGuest ? <Trans i18nKey="guest" /> : user.name}</div>
<div className="text-muted-foreground text-xs font-normal"> {user.email ? (
{!user.isGuest ? user.email : user.id.substring(0, 10)} <div className="text-muted-foreground text-xs font-normal">
</div> {user.email}
</div>
) : null}
</div> </div>
<div className="ml-4"> <div className="ml-4">
<Plan /> <Plan />

View file

@ -4,7 +4,6 @@ import type { Session } from "next-auth";
import { signOut, useSession } from "next-auth/react"; import { signOut, useSession } from "next-auth/react";
import React from "react"; import React from "react";
import { Spinner } from "@/components/spinner";
import { useSubscription } from "@/contexts/plan"; import { useSubscription } from "@/contexts/plan";
import { PreferencesProvider } from "@/contexts/preferences"; import { PreferencesProvider } from "@/contexts/preferences";
import { useTranslation } from "@/i18n/client"; import { useTranslation } from "@/i18n/client";
@ -14,7 +13,7 @@ import { isOwner } from "@/utils/permissions";
import { useRequiredContext } from "./use-required-context"; import { useRequiredContext } from "./use-required-context";
type UserData = { type UserData = {
id: string; id?: string;
name: string; name: string;
email?: string | null; email?: string | null;
isGuest: boolean; isGuest: boolean;
@ -71,7 +70,7 @@ export const UserProvider = (props: { children?: React.ReactNode }) => {
const tier = isGuest ? "guest" : subscription?.active ? "pro" : "hobby"; const tier = isGuest ? "guest" : subscription?.active ? "pro" : "hobby";
React.useEffect(() => { React.useEffect(() => {
if (user?.email) { if (user) {
posthog?.identify(user.id, { posthog?.identify(user.id, {
email: user.email, email: user.email,
name: user.name, name: user.name,
@ -84,26 +83,18 @@ export const UserProvider = (props: { children?: React.ReactNode }) => {
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, [user?.id]); }, [user?.id]);
if (!user) {
return (
<div className="flex h-screen items-center justify-center">
<Spinner />
</div>
);
}
return ( return (
<UserContext.Provider <UserContext.Provider
value={{ value={{
user: { user: {
id: user.id as string, id: user?.id,
name: user.name ?? t("guest"), name: user?.name ?? t("guest"),
email: user.email || null, email: user?.email || null,
isGuest, isGuest,
tier, tier,
timeZone: user.timeZone ?? null, timeZone: user?.timeZone ?? null,
image: user.image ?? null, image: user?.image ?? null,
locale: user.locale ?? i18n.language, locale: user?.locale ?? i18n.language,
}, },
refresh: session.update, refresh: session.update,
logout: async () => { logout: async () => {
@ -112,16 +103,16 @@ export const UserProvider = (props: { children?: React.ReactNode }) => {
posthog?.reset(); posthog?.reset();
}, },
ownsObject: (resource) => { ownsObject: (resource) => {
return isOwner(resource, { id: user.id, isGuest }); return user ? isOwner(resource, { id: user.id, isGuest }) : false;
}, },
}} }}
> >
<PreferencesProvider <PreferencesProvider
initialValue={{ initialValue={{
locale: user.locale ?? undefined, locale: user?.locale ?? undefined,
timeZone: user.timeZone ?? undefined, timeZone: user?.timeZone ?? undefined,
timeFormat: user.timeFormat ?? undefined, timeFormat: user?.timeFormat ?? undefined,
weekStart: user.weekStart ?? undefined, weekStart: user?.weekStart ?? undefined,
}} }}
onUpdate={async (newPreferences) => { onUpdate={async (newPreferences) => {
if (!isGuest) { if (!isGuest) {

View file

@ -56,9 +56,7 @@ export async function QuickCreateWidget() {
<GroupPollIcon size="lg" /> <GroupPollIcon size="lg" />
</div> </div>
<div className="min-w-0 flex-1"> <div className="min-w-0 flex-1">
<div className="truncate font-medium"> <div className="truncate font-medium">{poll.title}</div>
<Link href={`/poll/${poll.id}`}>{poll.title}</Link>
</div>
<div className="text-muted-foreground whitespace-nowrap text-sm"> <div className="text-muted-foreground whitespace-nowrap text-sm">
<RelativeDate date={poll.createdAt} /> <RelativeDate date={poll.createdAt} />
</div> </div>

View file

@ -3,7 +3,7 @@ import { withPostHog } from "@rallly/posthog/next/middleware";
import { NextResponse } from "next/server"; import { NextResponse } from "next/server";
import withAuth from "next-auth/middleware"; import withAuth from "next-auth/middleware";
import { getLocaleFromHeader, initGuest } from "@/app/guest"; import { getLocaleFromHeader } from "@/app/guest";
import { isSelfHosted } from "@/utils/constants"; import { isSelfHosted } from "@/utils/constants";
const supportedLocales = Object.keys(languages); const supportedLocales = Object.keys(languages);
@ -55,10 +55,9 @@ export const middleware = withAuth(
const res = NextResponse.rewrite(newUrl); const res = NextResponse.rewrite(newUrl);
res.headers.set("x-pathname", newUrl.pathname); res.headers.set("x-pathname", newUrl.pathname);
const jwt = await initGuest(req, res);
if (jwt?.sub) { if (req.nextauth.token) {
await withPostHog(res, { distinctID: jwt.sub }); await withPostHog(res, { distinctID: req.nextauth.token.sub });
} }
return res; return res;

View file

@ -1,12 +1,15 @@
import type { EmailClient } from "@rallly/emails"; import type { EmailClient } from "@rallly/emails";
type User = {
id: string;
isGuest: boolean;
locale?: string;
getEmailClient: (locale?: string) => EmailClient;
image?: string;
};
export type TRPCContext = { export type TRPCContext = {
user: { user?: User;
id: string; locale?: string;
isGuest: boolean;
locale?: string;
getEmailClient: (locale?: string) => EmailClient;
image?: string;
};
ip?: string; ip?: string;
}; };

View file

@ -4,6 +4,7 @@ import { generateOtp } from "@rallly/utils/nanoid";
import { z } from "zod"; import { z } from "zod";
import { isEmailBlocked } from "@/auth"; import { isEmailBlocked } from "@/auth";
import { getEmailClient } from "@/utils/emails";
import { createToken, decryptToken } from "@/utils/session"; import { createToken, decryptToken } from "@/utils/session";
import { publicProcedure, rateLimitMiddleware, router } from "../trpc"; import { publicProcedure, rateLimitMiddleware, router } from "../trpc";
@ -66,7 +67,7 @@ export const auth = router({
code, code,
}); });
await ctx.user.getEmailClient().sendTemplate("RegisterEmail", { await getEmailClient(ctx.locale).sendTemplate("RegisterEmail", {
to: input.email, to: input.email,
props: { props: {
code, code,

View file

@ -17,6 +17,7 @@ import {
proProcedure, proProcedure,
publicProcedure, publicProcedure,
rateLimitMiddleware, rateLimitMiddleware,
requireUserMiddleware,
router, router,
} from "../trpc"; } from "../trpc";
import { comments } from "./polls/comments"; import { comments } from "./polls/comments";
@ -130,6 +131,7 @@ export const polls = router({
// START LEGACY ROUTES // START LEGACY ROUTES
create: possiblyPublicProcedure create: possiblyPublicProcedure
.use(rateLimitMiddleware) .use(rateLimitMiddleware)
.use(requireUserMiddleware)
.input( .input(
z.object({ z.object({
title: z.string().trim().min(1), title: z.string().trim().min(1),
@ -332,7 +334,7 @@ export const polls = router({
}); });
}), }),
// END LEGACY ROUTES // END LEGACY ROUTES
getWatchers: possiblyPublicProcedure getWatchers: publicProcedure
.input( .input(
z.object({ z.object({
pollId: z.string(), pollId: z.string(),
@ -348,16 +350,9 @@ export const polls = router({
}, },
}); });
}), }),
watch: possiblyPublicProcedure watch: privateProcedure
.input(z.object({ pollId: z.string() })) .input(z.object({ pollId: z.string() }))
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
if (ctx.user.isGuest) {
throw new TRPCError({
code: "BAD_REQUEST",
message: "Guests can't watch polls",
});
}
await prisma.watcher.create({ await prisma.watcher.create({
data: { data: {
pollId: input.pollId, pollId: input.pollId,
@ -365,16 +360,9 @@ export const polls = router({
}, },
}); });
}), }),
unwatch: possiblyPublicProcedure unwatch: privateProcedure
.input(z.object({ pollId: z.string() })) .input(z.object({ pollId: z.string() }))
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
if (ctx.user.isGuest) {
throw new TRPCError({
code: "BAD_REQUEST",
message: "Guests can't unwatch polls",
});
}
const watcher = await prisma.watcher.findFirst({ const watcher = await prisma.watcher.findFirst({
where: { where: {
pollId: input.pollId, pollId: input.pollId,
@ -393,31 +381,6 @@ export const polls = router({
}); });
} }
}), }),
getByAdminUrlId: possiblyPublicProcedure
.input(
z.object({
urlId: z.string(),
}),
)
.query(async ({ input }) => {
const res = await prisma.poll.findUnique({
select: {
id: true,
},
where: {
adminUrlId: input.urlId,
},
});
if (!res) {
throw new TRPCError({
code: "NOT_FOUND",
message: "Poll not found",
});
}
return res;
}),
get: publicProcedure get: publicProcedure
.input( .input(
z.object({ z.object({
@ -479,9 +442,11 @@ export const polls = router({
} }
const inviteLink = shortUrl(`/invite/${res.id}`); const inviteLink = shortUrl(`/invite/${res.id}`);
const isOwner = ctx.user.isGuest const userId = ctx.user?.id;
? ctx.user.id === res.guestId
: ctx.user.id === res.userId; const isOwner = ctx.user?.isGuest
? userId === res.guestId
: userId === res.userId;
if (isOwner || res.adminUrlId === input.adminToken) { if (isOwner || res.adminUrlId === input.adminToken) {
return { ...res, inviteLink }; return { ...res, inviteLink };
@ -489,93 +454,6 @@ export const polls = router({
return { ...res, adminUrlId: "", inviteLink }; return { ...res, adminUrlId: "", inviteLink };
} }
}), }),
transfer: possiblyPublicProcedure
.input(
z.object({
pollId: z.string(),
}),
)
.mutation(async ({ input, ctx }) => {
await prisma.poll.update({
where: {
id: input.pollId,
},
data: {
userId: ctx.user.id,
},
});
}),
getParticipating: possiblyPublicProcedure
.input(
z.object({
pagination: z.object({
pageIndex: z.number(),
pageSize: z.number(),
}),
}),
)
.query(async ({ ctx, input }) => {
const [total, rows] = await Promise.all([
prisma.poll.count({
where: {
participants: {
some: {
userId: ctx.user.id,
},
},
},
}),
prisma.poll.findMany({
where: {
deletedAt: null,
participants: {
some: {
userId: ctx.user.id,
},
},
},
select: {
id: true,
title: true,
location: true,
createdAt: true,
timeZone: true,
adminUrlId: true,
participantUrlId: true,
status: true,
event: {
select: {
start: true,
duration: true,
},
},
closed: true,
participants: {
select: {
id: true,
name: true,
},
orderBy: [
{
createdAt: "desc",
},
{ name: "desc" },
],
},
},
orderBy: [
{
createdAt: "desc",
},
{ title: "asc" },
],
skip: input.pagination.pageIndex * input.pagination.pageSize,
take: input.pagination.pageSize,
}),
]);
return { total, rows };
}),
book: proProcedure book: proProcedure
.input( .input(
z.object({ z.object({

View file

@ -5,7 +5,12 @@ import { z } from "zod";
import { getEmailClient } from "@/utils/emails"; import { getEmailClient } from "@/utils/emails";
import { createToken } from "@/utils/session"; import { createToken } from "@/utils/session";
import { publicProcedure, rateLimitMiddleware, router } from "../../trpc"; import {
publicProcedure,
rateLimitMiddleware,
requireUserMiddleware,
router,
} from "../../trpc";
import type { DisableNotificationsPayload } from "../../types"; import type { DisableNotificationsPayload } from "../../types";
export const comments = router({ export const comments = router({
@ -13,12 +18,52 @@ export const comments = router({
.input( .input(
z.object({ z.object({
pollId: z.string(), pollId: z.string(),
hideParticipants: z.boolean().optional(), hideParticipants: z.boolean().optional(), // @deprecated
}), }),
) )
.query(async ({ input: { pollId, hideParticipants }, ctx }) => { .query(async ({ input: { pollId }, ctx }) => {
const poll = await prisma.poll.findUnique({
where: {
id: pollId,
},
select: {
userId: true,
guestId: true,
hideParticipants: true,
},
});
const isOwner = ctx.user?.isGuest
? poll?.guestId === ctx.user.id
: poll?.userId === ctx.user?.id;
const hideParticipants = poll?.hideParticipants && !isOwner;
if (hideParticipants && !isOwner) {
// if hideParticipants is enabled and the user is not the owner
if (!ctx.user) {
// cannot see any comments if there is no user
return [];
} else {
// only show comments created by the current users
return await prisma.comment.findMany({
where: {
pollId,
...(ctx.user.isGuest
? { guestId: ctx.user.id }
: { userId: ctx.user.id }),
},
orderBy: [
{
createdAt: "asc",
},
],
});
}
}
// return all comments
return await prisma.comment.findMany({ return await prisma.comment.findMany({
where: { pollId, userId: hideParticipants ? ctx.user.id : undefined }, where: { pollId },
orderBy: [ orderBy: [
{ {
createdAt: "asc", createdAt: "asc",
@ -28,6 +73,7 @@ export const comments = router({
}), }),
add: publicProcedure add: publicProcedure
.use(rateLimitMiddleware) .use(rateLimitMiddleware)
.use(requireUserMiddleware)
.input( .input(
z.object({ z.object({
pollId: z.string(), pollId: z.string(),

View file

@ -5,7 +5,12 @@ import { z } from "zod";
import { createToken } from "@/utils/session"; import { createToken } from "@/utils/session";
import { publicProcedure, rateLimitMiddleware, router } from "../../trpc"; import {
publicProcedure,
rateLimitMiddleware,
requireUserMiddleware,
router,
} from "../../trpc";
import type { DisableNotificationsPayload } from "../../types"; import type { DisableNotificationsPayload } from "../../types";
const MAX_PARTICIPANTS = 1000; const MAX_PARTICIPANTS = 1000;
@ -59,6 +64,7 @@ export const participants = router({
}), }),
add: publicProcedure add: publicProcedure
.use(rateLimitMiddleware) .use(rateLimitMiddleware)
.use(requireUserMiddleware)
.input( .input(
z.object({ z.object({
pollId: z.string(), pollId: z.string(),

View file

@ -12,7 +12,6 @@ import { createToken } from "@/utils/session";
import { getSubscriptionStatus } from "@/utils/subscription"; import { getSubscriptionStatus } from "@/utils/subscription";
import { import {
possiblyPublicProcedure,
privateProcedure, privateProcedure,
publicProcedure, publicProcedure,
rateLimitMiddleware, rateLimitMiddleware,
@ -68,9 +67,9 @@ export const user = router({
}, },
}); });
}), }),
subscription: possiblyPublicProcedure.query( subscription: publicProcedure.query(
async ({ ctx }): Promise<{ legacy?: boolean; active: boolean }> => { async ({ ctx }): Promise<{ legacy?: boolean; active: boolean }> => {
if (ctx.user.isGuest) { if (!ctx.user || ctx.user.isGuest) {
// guest user can't have an active subscription // guest user can't have an active subscription
return { return {
active: false, active: false,

View file

@ -1,5 +1,4 @@
import { createServerSideHelpers } from "@trpc/react-query/server"; import { createServerSideHelpers } from "@trpc/react-query/server";
import { TRPCError } from "@trpc/server";
import { redirect } from "next/navigation"; import { redirect } from "next/navigation";
import superjson from "superjson"; import superjson from "superjson";
@ -11,22 +10,17 @@ import { appRouter } from "../routers";
async function createContext(): Promise<TRPCContext> { async function createContext(): Promise<TRPCContext> {
const session = await getServerSession(); const session = await getServerSession();
if (!session) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "Unauthorized",
});
}
return { return {
user: { user: session?.user
id: session.user.id, ? {
isGuest: session.user.email === null, id: session.user.id,
locale: session.user.locale ?? undefined, isGuest: !session.user.email,
image: session.user.image ?? undefined, locale: session.user.locale ?? undefined,
getEmailClient: () => getEmailClient(session.user.locale ?? undefined), image: session.user.image ?? undefined,
}, getEmailClient: () =>
getEmailClient(session.user?.locale ?? undefined),
}
: undefined,
}; };
} }

View file

@ -3,6 +3,7 @@ import { Ratelimit } from "@upstash/ratelimit";
import { kv } from "@vercel/kv"; import { kv } from "@vercel/kv";
import superjson from "superjson"; import superjson from "superjson";
import { isQuickCreateEnabled } from "@/features/quick-create";
import { isSelfHosted } from "@/utils/constants"; import { isSelfHosted } from "@/utils/constants";
import { getSubscriptionStatus } from "@/utils/subscription"; import { getSubscriptionStatus } from "@/utils/subscription";
@ -23,48 +24,9 @@ export const middleware = t.middleware;
export const possiblyPublicProcedure = t.procedure.use( export const possiblyPublicProcedure = t.procedure.use(
middleware(async ({ ctx, next }) => { middleware(async ({ ctx, next }) => {
// On self-hosted instances, these procedures require login // These procedurs are public if Quick Create is enabled
if (isSelfHosted && ctx.user.isGuest) { const isGuest = !ctx.user || ctx.user.isGuest;
throw new TRPCError({ if (isGuest && !isQuickCreateEnabled) {
code: "UNAUTHORIZED",
message: "Login is required",
});
}
return next();
}),
);
export const proProcedure = t.procedure.use(
middleware(async ({ ctx, next }) => {
if (ctx.user.isGuest) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "Login is required",
});
}
if (isSelfHosted) {
// Self-hosted instances don't have paid subscriptions
return next();
}
const { active: isPro } = await getSubscriptionStatus(ctx.user.id);
if (!isPro) {
throw new TRPCError({
code: "UNAUTHORIZED",
message:
"You must have an active paid subscription to perform this action",
});
}
return next();
}),
);
export const privateProcedure = t.procedure.use(
middleware(async ({ ctx, next }) => {
if (ctx.user.isGuest !== false) {
throw new TRPCError({ throw new TRPCError({
code: "UNAUTHORIZED", code: "UNAUTHORIZED",
message: "Login is required", message: "Login is required",
@ -75,6 +37,58 @@ export const privateProcedure = t.procedure.use(
}), }),
); );
// This procedure guarantees that a user will exist in the context by
// creating a guest user if needed
export const requireUserMiddleware = middleware(async ({ ctx, next }) => {
if (!ctx.user) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "This method requires a user",
});
}
return next({
ctx: {
user: ctx.user,
},
});
});
export const privateProcedure = t.procedure.use(async ({ ctx, next }) => {
const { user } = ctx;
if (!user || user.isGuest !== false) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "Login is required",
});
}
return next({
ctx: {
user,
},
});
});
export const proProcedure = privateProcedure.use(async ({ ctx, next }) => {
if (isSelfHosted) {
// Self-hosted instances don't have paid subscriptions
return next();
}
const { active: isPro } = await getSubscriptionStatus(ctx.user.id);
if (!isPro) {
throw new TRPCError({
code: "UNAUTHORIZED",
message:
"You must have an active paid subscription to perform this action",
});
}
return next();
});
export const rateLimitMiddleware = middleware(async ({ ctx, next }) => { export const rateLimitMiddleware = middleware(async ({ ctx, next }) => {
if (!process.env.KV_REST_API_URL) { if (!process.env.KV_REST_API_URL) {
return next(); return next();