♻️ Make user session optional (#1515)

This commit is contained in:
Luke Vella 2025-01-27 13:02:34 +00:00 committed by GitHub
parent f6a0bca4f8
commit 58d5c42a6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 343 additions and 549 deletions

View file

@ -6,3 +6,4 @@ DATABASE_URL=postgres://postgres:postgres@localhost:5450/rallly
SUPPORT_EMAIL=support@rallly.co
SMTP_HOST=localhost
SMTP_PORT=1025
QUICK_CREATE_ENABLED=true

View file

@ -10,7 +10,7 @@ declare module "next-auth" {
* Returned by `useSession`, `getSession` and received as a prop on the `SessionProvider` React Context
*/
interface Session {
user: {
user?: {
id: string;
timeZone?: string | null;
timeFormat?: TimeFormat | null;

View file

@ -1,12 +1,8 @@
"use client";
import { Alert, AlertDescription, AlertTitle } from "@rallly/ui/alert";
import { Button } from "@rallly/ui/button";
import { DialogTrigger } from "@rallly/ui/dialog";
import { Input } from "@rallly/ui/input";
import { Label } from "@rallly/ui/label";
import { InfoIcon, LogOutIcon, TrashIcon, UserXIcon } from "lucide-react";
import { LogOutIcon, TrashIcon } from "lucide-react";
import Head from "next/head";
import Link from "next/link";
import { useTranslation } from "next-i18next";
import { DeleteAccountDialog } from "@/app/[locale]/(admin)/settings/profile/delete-account-dialog";
@ -31,43 +27,6 @@ export const ProfilePage = () => {
<Head>
<title>{t("profile")}</title>
</Head>
{user.isGuest ? (
<SettingsContent>
<SettingsSection
title={<Trans i18nKey="profile" />}
description={<Trans i18nKey="profileDescription" />}
>
<Label className="mb-2.5">
<Trans i18nKey="userId" defaults="User ID" />
</Label>
<Input
className="w-full"
value={user.id.substring(0, 10)}
readOnly
disabled
/>
<Alert className="mt-4" icon={InfoIcon}>
<AlertTitle>
<Trans i18nKey="aboutGuest" defaults="Guest User" />
</AlertTitle>
<AlertDescription>
<Trans
i18nKey="aboutGuestDescription"
defaults="Profile settings are not available for guest users. <0>Sign in</0> to your existing account or <1>create a new account</1> to customize your profile."
components={[
<Link className="text-link" key={0} href="/login" />,
<Link className="text-link" key={1} href="/register" />,
]}
/>
</AlertDescription>
</Alert>
<LogoutButton className="mt-6" variant="destructive">
<UserXIcon className="size-4" />
<Trans i18nKey="forgetMe" />
</LogoutButton>
</SettingsSection>
</SettingsContent>
) : (
<SettingsContent>
<SettingsSection
title={<Trans i18nKey="profile" defaults="Profile" />}
@ -125,10 +84,7 @@ export const ProfilePage = () => {
<DialogTrigger asChild>
<Button className="text-destructive">
<TrashIcon className="size-4" />
<Trans
i18nKey="deleteAccount"
defaults="Delete Account"
/>
<Trans i18nKey="deleteAccount" defaults="Delete Account" />
</Button>
</DialogTrigger>
</DeleteAccountDialog>
@ -136,7 +92,6 @@ export const ProfilePage = () => {
</>
) : null}
</SettingsContent>
)}
</Settings>
);
};

View file

@ -24,7 +24,7 @@ export const LoginPage = ({ magicLink, email }: PageProps) => {
if (!data.url.includes("auth/error")) {
// if login was successful, update the session
const updatedSession = await session.update();
if (updatedSession) {
if (updatedSession?.user) {
// identify the user in posthog
posthog?.identify(updatedSession.user.id, {
email: updatedSession.user.email,

View file

@ -1,10 +0,0 @@
import type { NextRequest } from "next/server";
import { NextResponse } from "next/server";
import { resetUser } from "@/app/guest";
export async function POST(req: NextRequest) {
const res = NextResponse.json({ ok: 1 });
await resetUser(req, res);
return res;
}

View file

@ -16,7 +16,7 @@ export const GET = async (req: NextRequest) => {
const session = await getServerSession();
if (!session || !session.user.email) {
if (!session || !session.user?.email) {
return NextResponse.redirect(new URL("/login", req.url));
}

View file

@ -20,7 +20,7 @@ export async function POST(request: NextRequest) {
Object.fromEntries(formData.entries()),
);
if (!userSession || userSession.user.email === null) {
if (!userSession?.user || userSession.user?.email === null) {
// You need to be logged in to subscribe
return NextResponse.redirect(
new URL(

View file

@ -33,7 +33,7 @@ export async function GET(request: NextRequest) {
}
} else {
const userSession = await getServerSession();
if (!userSession || userSession.user.email === null) {
if (!userSession?.user || userSession.user.email === null) {
Sentry.captureException(new Error("User not logged in"));
return NextResponse.json(
{ error: "User not logged in" },

View file

@ -1,38 +1,37 @@
import * as Sentry from "@sentry/nextjs";
import { TRPCError } from "@trpc/server";
import { fetchRequestHandler } from "@trpc/server/adapters/fetch";
import { ipAddress } from "@vercel/functions";
import type { NextRequest } from "next/server";
import { getLocaleFromHeader } from "@/app/guest";
import { getServerSession } from "@/auth";
import type { TRPCContext } from "@/trpc/context";
import { appRouter } from "@/trpc/routers";
import { getEmailClient } from "@/utils/emails";
const handler = (request: Request) => {
const handler = (req: NextRequest) => {
return fetchRequestHandler({
endpoint: "/api/trpc",
req: request,
req,
router: appRouter,
createContext: async () => {
const session = await getServerSession();
if (!session?.user) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "Unauthorized",
});
}
return {
user: {
const locale = await getLocaleFromHeader(req);
const user = session?.user
? {
id: session.user.id,
isGuest: session.user.email === null,
isGuest: !session.user.email,
locale: session.user.locale ?? undefined,
image: session.user.image ?? undefined,
getEmailClient: () =>
getEmailClient(session.user?.locale ?? undefined),
},
ip: ipAddress(request) ?? undefined,
}
: undefined;
return {
user,
locale,
ip: ipAddress(req) ?? undefined,
} satisfies TRPCContext;
},
onError({ error }) {

View file

@ -52,7 +52,7 @@ export const GET = async (request: NextRequest) => {
const session = await getServerSession();
if (!session || !session.user.email) {
if (!session?.user || !session.user.email) {
return NextResponse.redirect(
new URL(`/login?callbackUrl=${request.url}`, request.url),
);

View file

@ -1,23 +1,9 @@
import languages from "@rallly/languages";
import { absoluteUrl } from "@rallly/utils/absolute-url";
import { randomid } from "@rallly/utils/nanoid";
import languageParser from "accept-language-parser";
import type { NextRequest, NextResponse } from "next/server";
import type { JWT } from "next-auth/jwt";
import { decode, encode } from "next-auth/jwt";
import type { NextRequest } from "next/server";
const supportedLocales = Object.keys(languages);
function getCookieSettings() {
const secure = absoluteUrl().startsWith("https://");
const prefix = secure ? "__Secure-" : "";
const name = `${prefix}next-auth.session-token`;
return {
secure,
name,
};
}
export async function getLocaleFromHeader(req: NextRequest) {
// Check if locale is specified in header
const headers = req.headers;
@ -27,65 +13,3 @@ export async function getLocaleFromHeader(req: NextRequest) {
: null;
return localeFromHeader ?? "en";
}
async function setCookie(res: NextResponse, jwt: JWT) {
const { name, secure } = getCookieSettings();
const token = await encode({
token: jwt,
secret: process.env.SECRET_PASSWORD,
});
res.cookies.set({
name,
value: token,
httpOnly: true,
secure,
sameSite: "lax",
path: "/",
});
}
export async function resetUser(req: NextRequest, res: NextResponse) {
// resets to a new guest user
const locale = await getLocaleFromHeader(req);
const jwt: JWT = {
sub: `user-${randomid()}`,
email: null,
locale,
};
await setCookie(res, jwt);
}
export async function initGuest(req: NextRequest, res: NextResponse) {
const { name } = getCookieSettings();
const token = req.cookies.get(name)?.value;
if (token) {
try {
const jwt = await decode({
token,
secret: process.env.SECRET_PASSWORD,
});
if (jwt) {
return jwt;
}
} catch (error) {
// invalid token
console.error(error);
}
}
const locale = await getLocaleFromHeader(req);
const jwt: JWT = {
sub: `user-${randomid()}`,
email: null,
locale,
};
await setCookie(res, jwt);
return jwt;
}

View file

@ -174,7 +174,7 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
adapter: CustomPrismaAdapter(prisma, {
migrateData: async (userId) => {
const session = await getServerSession(...args);
if (session && session.user.email === null) {
if (session?.user && session.user.email === null) {
await mergeGuestsIntoUser(userId, [session.user.id]);
}
},
@ -255,7 +255,7 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
if (!isInitialSocialLogin) {
// merge guest user into newly logged in user
const session = await getServerSession(...args);
if (session && session.user.email === null) {
if (session?.user && session.user.email === null) {
await mergeGuestsIntoUser(user.id, [session.user.id]);
}
}
@ -273,13 +273,18 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
return token;
},
async session({ session, token }) {
if (!token.sub) {
return session;
}
if (token.sub?.startsWith("user-")) {
session.user.id = token.sub as string;
session.user.locale = token.locale;
session.user.timeFormat = token.timeFormat;
session.user.timeZone = token.timeZone;
session.user.locale = token.locale;
session.user.weekStart = token.weekStart;
session.user = {
id: token.sub as string,
locale: token.locale,
timeFormat: token.timeFormat,
timeZone: token.timeZone,
weekStart: token.weekStart,
};
return session;
}
@ -300,21 +305,18 @@ const getAuthOptions = (...args: GetServerSessionParams) =>
});
if (user) {
session.user.id = user.id;
session.user.name = user.name;
session.user.email = user.email;
session.user.image = user.image;
} else {
session.user.id = token.sub || `user-${randomid()}`;
session.user = {
id: user.id,
name: user.name,
email: user.email,
image: user.image,
locale: user.locale,
timeFormat: user.timeFormat,
timeZone: user.timeZone,
weekStart: user.weekStart,
};
}
const source = user ?? token;
session.user.locale = source.locale;
session.user.timeFormat = source.timeFormat;
session.user.timeZone = source.timeZone;
session.user.weekStart = source.weekStart;
return session;
},
},

View file

@ -16,26 +16,27 @@ import type { Adapter, AdapterAccount } from "next-auth/adapters";
export function CustomPrismaAdapter(
client: ExtendedPrismaClient,
options: { migrateData: (userId: string) => Promise<void> },
): Adapter {
) {
const adapter = PrismaAdapter(client as PrismaClient);
return {
...PrismaAdapter(client as PrismaClient),
linkAccount: async (data) => {
await options.migrateData(data.userId);
return client.account.create({
...adapter,
linkAccount: async (account: AdapterAccount) => {
await options.migrateData(account.userId);
return (await client.account.create({
data: {
userId: data.userId,
type: data.type,
provider: data.provider,
providerAccountId: data.providerAccountId,
access_token: data.access_token as string,
expires_at: data.expires_at as number,
id_token: data.id_token as string,
token_type: data.token_type as string,
refresh_token: data.refresh_token as string,
scope: data.scope as string,
session_state: data.session_state as string,
userId: account.userId,
type: account.type,
provider: account.provider,
providerAccountId: account.providerAccountId,
access_token: account.access_token as string,
expires_at: account.expires_at as number,
id_token: account.id_token as string,
token_type: account.token_type as string,
refresh_token: account.refresh_token as string,
scope: account.scope as string,
session_state: account.session_state as string,
},
}) as unknown as AdapterAccount;
})) as AdapterAccount;
},
};
} satisfies Adapter;
}

View file

@ -10,6 +10,7 @@ import {
} from "@rallly/ui/card";
import { Form } from "@rallly/ui/form";
import { useRouter } from "next/navigation";
import { signIn, useSession } from "next-auth/react";
import React from "react";
import { useForm } from "react-hook-form";
import useFormPersist from "react-hook-form-persist";
@ -19,7 +20,6 @@ import { PollSettingsForm } from "@/components/forms/poll-settings";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { trpc } from "@/trpc/client";
import { setCookie } from "@/utils/cookies";
import type { NewEventData } from "./forms";
import { PollDetailsForm, PollOptionsForm } from "./forms";
@ -42,6 +42,7 @@ export interface CreatePollPageProps {
export const CreatePoll: React.FunctionComponent = () => {
const router = useRouter();
const { user } = useUser();
const session = useSession();
const form = useForm<NewEventData>({
defaultValues: {
title: "",
@ -66,8 +67,12 @@ export const CreatePoll: React.FunctionComponent = () => {
const posthog = usePostHog();
const createPoll = trpc.polls.create.useMutation({
networkMode: "always",
onSuccess: () => {
setCookie("new-poll", "1");
onMutate: async () => {
if (session.status !== "authenticated") {
await signIn("guest", {
redirect: false,
});
}
},
});
@ -76,7 +81,6 @@ export const CreatePoll: React.FunctionComponent = () => {
<form
onSubmit={form.handleSubmit(async (formData) => {
const title = required(formData?.title);
await createPoll.mutateAsync(
{
title: title,

View file

@ -26,6 +26,7 @@ import {
MoreHorizontalIcon,
TrashIcon,
} from "lucide-react";
import { signIn, useSession } from "next-auth/react";
import { useTranslation } from "next-i18next";
import * as React from "react";
import { Controller, useForm } from "react-hook-form";
@ -57,6 +58,7 @@ function NewCommentForm({
const { t } = useTranslation();
const poll = usePoll();
const { user } = useUser();
const session = useSession();
const { participants } = useParticipants();
const authorName = React.useMemo(() => {
@ -72,8 +74,6 @@ function NewCommentForm({
const posthog = usePostHog();
const session = useUser();
const { register, reset, control, handleSubmit, formState } =
useForm<CommentForm>({
defaultValues: {
@ -84,6 +84,13 @@ function NewCommentForm({
const { toast } = useToast();
const addComment = trpc.polls.comments.add.useMutation({
onMutate: async () => {
if (session.status !== "authenticated") {
await signIn("guest", {
redirect: false,
});
}
},
onSuccess: () => {
posthog?.capture("created comment");
},
@ -119,7 +126,7 @@ function NewCommentForm({
>
<Controller
name="authorName"
key={session.user?.id}
key={user?.id}
control={control}
rules={{ validate: requiredString }}
render={({ field }) => (

View file

@ -14,7 +14,6 @@ import { useParams, usePathname } from "next/navigation";
import React from "react";
import { GroupPollIcon } from "@/app/[locale]/(admin)/app-card";
import Loader from "@/app/[locale]/poll/[urlId]/skeleton";
import { LogoutButton } from "@/app/components/logout-button";
import { InviteDialog } from "@/components/invite-dialog";
import { LoginLink } from "@/components/login-link";
@ -30,9 +29,7 @@ import NotificationsToggle from "@/components/poll/notifications-toggle";
import { LegacyPollContextProvider } from "@/components/poll/poll-context-provider";
import { Trans } from "@/components/trans";
import { useUser } from "@/components/user-provider";
import { usePlan } from "@/contexts/plan";
import { usePoll } from "@/contexts/poll";
import { trpc } from "@/trpc/client";
const AdminControls = () => {
return (
@ -86,8 +83,9 @@ const Layout = ({ children }: React.PropsWithChildren) => {
};
const PermissionGuard = ({ children }: React.PropsWithChildren) => {
const poll = usePoll();
const { user } = useUser();
const poll = usePoll();
if (!poll.adminUrlId) {
return (
<PageDialog icon={ShieldCloseIcon}>
@ -139,30 +137,6 @@ const PermissionGuard = ({ children }: React.PropsWithChildren) => {
return <>{children}</>;
};
const Prefetch = ({ children }: React.PropsWithChildren) => {
const params = useParams();
const urlId = params?.urlId as string;
const poll = trpc.polls.get.useQuery({ urlId });
const participants = trpc.polls.participants.list.useQuery({ pollId: urlId });
const watchers = trpc.polls.getWatchers.useQuery({ pollId: urlId });
const comments = trpc.polls.comments.list.useQuery({ pollId: urlId });
usePlan(); // prefetch plan
if (
!poll.isFetched ||
!watchers.isFetched ||
!participants.isFetched ||
!comments.isFetched
) {
return <Loader />;
}
return <>{children}</>;
};
export const PollLayout = ({ children }: React.PropsWithChildren) => {
const params = useParams();
@ -174,12 +148,10 @@ export const PollLayout = ({ children }: React.PropsWithChildren) => {
}
return (
<Prefetch>
<LegacyPollContextProvider>
<PermissionGuard>
<Layout>{children}</Layout>
</PermissionGuard>
</LegacyPollContextProvider>
</Prefetch>
);
};

View file

@ -92,7 +92,6 @@ export const NewParticipantForm = (props: NewParticipantModalProps) => {
const { user } = useUser();
const isLoggedIn = !user.isGuest;
const { register, setError, formState, handleSubmit } =
useForm<NewParticipantFormData>({
resolver: zodResolver(schema),

View file

@ -1,4 +1,5 @@
import { usePostHog } from "@rallly/posthog/client";
import { signIn, useSession } from "next-auth/react";
import { usePoll } from "@/components/poll-context";
import { trpc } from "@/trpc/client";
@ -18,9 +19,16 @@ export const normalizeVotes = (
export const useAddParticipantMutation = () => {
const posthog = usePostHog();
const queryClient = trpc.useUtils();
const session = useSession();
return trpc.polls.participants.add.useMutation({
onSuccess: (newParticipant, input) => {
onMutate: async () => {
if (session.status !== "authenticated") {
await signIn("guest", {
redirect: false,
});
}
},
onSuccess: async (newParticipant, input) => {
const { pollId, name, email } = newParticipant;
queryClient.polls.participants.list.setData(
{ pollId },
@ -31,6 +39,7 @@ export const useAddParticipantMutation = () => {
];
},
);
posthog?.capture("add participant", {
pollId,
name,

View file

@ -44,7 +44,7 @@ const NotificationsToggle: React.FunctionComponent = () => {
queryClient.polls.getWatchers.setData(
{ pollId: poll.id },
(oldWatchers) => {
if (!oldWatchers) {
if (!oldWatchers || !user.id) {
return;
}
return [...oldWatchers, { userId: user.id }];

View file

@ -62,9 +62,11 @@ export const UserDropdown = ({ className }: { className?: string }) => {
<DropdownMenuLabel className="flex items-center gap-2">
<div className="grow">
<div>{user.isGuest ? <Trans i18nKey="guest" /> : user.name}</div>
{user.email ? (
<div className="text-muted-foreground text-xs font-normal">
{!user.isGuest ? user.email : user.id.substring(0, 10)}
{user.email}
</div>
) : null}
</div>
<div className="ml-4">
<Plan />

View file

@ -4,7 +4,6 @@ import type { Session } from "next-auth";
import { signOut, useSession } from "next-auth/react";
import React from "react";
import { Spinner } from "@/components/spinner";
import { useSubscription } from "@/contexts/plan";
import { PreferencesProvider } from "@/contexts/preferences";
import { useTranslation } from "@/i18n/client";
@ -14,7 +13,7 @@ import { isOwner } from "@/utils/permissions";
import { useRequiredContext } from "./use-required-context";
type UserData = {
id: string;
id?: string;
name: string;
email?: string | null;
isGuest: boolean;
@ -71,7 +70,7 @@ export const UserProvider = (props: { children?: React.ReactNode }) => {
const tier = isGuest ? "guest" : subscription?.active ? "pro" : "hobby";
React.useEffect(() => {
if (user?.email) {
if (user) {
posthog?.identify(user.id, {
email: user.email,
name: user.name,
@ -84,26 +83,18 @@ export const UserProvider = (props: { children?: React.ReactNode }) => {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [user?.id]);
if (!user) {
return (
<div className="flex h-screen items-center justify-center">
<Spinner />
</div>
);
}
return (
<UserContext.Provider
value={{
user: {
id: user.id as string,
name: user.name ?? t("guest"),
email: user.email || null,
id: user?.id,
name: user?.name ?? t("guest"),
email: user?.email || null,
isGuest,
tier,
timeZone: user.timeZone ?? null,
image: user.image ?? null,
locale: user.locale ?? i18n.language,
timeZone: user?.timeZone ?? null,
image: user?.image ?? null,
locale: user?.locale ?? i18n.language,
},
refresh: session.update,
logout: async () => {
@ -112,16 +103,16 @@ export const UserProvider = (props: { children?: React.ReactNode }) => {
posthog?.reset();
},
ownsObject: (resource) => {
return isOwner(resource, { id: user.id, isGuest });
return user ? isOwner(resource, { id: user.id, isGuest }) : false;
},
}}
>
<PreferencesProvider
initialValue={{
locale: user.locale ?? undefined,
timeZone: user.timeZone ?? undefined,
timeFormat: user.timeFormat ?? undefined,
weekStart: user.weekStart ?? undefined,
locale: user?.locale ?? undefined,
timeZone: user?.timeZone ?? undefined,
timeFormat: user?.timeFormat ?? undefined,
weekStart: user?.weekStart ?? undefined,
}}
onUpdate={async (newPreferences) => {
if (!isGuest) {

View file

@ -56,9 +56,7 @@ export async function QuickCreateWidget() {
<GroupPollIcon size="lg" />
</div>
<div className="min-w-0 flex-1">
<div className="truncate font-medium">
<Link href={`/poll/${poll.id}`}>{poll.title}</Link>
</div>
<div className="truncate font-medium">{poll.title}</div>
<div className="text-muted-foreground whitespace-nowrap text-sm">
<RelativeDate date={poll.createdAt} />
</div>

View file

@ -3,7 +3,7 @@ import { withPostHog } from "@rallly/posthog/next/middleware";
import { NextResponse } from "next/server";
import withAuth from "next-auth/middleware";
import { getLocaleFromHeader, initGuest } from "@/app/guest";
import { getLocaleFromHeader } from "@/app/guest";
import { isSelfHosted } from "@/utils/constants";
const supportedLocales = Object.keys(languages);
@ -55,10 +55,9 @@ export const middleware = withAuth(
const res = NextResponse.rewrite(newUrl);
res.headers.set("x-pathname", newUrl.pathname);
const jwt = await initGuest(req, res);
if (jwt?.sub) {
await withPostHog(res, { distinctID: jwt.sub });
if (req.nextauth.token) {
await withPostHog(res, { distinctID: req.nextauth.token.sub });
}
return res;

View file

@ -1,12 +1,15 @@
import type { EmailClient } from "@rallly/emails";
export type TRPCContext = {
user: {
type User = {
id: string;
isGuest: boolean;
locale?: string;
getEmailClient: (locale?: string) => EmailClient;
image?: string;
};
export type TRPCContext = {
user?: User;
locale?: string;
ip?: string;
};

View file

@ -4,6 +4,7 @@ import { generateOtp } from "@rallly/utils/nanoid";
import { z } from "zod";
import { isEmailBlocked } from "@/auth";
import { getEmailClient } from "@/utils/emails";
import { createToken, decryptToken } from "@/utils/session";
import { publicProcedure, rateLimitMiddleware, router } from "../trpc";
@ -66,7 +67,7 @@ export const auth = router({
code,
});
await ctx.user.getEmailClient().sendTemplate("RegisterEmail", {
await getEmailClient(ctx.locale).sendTemplate("RegisterEmail", {
to: input.email,
props: {
code,

View file

@ -17,6 +17,7 @@ import {
proProcedure,
publicProcedure,
rateLimitMiddleware,
requireUserMiddleware,
router,
} from "../trpc";
import { comments } from "./polls/comments";
@ -130,6 +131,7 @@ export const polls = router({
// START LEGACY ROUTES
create: possiblyPublicProcedure
.use(rateLimitMiddleware)
.use(requireUserMiddleware)
.input(
z.object({
title: z.string().trim().min(1),
@ -332,7 +334,7 @@ export const polls = router({
});
}),
// END LEGACY ROUTES
getWatchers: possiblyPublicProcedure
getWatchers: publicProcedure
.input(
z.object({
pollId: z.string(),
@ -348,16 +350,9 @@ export const polls = router({
},
});
}),
watch: possiblyPublicProcedure
watch: privateProcedure
.input(z.object({ pollId: z.string() }))
.mutation(async ({ input, ctx }) => {
if (ctx.user.isGuest) {
throw new TRPCError({
code: "BAD_REQUEST",
message: "Guests can't watch polls",
});
}
await prisma.watcher.create({
data: {
pollId: input.pollId,
@ -365,16 +360,9 @@ export const polls = router({
},
});
}),
unwatch: possiblyPublicProcedure
unwatch: privateProcedure
.input(z.object({ pollId: z.string() }))
.mutation(async ({ input, ctx }) => {
if (ctx.user.isGuest) {
throw new TRPCError({
code: "BAD_REQUEST",
message: "Guests can't unwatch polls",
});
}
const watcher = await prisma.watcher.findFirst({
where: {
pollId: input.pollId,
@ -393,31 +381,6 @@ export const polls = router({
});
}
}),
getByAdminUrlId: possiblyPublicProcedure
.input(
z.object({
urlId: z.string(),
}),
)
.query(async ({ input }) => {
const res = await prisma.poll.findUnique({
select: {
id: true,
},
where: {
adminUrlId: input.urlId,
},
});
if (!res) {
throw new TRPCError({
code: "NOT_FOUND",
message: "Poll not found",
});
}
return res;
}),
get: publicProcedure
.input(
z.object({
@ -479,9 +442,11 @@ export const polls = router({
}
const inviteLink = shortUrl(`/invite/${res.id}`);
const isOwner = ctx.user.isGuest
? ctx.user.id === res.guestId
: ctx.user.id === res.userId;
const userId = ctx.user?.id;
const isOwner = ctx.user?.isGuest
? userId === res.guestId
: userId === res.userId;
if (isOwner || res.adminUrlId === input.adminToken) {
return { ...res, inviteLink };
@ -489,93 +454,6 @@ export const polls = router({
return { ...res, adminUrlId: "", inviteLink };
}
}),
transfer: possiblyPublicProcedure
.input(
z.object({
pollId: z.string(),
}),
)
.mutation(async ({ input, ctx }) => {
await prisma.poll.update({
where: {
id: input.pollId,
},
data: {
userId: ctx.user.id,
},
});
}),
getParticipating: possiblyPublicProcedure
.input(
z.object({
pagination: z.object({
pageIndex: z.number(),
pageSize: z.number(),
}),
}),
)
.query(async ({ ctx, input }) => {
const [total, rows] = await Promise.all([
prisma.poll.count({
where: {
participants: {
some: {
userId: ctx.user.id,
},
},
},
}),
prisma.poll.findMany({
where: {
deletedAt: null,
participants: {
some: {
userId: ctx.user.id,
},
},
},
select: {
id: true,
title: true,
location: true,
createdAt: true,
timeZone: true,
adminUrlId: true,
participantUrlId: true,
status: true,
event: {
select: {
start: true,
duration: true,
},
},
closed: true,
participants: {
select: {
id: true,
name: true,
},
orderBy: [
{
createdAt: "desc",
},
{ name: "desc" },
],
},
},
orderBy: [
{
createdAt: "desc",
},
{ title: "asc" },
],
skip: input.pagination.pageIndex * input.pagination.pageSize,
take: input.pagination.pageSize,
}),
]);
return { total, rows };
}),
book: proProcedure
.input(
z.object({

View file

@ -5,7 +5,12 @@ import { z } from "zod";
import { getEmailClient } from "@/utils/emails";
import { createToken } from "@/utils/session";
import { publicProcedure, rateLimitMiddleware, router } from "../../trpc";
import {
publicProcedure,
rateLimitMiddleware,
requireUserMiddleware,
router,
} from "../../trpc";
import type { DisableNotificationsPayload } from "../../types";
export const comments = router({
@ -13,12 +18,52 @@ export const comments = router({
.input(
z.object({
pollId: z.string(),
hideParticipants: z.boolean().optional(),
hideParticipants: z.boolean().optional(), // @deprecated
}),
)
.query(async ({ input: { pollId, hideParticipants }, ctx }) => {
.query(async ({ input: { pollId }, ctx }) => {
const poll = await prisma.poll.findUnique({
where: {
id: pollId,
},
select: {
userId: true,
guestId: true,
hideParticipants: true,
},
});
const isOwner = ctx.user?.isGuest
? poll?.guestId === ctx.user.id
: poll?.userId === ctx.user?.id;
const hideParticipants = poll?.hideParticipants && !isOwner;
if (hideParticipants && !isOwner) {
// if hideParticipants is enabled and the user is not the owner
if (!ctx.user) {
// cannot see any comments if there is no user
return [];
} else {
// only show comments created by the current users
return await prisma.comment.findMany({
where: { pollId, userId: hideParticipants ? ctx.user.id : undefined },
where: {
pollId,
...(ctx.user.isGuest
? { guestId: ctx.user.id }
: { userId: ctx.user.id }),
},
orderBy: [
{
createdAt: "asc",
},
],
});
}
}
// return all comments
return await prisma.comment.findMany({
where: { pollId },
orderBy: [
{
createdAt: "asc",
@ -28,6 +73,7 @@ export const comments = router({
}),
add: publicProcedure
.use(rateLimitMiddleware)
.use(requireUserMiddleware)
.input(
z.object({
pollId: z.string(),

View file

@ -5,7 +5,12 @@ import { z } from "zod";
import { createToken } from "@/utils/session";
import { publicProcedure, rateLimitMiddleware, router } from "../../trpc";
import {
publicProcedure,
rateLimitMiddleware,
requireUserMiddleware,
router,
} from "../../trpc";
import type { DisableNotificationsPayload } from "../../types";
const MAX_PARTICIPANTS = 1000;
@ -59,6 +64,7 @@ export const participants = router({
}),
add: publicProcedure
.use(rateLimitMiddleware)
.use(requireUserMiddleware)
.input(
z.object({
pollId: z.string(),

View file

@ -12,7 +12,6 @@ import { createToken } from "@/utils/session";
import { getSubscriptionStatus } from "@/utils/subscription";
import {
possiblyPublicProcedure,
privateProcedure,
publicProcedure,
rateLimitMiddleware,
@ -68,9 +67,9 @@ export const user = router({
},
});
}),
subscription: possiblyPublicProcedure.query(
subscription: publicProcedure.query(
async ({ ctx }): Promise<{ legacy?: boolean; active: boolean }> => {
if (ctx.user.isGuest) {
if (!ctx.user || ctx.user.isGuest) {
// guest user can't have an active subscription
return {
active: false,

View file

@ -1,5 +1,4 @@
import { createServerSideHelpers } from "@trpc/react-query/server";
import { TRPCError } from "@trpc/server";
import { redirect } from "next/navigation";
import superjson from "superjson";
@ -11,22 +10,17 @@ import { appRouter } from "../routers";
async function createContext(): Promise<TRPCContext> {
const session = await getServerSession();
if (!session) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "Unauthorized",
});
}
return {
user: {
user: session?.user
? {
id: session.user.id,
isGuest: session.user.email === null,
isGuest: !session.user.email,
locale: session.user.locale ?? undefined,
image: session.user.image ?? undefined,
getEmailClient: () => getEmailClient(session.user.locale ?? undefined),
},
getEmailClient: () =>
getEmailClient(session.user?.locale ?? undefined),
}
: undefined,
};
}

View file

@ -3,6 +3,7 @@ import { Ratelimit } from "@upstash/ratelimit";
import { kv } from "@vercel/kv";
import superjson from "superjson";
import { isQuickCreateEnabled } from "@/features/quick-create";
import { isSelfHosted } from "@/utils/constants";
import { getSubscriptionStatus } from "@/utils/subscription";
@ -23,26 +24,53 @@ export const middleware = t.middleware;
export const possiblyPublicProcedure = t.procedure.use(
middleware(async ({ ctx, next }) => {
// On self-hosted instances, these procedures require login
if (isSelfHosted && ctx.user.isGuest) {
// These procedurs are public if Quick Create is enabled
const isGuest = !ctx.user || ctx.user.isGuest;
if (isGuest && !isQuickCreateEnabled) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "Login is required",
});
}
return next();
}),
);
export const proProcedure = t.procedure.use(
middleware(async ({ ctx, next }) => {
if (ctx.user.isGuest) {
// This procedure guarantees that a user will exist in the context by
// creating a guest user if needed
export const requireUserMiddleware = middleware(async ({ ctx, next }) => {
if (!ctx.user) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "This method requires a user",
});
}
return next({
ctx: {
user: ctx.user,
},
});
});
export const privateProcedure = t.procedure.use(async ({ ctx, next }) => {
const { user } = ctx;
if (!user || user.isGuest !== false) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "Login is required",
});
}
return next({
ctx: {
user,
},
});
});
export const proProcedure = privateProcedure.use(async ({ ctx, next }) => {
if (isSelfHosted) {
// Self-hosted instances don't have paid subscriptions
return next();
@ -59,21 +87,7 @@ export const proProcedure = t.procedure.use(
}
return next();
}),
);
export const privateProcedure = t.procedure.use(
middleware(async ({ ctx, next }) => {
if (ctx.user.isGuest !== false) {
throw new TRPCError({
code: "UNAUTHORIZED",
message: "Login is required",
});
}
return next();
}),
);
export const rateLimitMiddleware = middleware(async ({ ctx, next }) => {
if (!process.env.KV_REST_API_URL) {