Add instance settings and option to disable user registration (#1745)

This commit is contained in:
Luke Vella 2025-06-02 19:40:00 +01:00 committed by GitHub
parent 9e1f3c616e
commit 3c2e008579
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 552 additions and 153 deletions

View file

@ -1,4 +1,4 @@
import type { TimeFormat } from "@rallly/database"; import type { TimeFormat, UserRole } from "@rallly/database";
import type { DefaultSession, DefaultUser } from "next-auth"; import type { DefaultSession, DefaultUser } from "next-auth";
import type { DefaultJWT } from "next-auth/jwt"; import type { DefaultJWT } from "next-auth/jwt";
import type { NextRequest } from "next/server"; import type { NextRequest } from "next/server";
@ -23,6 +23,7 @@ declare module "next-auth" {
timeFormat?: TimeFormat | null; timeFormat?: TimeFormat | null;
weekStart?: number | null; weekStart?: number | null;
banned?: boolean | null; banned?: boolean | null;
role?: UserRole | null;
} }
interface NextAuthRequest extends NextRequest { interface NextAuthRequest extends NextRequest {

View file

@ -386,5 +386,14 @@
"licenseKeyErrorRateLimitExceeded": "Rate limit exceeded", "licenseKeyErrorRateLimitExceeded": "Rate limit exceeded",
"licenseKeyErrorInvalidLicenseKey": "Invalid license key", "licenseKeyErrorInvalidLicenseKey": "Invalid license key",
"licenseKeyGenericError": "An error occurred while validating the license key", "licenseKeyGenericError": "An error occurred while validating the license key",
"activate": "Activate" "activate": "Activate",
"authErrorsRegistrationDisabled": "Registration is currently disabled. Please try again later.",
"authErrorsEmailNotVerified": "Your email address is not verified. Please verify your email before logging in.",
"authErrorsUserBanned": "This account has been banned. Please contact support if you believe this is an error.",
"authErrorsEmailBlocked": "This email address is not allowed. Please use a different email or contact support.",
"authErrorsUserNotFound": "No account found with this email address. Please check the email or register for a new account.",
"disableUserRegistration": "Disable User Registration",
"disableUserRegistrationDescription": "Prevent new users from registering an account.",
"authenticationAndSecurity": "Authentication & Security",
"authenticationAndSecurityDescription": "Manage authentication and security settings"
} }

View file

@ -6,16 +6,62 @@ export function AuthErrors() {
const { t } = useTranslation(); const { t } = useTranslation();
const searchParams = useSearchParams(); const searchParams = useSearchParams();
const error = searchParams?.get("error"); const error = searchParams?.get("error");
if (error === "OAuthAccountNotLinked") { switch (error) {
return ( case "OAuthAccountNotLinked":
<p className="text-destructive text-sm"> return (
{t("accountNotLinkedDescription", { <p className="text-destructive text-sm">
defaultValue: {t("accountNotLinkedDescription", {
"A user with this email already exists. Please log in using the original method.", defaultValue:
})} "A user with this email already exists. Please log in using the original method.",
</p> })}
); </p>
);
case "RegistrationDisabled":
return (
<p className="text-destructive text-sm">
{t("authErrorsRegistrationDisabled", {
defaultValue:
"Registration is currently disabled. Please try again later.",
})}
</p>
);
case "EmailNotVerified":
return (
<p className="text-destructive text-sm">
{t("authErrorsEmailNotVerified", {
defaultValue:
"Your email address is not verified. Please verify your email before logging in.",
})}
</p>
);
case "Banned":
return (
<p className="text-destructive text-sm">
{t("authErrorsUserBanned", {
defaultValue:
"This account has been banned. Please contact support if you believe this is an error.",
})}
</p>
);
case "EmailBlocked":
return (
<p className="text-destructive text-sm">
{t("authErrorsEmailBlocked", {
defaultValue:
"This email address is not allowed. Please use a different email or contact support.",
})}
</p>
);
case "UserNotFound":
return (
<p className="text-destructive text-sm">
{t("authErrorsUserNotFound", {
defaultValue:
"No account found with this email address. Please check the email or register for a new account.",
})}
</p>
);
default:
return null;
} }
return null;
} }

View file

@ -6,6 +6,7 @@ import { MicrosoftProvider } from "@/auth/providers/microsoft";
import { OIDCProvider } from "@/auth/providers/oidc"; import { OIDCProvider } from "@/auth/providers/oidc";
import { getTranslation } from "@/i18n/server"; import { getTranslation } from "@/i18n/server";
import { getInstanceSettings } from "@/features/instance-settings/queries";
import { import {
AuthPageContainer, AuthPageContainer,
AuthPageContent, AuthPageContent,
@ -20,19 +21,33 @@ import { LoginWithOIDC } from "./components/login-with-oidc";
import { OrDivider } from "./components/or-divider"; import { OrDivider } from "./components/or-divider";
import { SSOProvider } from "./components/sso-provider"; import { SSOProvider } from "./components/sso-provider";
async function loadData() {
const [instanceSettings, { t }] = await Promise.all([
getInstanceSettings(),
getTranslation(),
]);
return {
instanceSettings,
t,
};
}
export default async function LoginPage(props: { export default async function LoginPage(props: {
searchParams?: Promise<{ searchParams?: Promise<{
redirectTo?: string; redirectTo?: string;
}>; }>;
}) { }) {
const searchParams = await props.searchParams; const searchParams = await props.searchParams;
const { t } = await getTranslation();
const { instanceSettings, t } = await loadData();
const oidcProvider = OIDCProvider(); const oidcProvider = OIDCProvider();
const socialProviders = [GoogleProvider(), MicrosoftProvider()].filter( const socialProviders = [GoogleProvider(), MicrosoftProvider()].filter(
Boolean, Boolean,
); );
const hasAlternateLoginMethods = socialProviders.length > 0 || !!oidcProvider; const hasAlternateLoginMethods = [...socialProviders, oidcProvider].some(
Boolean,
);
return ( return (
<AuthPageContainer> <AuthPageContainer>
@ -74,16 +89,18 @@ export default async function LoginPage(props: {
) : null} ) : null}
</AuthPageContent> </AuthPageContent>
<AuthErrors /> <AuthErrors />
<AuthPageExternal> {!instanceSettings?.disableUserRegistration ? (
<Trans <AuthPageExternal>
t={t} <Trans
i18nKey="loginFooter" t={t}
defaults="Don't have an account? <a>Sign up</a>" i18nKey="loginFooter"
components={{ defaults="Don't have an account? <a>Sign up</a>"
a: <Link className="text-link" href="/register" />, components={{
}} a: <Link className="text-link" href="/register" />,
/> }}
</AuthPageExternal> />
</AuthPageExternal>
) : null}
</AuthPageContainer> </AuthPageContainer>
); );
} }

View file

@ -3,6 +3,8 @@ import { Trans } from "react-i18next/TransWithoutContext";
import { getTranslation } from "@/i18n/server"; import { getTranslation } from "@/i18n/server";
import { getInstanceSettings } from "@/features/instance-settings/queries";
import { notFound } from "next/navigation";
import { import {
AuthPageContainer, AuthPageContainer,
AuthPageContent, AuthPageContent,
@ -18,6 +20,11 @@ export default async function Register(props: {
}) { }) {
const params = await props.params; const params = await props.params;
const { t } = await getTranslation(params.locale); const { t } = await getTranslation(params.locale);
const instanceSettings = await getInstanceSettings();
if (instanceSettings?.disableUserRegistration) {
return notFound();
}
return ( return (
<AuthPageContainer> <AuthPageContainer>

View file

@ -49,7 +49,7 @@ export default async function Layout({
</TopBar> </TopBar>
<LicenseLimitWarning /> <LicenseLimitWarning />
<div className="flex flex-1 flex-col"> <div className="flex flex-1 flex-col">
<div className="flex flex-1 flex-col p-4 md:p-8">{children}</div> <div className="flex flex-1 flex-col">{children}</div>
</div> </div>
<ActionBar /> <ActionBar />
</SidebarInset> </SidebarInset>

View file

@ -17,9 +17,7 @@ export default async function AdminLayout({
<ControlPanelSidebar /> <ControlPanelSidebar />
<SidebarInset> <SidebarInset>
<LicenseLimitWarning /> <LicenseLimitWarning />
<div className="min-w-0 p-4 md:p-8 flex-1 flex-col flex"> <div className="min-w-0 flex-1 flex-col flex">{children}</div>
{children}
</div>
</SidebarInset> </SidebarInset>
</ControlPanelSidebarProvider> </ControlPanelSidebarProvider>
); );

View file

@ -11,7 +11,12 @@ import { getLicense } from "@/features/licensing/queries";
import { prisma } from "@rallly/database"; import { prisma } from "@rallly/database";
import { cn } from "@rallly/ui"; import { cn } from "@rallly/ui";
import { Tile, TileGrid, TileTitle } from "@rallly/ui/tile"; import { Tile, TileGrid, TileTitle } from "@rallly/ui/tile";
import { GaugeIcon, KeySquareIcon, UsersIcon } from "lucide-react"; import {
GaugeIcon,
KeySquareIcon,
SettingsIcon,
UsersIcon,
} from "lucide-react";
import Link from "next/link"; import Link from "next/link";
async function loadData() { async function loadData() {
@ -47,6 +52,7 @@ export default async function AdminPage() {
<Trans i18nKey="homeNavTitle" defaults="Navigation" /> <Trans i18nKey="homeNavTitle" defaults="Navigation" />
</h2> </h2>
<TileGrid> <TileGrid>
{/* USERS */}
<Tile asChild> <Tile asChild>
<Link href="/control-panel/users"> <Link href="/control-panel/users">
<div className="flex justify-between"> <div className="flex justify-between">
@ -79,6 +85,7 @@ export default async function AdminPage() {
</div> </div>
</Link> </Link>
</Tile> </Tile>
{/* LICENSE */}
<Tile asChild> <Tile asChild>
<Link href="/control-panel/license"> <Link href="/control-panel/license">
<div className="flex justify-between"> <div className="flex justify-between">
@ -100,6 +107,19 @@ export default async function AdminPage() {
</TileTitle> </TileTitle>
</Link> </Link>
</Tile> </Tile>
{/* INSTANCE SETTINGS */}
<Tile asChild>
<Link href="/control-panel/settings">
<div className="flex justify-between">
<PageIcon color="darkGray">
<SettingsIcon />
</PageIcon>
</div>
<TileTitle>
<Trans i18nKey="settings" defaults="Settings" />
</TileTitle>
</Link>
</Tile>
</TileGrid> </TileGrid>
</div> </div>
</PageContent> </PageContent>

View file

@ -0,0 +1,21 @@
"use server";
import { requireAdmin } from "@/auth/queries";
import { prisma } from "@rallly/database";
export async function setDisableUserRegistration({
disableUserRegistration,
}: {
disableUserRegistration: boolean;
}) {
await requireAdmin();
await prisma.instanceSettings.update({
where: {
id: 1,
},
data: {
disableUserRegistration,
},
});
}

View file

@ -0,0 +1,36 @@
"use client";
import { Trans } from "@/components/trans";
import { Label } from "@rallly/ui/label";
import { Switch } from "@rallly/ui/switch";
import { setDisableUserRegistration } from "./actions";
export function DisableUserRegistration({
defaultValue,
}: { defaultValue: boolean }) {
return (
<div>
<div className="flex items-center gap-2">
<Switch
id="disable-user-registration"
onCheckedChange={(checked) => {
setDisableUserRegistration({ disableUserRegistration: checked });
}}
defaultChecked={defaultValue}
/>
<Label htmlFor="disable-user-registration">
<Trans
i18nKey="disableUserRegistration"
defaults="Disable User Registration"
/>
</Label>
</div>
<p className="text-sm mt-2 text-muted-foreground">
<Trans
i18nKey="disableUserRegistrationDescription"
defaults="Prevent new users from registering an account."
/>
</p>
</div>
);
}

View file

@ -0,0 +1,62 @@
import { PageIcon } from "@/app/components/page-icons";
import {
FullWidthLayout,
FullWidthLayoutContent,
FullWidthLayoutHeader,
FullWidthLayoutTitle,
} from "@/components/full-width-layout";
import { Trans } from "@/components/trans";
import { getInstanceSettings } from "@/features/instance-settings/queries";
import { SettingsIcon } from "lucide-react";
import { DisableUserRegistration } from "./disable-user-registration";
async function loadData() {
const instanceSettings = await getInstanceSettings();
return {
instanceSettings,
};
}
export default async function SettingsPage() {
const { instanceSettings } = await loadData();
return (
<FullWidthLayout>
<FullWidthLayoutHeader>
<FullWidthLayoutTitle
icon={
<PageIcon size="sm" color="darkGray">
<SettingsIcon />
</PageIcon>
}
>
<Trans i18nKey="settings" defaults="Settings" />
</FullWidthLayoutTitle>
</FullWidthLayoutHeader>
<FullWidthLayoutContent>
<div className="flex flex-col lg:flex-row p-6 gap-6 rounded-lg border">
<div className="lg:w-1/2">
<h2 className="text-base font-semibold">
<Trans
i18nKey="authenticationAndSecurity"
defaults="Authentication & Security"
/>
</h2>
<p className="mt-1 text-muted-foreground text-sm">
<Trans
i18nKey="authenticationAndSecurityDescription"
defaults="Manage authentication and security settings"
/>
</p>
</div>
<div className="flex-1">
<DisableUserRegistration
defaultValue={instanceSettings?.disableUserRegistration}
/>
</div>
</div>
</FullWidthLayoutContent>
</FullWidthLayout>
);
}

View file

@ -10,6 +10,7 @@ import {
ArrowLeftIcon, ArrowLeftIcon,
HomeIcon, HomeIcon,
KeySquareIcon, KeySquareIcon,
SettingsIcon,
UsersIcon, UsersIcon,
} from "lucide-react"; } from "lucide-react";
import type * as React from "react"; import type * as React from "react";
@ -49,6 +50,10 @@ export async function ControlPanelSidebar({
<KeySquareIcon className="size-4" /> <KeySquareIcon className="size-4" />
<Trans i18nKey="license" defaults="License" /> <Trans i18nKey="license" defaults="License" />
</NavItem> </NavItem>
<NavItem href="/control-panel/settings">
<SettingsIcon className="size-4" />
<Trans i18nKey="settings" defaults="Settings" />
</NavItem>
</SidebarMenu> </SidebarMenu>
</SidebarGroup> </SidebarGroup>
</SidebarContent> </SidebarContent>

View file

@ -1,6 +1,6 @@
import "../../style.css"; import "../../style.css";
import { defaultLocale, supportedLngs } from "@rallly/languages"; import { supportedLngs } from "@rallly/languages";
import { PostHogProvider, posthog } from "@rallly/posthog/client"; import { PostHogProvider, posthog } from "@rallly/posthog/client";
import { Toaster } from "@rallly/ui/toaster"; import { Toaster } from "@rallly/ui/toaster";
import { TooltipProvider } from "@rallly/ui/tooltip"; import { TooltipProvider } from "@rallly/ui/tooltip";
@ -15,7 +15,7 @@ import { PreferencesProvider } from "@/contexts/preferences";
import { TimezoneProvider } from "@/features/timezone/client/context"; import { TimezoneProvider } from "@/features/timezone/client/context";
import { I18nProvider } from "@/i18n/client"; import { I18nProvider } from "@/i18n/client";
import { getLocale } from "@/i18n/server/get-locale"; import { getLocale } from "@/i18n/server/get-locale";
import { auth, getUserId } from "@/next-auth"; import { auth } from "@/next-auth";
import { TRPCProvider } from "@/trpc/client/provider"; import { TRPCProvider } from "@/trpc/client/provider";
import { ConnectedDayjsProvider } from "@/utils/dayjs"; import { ConnectedDayjsProvider } from "@/utils/dayjs";
@ -34,25 +34,35 @@ export const viewport: Viewport = {
initialScale: 1, initialScale: 1,
}; };
async function loadData() {
const [session, locale] = await Promise.all([auth(), getLocale()]);
const userId = session?.user?.email ? session.user.id : undefined;
const user = userId ? await getUser(userId) : null;
return {
session,
locale,
user,
};
}
export default async function Root({ export default async function Root({
children, children,
}: { }: {
children: React.ReactNode; children: React.ReactNode;
}) { }) {
const session = await auth(); const { session, locale: fallbackLocale, user } = await loadData();
let locale = await getLocale(); let locale = fallbackLocale;
const userId = await getUserId();
const user = userId ? await getUser(userId) : null;
if (user?.locale) { if (user?.locale) {
locale = user.locale; locale = user.locale;
} }
if (!supportedLngs.includes(locale)) { if (!supportedLngs.includes(locale)) {
locale = defaultLocale; locale = fallbackLocale;
} }
return ( return (

View file

@ -8,10 +8,12 @@ import { UserDropdown } from "@/components/user-dropdown";
import { getTranslation } from "@/i18n/server"; import { getTranslation } from "@/i18n/server";
import { getLoggedIn } from "@/next-auth"; import { getLoggedIn } from "@/next-auth";
import { getInstanceSettings } from "@/features/instance-settings/queries";
import { BackButton } from "./back-button"; import { BackButton } from "./back-button";
export default async function Page() { export default async function Page() {
const isLoggedIn = await getLoggedIn(); const isLoggedIn = await getLoggedIn();
const instanceSettings = await getInstanceSettings();
return ( return (
<div> <div>
@ -42,13 +44,15 @@ export default async function Page() {
<Trans i18nKey="login" defaults="Login" /> <Trans i18nKey="login" defaults="Login" />
</Link> </Link>
</Button> </Button>
<Button variant="primary" asChild> {instanceSettings?.disableUserRegistration ? null : (
<Link <Button variant="primary" asChild>
href={`/register?redirectTo=${encodeURIComponent("/new")}`} <Link
> href={`/register?redirectTo=${encodeURIComponent("/new")}`}
<Trans i18nKey="signUp" defaults="Sign up" /> >
</Link> <Trans i18nKey="signUp" defaults="Sign up" />
</Button> </Link>
</Button>
)}
</div> </div>
)} )}
</div> </div>

View file

@ -29,7 +29,7 @@ const pageIconVariants = cva("inline-flex items-center justify-center", {
purple: "bg-purple-500 text-white", purple: "bg-purple-500 text-white",
}, },
size: { size: {
sm: "size-6 [&_svg]:size-3 rounded-md", sm: "size-7 [&_svg]:size-4 rounded-md",
md: "size-8 [&_svg]:size-5 rounded-lg", md: "size-8 [&_svg]:size-5 rounded-lg",
lg: "size-9 [&_svg]:size-5 rounded-lg", lg: "size-9 [&_svg]:size-5 rounded-lg",
xl: "size-10 [&_svg]:size-5 rounded-lg", xl: "size-10 [&_svg]:size-5 rounded-lg",

View file

@ -8,7 +8,9 @@ export function PageContainer({
className, className,
}: React.PropsWithChildren<{ className?: string }>) { }: React.PropsWithChildren<{ className?: string }>) {
return ( return (
<div className={cn("mx-auto w-full max-w-7xl", className)}>{children}</div> <div className={cn("mx-auto w-full p-4 md:p-8 max-w-7xl", className)}>
{children}
</div>
); );
} }

View file

@ -0,0 +1,31 @@
export function FullWidthLayout({ children }: { children: React.ReactNode }) {
return <div>{children}</div>;
}
export function FullWidthLayoutHeader({
children,
}: { children: React.ReactNode }) {
return (
<header className="py-4 rounded-t-lg bg-background/90 backdrop-blur-sm sticky top-0 z-10 px-6 border-b">
{children}
</header>
);
}
export function FullWidthLayoutContent({
children,
}: { children: React.ReactNode }) {
return <main className="p-6">{children}</main>;
}
export function FullWidthLayoutTitle({
children,
icon,
}: { children: React.ReactNode; icon?: React.ReactNode }) {
return (
<div className="flex items-center gap-2">
{icon}
<h1 className="text-xl font-semibold">{children}</h1>
</div>
);
}

View file

@ -1,17 +1,14 @@
"use client"; "use client";
import React from "react"; import React from "react";
import type { Feature, FeatureFlagConfig } from "./types";
interface Features { const FeatureFlagsContext = React.createContext<FeatureFlagConfig | undefined>(
storage: boolean;
}
const FeatureFlagsContext = React.createContext<Features | undefined>(
undefined, undefined,
); );
interface FeatureFlagsProviderProps { interface FeatureFlagsProviderProps {
value: Features; value: FeatureFlagConfig;
children: React.ReactNode; children: React.ReactNode;
} }
@ -26,7 +23,7 @@ export function FeatureFlagsProvider({
); );
} }
export function useFeatureFlag(featureName: keyof Features): boolean { export function useFeatureFlag(featureName: Feature): boolean {
const context = React.useContext(FeatureFlagsContext); const context = React.useContext(FeatureFlagsContext);
if (context === undefined) { if (context === undefined) {
throw new Error( throw new Error(
@ -35,3 +32,14 @@ export function useFeatureFlag(featureName: keyof Features): boolean {
} }
return context[featureName] ?? false; return context[featureName] ?? false;
} }
export function IfFeatureEnabled({
feature,
children,
}: {
feature: Feature;
children: React.ReactNode;
}) {
const featureEnabled = useFeatureFlag(feature);
return featureEnabled ? children : null;
}

View file

@ -0,0 +1,5 @@
export interface FeatureFlagConfig {
storage: boolean;
}
export type Feature = keyof FeatureFlagConfig;

View file

@ -0,0 +1,18 @@
"server-only";
import { prisma } from "@rallly/database";
import { cache } from "react";
export const getInstanceSettings = cache(async () => {
const instanceSettings = await prisma.instanceSettings.findUnique({
where: {
id: 1,
},
select: {
disableUserRegistration: true,
},
});
return {
disableUserRegistration: instanceSettings?.disableUserRegistration ?? false,
};
});

View file

@ -13,6 +13,7 @@ import {
ArrowRightIcon, ArrowRightIcon,
KeySquareIcon, KeySquareIcon,
PlusIcon, PlusIcon,
SettingsIcon,
UsersIcon, UsersIcon,
} from "lucide-react"; } from "lucide-react";
import { useRouter } from "next/navigation"; import { useRouter } from "next/navigation";
@ -155,6 +156,18 @@ export function CommandMenu() {
})} })}
/> />
</CommandItem> </CommandItem>
<CommandItem
onSelect={() => handleSelect("/control-panel/settings")}
>
<PageIcon size="sm">
<SettingsIcon />
</PageIcon>
<NavigationCommandLabel
label={t("settings", {
defaultValue: "Settings",
})}
/>
</CommandItem>
</CommandGroup> </CommandGroup>
)} )}
</CommandList> </CommandList>

View file

@ -1 +1,4 @@
export const isQuickCreateEnabled = process.env.QUICK_CREATE_ENABLED === "true"; import { isSelfHosted } from "@/utils/constants";
export const isQuickCreateEnabled =
!isSelfHosted && process.env.QUICK_CREATE_ENABLED === "true";

View file

@ -14,6 +14,7 @@ import { GuestProvider } from "./auth/providers/guest";
import { MicrosoftProvider } from "./auth/providers/microsoft"; import { MicrosoftProvider } from "./auth/providers/microsoft";
import { OIDCProvider } from "./auth/providers/oidc"; import { OIDCProvider } from "./auth/providers/oidc";
import { RegistrationTokenProvider } from "./auth/providers/registration-token"; import { RegistrationTokenProvider } from "./auth/providers/registration-token";
import { getInstanceSettings } from "./features/instance-settings/queries";
import { nextAuthConfig } from "./next-auth.config"; import { nextAuthConfig } from "./next-auth.config";
const { const {
@ -94,40 +95,7 @@ const {
}, },
callbacks: { callbacks: {
...nextAuthConfig.callbacks, ...nextAuthConfig.callbacks,
async signIn({ user, email, profile }) { async signIn({ user, email, profile, account }) {
const distinctId = user.id;
// prevent sign in if email is not verified
if (
profile &&
"email_verified" in profile &&
profile.email_verified === false &&
distinctId
) {
posthog?.capture({
distinctId,
event: "login failed",
properties: {
reason: "email not verified",
},
});
return false;
}
if (user.banned) {
return false;
}
// Make sure email is allowed
if (user.email) {
if (isEmailBlocked(user.email) || (await isEmailBanned(user.email))) {
return false;
}
}
// For now, we don't allow users to login unless they have
// registered an account. This is just because we need a name
// to display on the dashboard. The flow can be modified so that
// the name is requested after the user has logged in.
if (email?.verificationRequest) { if (email?.verificationRequest) {
const isRegisteredUser = const isRegisteredUser =
(await prisma.user.count({ (await prisma.user.count({
@ -135,19 +103,39 @@ const {
email: user.email as string, email: user.email as string,
}, },
})) > 0; })) > 0;
if (!isRegisteredUser) {
return isRegisteredUser; return "/login?error=EmailNotVerified";
}
} }
// when we login with a social account for the first time, the user is not created yet if (user.banned) {
// and the user id will be the same as the provider account id return "/login?error=Banned";
// we handle this case the the prisma adapter when we link accounts }
const isInitialSocialLogin = user.id === profile?.sub;
if (!isInitialSocialLogin) { // Make sure email is allowed
const emailToTest = user.email || profile?.email;
if (emailToTest) {
if (isEmailBlocked(emailToTest) || (await isEmailBanned(emailToTest))) {
return "/login?error=EmailBlocked";
}
}
const isNewUser = !user.role && profile;
// Check for new user login with OAuth provider
if (isNewUser) {
// If role isn't set than the user doesn't exist yet
// This can happen if logging in with an OAuth provider
const instanceSettings = await getInstanceSettings();
if (instanceSettings?.disableUserRegistration) {
return "/login?error=RegistrationDisabled";
}
}
if (!isNewUser && user.id) {
// merge guest user into newly logged in user // merge guest user into newly logged in user
const session = await auth(); const session = await auth();
if (user.id && session?.user && !session.user.email) { if (session?.user && !session.user.email) {
await mergeGuestsIntoUser(user.id, [session.user.id]); await mergeGuestsIntoUser(user.id, [session.user.id]);
} }
} }

View file

@ -11,6 +11,8 @@ import { getEmailClient } from "@/utils/emails";
import { isValidName } from "@/utils/is-valid-name"; import { isValidName } from "@/utils/is-valid-name";
import { createToken, decryptToken } from "@/utils/session"; import { createToken, decryptToken } from "@/utils/session";
import { getInstanceSettings } from "@/features/instance-settings/queries";
import { TRPCError } from "@trpc/server";
import { createRateLimitMiddleware, publicProcedure, router } from "../trpc"; import { createRateLimitMiddleware, publicProcedure, router } from "../trpc";
import type { RegistrationTokenPayload } from "../types"; import type { RegistrationTokenPayload } from "../types";
@ -52,6 +54,14 @@ export const auth = router({
| "temporaryEmailNotAllowed"; | "temporaryEmailNotAllowed";
} }
> => { > => {
const instanceSettings = await getInstanceSettings();
if (instanceSettings.disableUserRegistration) {
throw new TRPCError({
code: "BAD_REQUEST",
message: "User registration is disabled",
});
}
if (isEmailBlocked?.(input.email)) { if (isEmailBlocked?.(input.email)) {
return { ok: false, reason: "emailNotAllowed" }; return { ok: false, reason: "emailNotAllowed" };
} }

View file

@ -36,12 +36,12 @@ test.describe("Admin Setup Page Access", () => {
test("should allow access if user is the designated initial admin (and not yet admin role)", async ({ test("should allow access if user is the designated initial admin (and not yet admin role)", async ({
page, page,
}) => { }) => {
await createUserInDb( await createUserInDb({
INITIAL_ADMIN_TEST_EMAIL, email: INITIAL_ADMIN_TEST_EMAIL,
"Initial Admin User", name: "Initial Admin User",
"user", role: "user",
); });
await loginWithEmail(page, INITIAL_ADMIN_TEST_EMAIL); await loginWithEmail(page, { email: INITIAL_ADMIN_TEST_EMAIL });
await page.goto("/admin-setup"); await page.goto("/admin-setup");
await expect(page).toHaveURL(/.*\/admin-setup/); await expect(page).toHaveURL(/.*\/admin-setup/);
@ -54,8 +54,12 @@ test.describe("Admin Setup Page Access", () => {
test("should show 'not found' for a regular user (not initial admin, not admin role)", async ({ test("should show 'not found' for a regular user (not initial admin, not admin role)", async ({
page, page,
}) => { }) => {
await createUserInDb(REGULAR_USER_EMAIL, "Regular User", "user"); await createUserInDb({
await loginWithEmail(page, REGULAR_USER_EMAIL); email: REGULAR_USER_EMAIL,
name: "Regular User",
role: "user",
});
await loginWithEmail(page, { email: REGULAR_USER_EMAIL });
await page.goto("/admin-setup"); await page.goto("/admin-setup");
await expect(page.getByText("404 not found")).toBeVisible(); await expect(page.getByText("404 not found")).toBeVisible();
@ -64,8 +68,12 @@ test.describe("Admin Setup Page Access", () => {
test("should redirect an existing admin user to control-panel", async ({ test("should redirect an existing admin user to control-panel", async ({
page, page,
}) => { }) => {
await createUserInDb(SUBSEQUENT_ADMIN_EMAIL, "Existing Admin", "admin"); await createUserInDb({
await loginWithEmail(page, SUBSEQUENT_ADMIN_EMAIL); email: SUBSEQUENT_ADMIN_EMAIL,
name: "Existing Admin",
role: "admin",
});
await loginWithEmail(page, { email: SUBSEQUENT_ADMIN_EMAIL });
await page.goto("/admin-setup"); await page.goto("/admin-setup");
await expect(page).toHaveURL(/.*\/control-panel/); await expect(page).toHaveURL(/.*\/control-panel/);
@ -74,8 +82,12 @@ test.describe("Admin Setup Page Access", () => {
test("should show 'not found' if INITIAL_ADMIN_EMAIL in env is different from user's email", async ({ test("should show 'not found' if INITIAL_ADMIN_EMAIL in env is different from user's email", async ({
page, page,
}) => { }) => {
await createUserInDb(OTHER_USER_EMAIL, "Other User", "user"); await createUserInDb({
await loginWithEmail(page, OTHER_USER_EMAIL); email: OTHER_USER_EMAIL,
name: "Other User",
role: "user",
});
await loginWithEmail(page, { email: OTHER_USER_EMAIL });
await page.goto("/admin-setup"); await page.goto("/admin-setup");
await expect(page.getByText("404 not found")).toBeVisible(); await expect(page.getByText("404 not found")).toBeVisible();
@ -84,12 +96,12 @@ test.describe("Admin Setup Page Access", () => {
test("initial admin can make themselves admin using the button", async ({ test("initial admin can make themselves admin using the button", async ({
page, page,
}) => { }) => {
await createUserInDb( await createUserInDb({
INITIAL_ADMIN_TEST_EMAIL, email: INITIAL_ADMIN_TEST_EMAIL,
"Initial Admin To Be", name: "Initial Admin To Be",
"user", role: "user",
); });
await loginWithEmail(page, INITIAL_ADMIN_TEST_EMAIL); await loginWithEmail(page, { email: INITIAL_ADMIN_TEST_EMAIL });
await page.goto("/admin-setup"); await page.goto("/admin-setup");
await expect(page.getByText("Are you the admin?")).toBeVisible(); await expect(page.getByText("Are you the admin?")).toBeVisible();

View file

@ -4,21 +4,21 @@ import { load } from "cheerio";
import { captureEmailHTML } from "./mailpit/mailpit"; import { captureEmailHTML } from "./mailpit/mailpit";
import { RegisterPage } from "./register-page"; import { RegisterPage } from "./register-page";
import { createUserInDb, loginWithEmail } from "./test-utils";
import { getCode } from "./utils"; import { getCode } from "./utils";
const testUserEmail = "test@example.com"; const testUserEmail = "test@example.com";
const testExistingUserEmail = "existing-user-for-disabled-test@example.com";
test.describe.serial(() => { test.describe.serial(() => {
test.afterAll(async () => { test.afterAll(async () => {
try { await prisma.user.deleteMany({
await prisma.user.deleteMany({ where: {
where: { email: {
email: testUserEmail, in: [testUserEmail, testExistingUserEmail],
}, },
}); },
} catch { });
// User doesn't exist
}
}); });
test.describe("new user", () => { test.describe("new user", () => {
@ -140,4 +140,36 @@ test.describe.serial(() => {
await expect(page.getByText("Test User")).toBeVisible(); await expect(page.getByText("Test User")).toBeVisible();
}); });
}); });
test.describe("when user registration is disabled", () => {
test.beforeAll(async () => {
await prisma.instanceSettings.update({
where: { id: 1 },
data: {
disableUserRegistration: true,
},
});
});
test.afterAll(async () => {
await prisma.instanceSettings.update({
where: { id: 1 },
data: {
disableUserRegistration: false,
},
});
});
test("allows existing user to log in via email", async ({ page }) => {
await createUserInDb({
email: testExistingUserEmail,
name: "Existing User",
role: "user",
});
await loginWithEmail(page, { email: testExistingUserEmail });
await expect(page).toHaveURL("/");
});
});
}); });

View file

@ -1,12 +1,16 @@
import type { Page } from "@playwright/test"; import type { Page } from "@playwright/test";
import { prisma } from "@rallly/database"; import { type UserRole, prisma } from "@rallly/database";
import { LoginPage } from "./login-page"; import { LoginPage } from "./login-page";
export async function createUserInDb( export async function createUserInDb({
email: string, email,
name: string, name,
role: "user" | "admin" = "user", role = "user",
) { }: {
email: string;
name: string;
role?: UserRole;
}) {
return prisma.user.create({ return prisma.user.create({
data: { data: {
email, email,
@ -19,7 +23,7 @@ export async function createUserInDb(
}); });
} }
export async function loginWithEmail(page: Page, email: string) { export async function loginWithEmail(page: Page, { email }: { email: string }) {
const loginPage = new LoginPage(page); const loginPage = new LoginPage(page);
await loginPage.goto(); await loginPage.goto();
await loginPage.login({ await loginPage.login({

View file

@ -0,0 +1,12 @@
-- CreateTable
CREATE TABLE "instance_settings" (
"id" INTEGER NOT NULL DEFAULT 1,
"disable_user_registration" BOOLEAN NOT NULL DEFAULT false,
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT "instance_settings_pkey" PRIMARY KEY ("id")
);
-- Create default instance settings
INSERT INTO "instance_settings" ("id", "disable_user_registration") VALUES (1, false);

View file

@ -0,0 +1,16 @@
ALTER TABLE "instance_settings"
ADD CONSTRAINT instance_settings_singleton CHECK (id = 1);
CREATE OR REPLACE FUNCTION prevent_delete_instance_settings()
RETURNS TRIGGER AS $$
BEGIN
IF OLD.id = 1 THEN
RAISE EXCEPTION 'Deleting the instance_settings record (id=1) is not permitted.';
END IF;
RETURN OLD;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trg_prevent_instance_settings_deletion
BEFORE DELETE ON instance_settings
FOR EACH ROW EXECUTE FUNCTION prevent_delete_instance_settings();

View file

@ -0,0 +1,10 @@
model InstanceSettings {
id Int @id @default(1)
// Authentication & Security
disableUserRegistration Boolean @default(false) @map("disable_user_registration")
createdAt DateTime @default(now()) @map("created_at")
updatedAt DateTime @default(now()) @updatedAt @map("updated_at")
@@map("instance_settings")
}

View file

@ -10,16 +10,16 @@ enum LicenseStatus {
} }
model License { model License {
id String @id @default(cuid()) id String @id @default(cuid())
licenseKey String @unique @map("license_key") licenseKey String @unique @map("license_key")
version Int? @map("version") version Int? @map("version")
type LicenseType type LicenseType
seats Int? @map("seats") seats Int? @map("seats")
issuedAt DateTime @default(now()) @map("issued_at") issuedAt DateTime @default(now()) @map("issued_at")
expiresAt DateTime? @map("expires_at") expiresAt DateTime? @map("expires_at")
licenseeEmail String? @map("licensee_email") licenseeEmail String? @map("licensee_email")
licenseeName String? @map("licensee_name") licenseeName String? @map("licensee_name")
status LicenseStatus @default(ACTIVE) @map("status") status LicenseStatus @default(ACTIVE) @map("status")
validations LicenseValidation[] validations LicenseValidation[]
@ -38,18 +38,17 @@ model LicenseValidation {
@@map("license_validations") @@map("license_validations")
} }
model InstanceLicense { model InstanceLicense {
id String @id @default(cuid()) id String @id @default(cuid())
licenseKey String @unique @map("license_key") licenseKey String @unique @map("license_key")
version Int? @map("version") version Int? @map("version")
type LicenseType type LicenseType
seats Int? @map("seats") seats Int? @map("seats")
issuedAt DateTime @default(now()) @map("issued_at") issuedAt DateTime @default(now()) @map("issued_at")
expiresAt DateTime? @map("expires_at") expiresAt DateTime? @map("expires_at")
licenseeEmail String? @map("licensee_email") licenseeEmail String? @map("licensee_email")
licenseeName String? @map("licensee_name") licenseeName String? @map("licensee_name")
status LicenseStatus @default(ACTIVE) @map("status") status LicenseStatus @default(ACTIVE) @map("status")
@@map("instance_licenses") @@map("instance_licenses")
} }