diff --git a/src/lib/server/db/client.ts b/src/lib/server/db/client.ts index 7a7b58b..37a1054 100644 --- a/src/lib/server/db/client.ts +++ b/src/lib/server/db/client.ts @@ -119,26 +119,21 @@ export const registerUserClientChallenge = async ( }); }; -export const getUserClientChallenge = async (answer: string, ip: string) => { +export const consumeUserClientChallenge = async (userId: number, answer: string, ip: string) => { const challenges = await db - .select() - .from(userClientChallenge) + .delete(userClientChallenge) .where( and( + eq(userClientChallenge.userId, userId), eq(userClientChallenge.answer, answer), eq(userClientChallenge.allowedIp, ip), gt(userClientChallenge.expiresAt, new Date()), - eq(userClientChallenge.isUsed, false), ), ) - .limit(1); + .returning({ clientId: userClientChallenge.clientId }); return challenges[0] ?? null; }; -export const markUserClientChallengeAsUsed = async (id: number) => { - await db.update(userClientChallenge).set({ isUsed: true }).where(eq(userClientChallenge.id, id)); -}; - export const cleanupExpiredUserClientChallenges = async () => { await db.delete(userClientChallenge).where(lte(userClientChallenge.expiresAt, new Date())); }; diff --git a/src/lib/server/db/schema/client.ts b/src/lib/server/db/schema/client.ts index eacd9c9..1e9eb85 100644 --- a/src/lib/server/db/schema/client.ts +++ b/src/lib/server/db/schema/client.ts @@ -1,4 +1,11 @@ -import { sqliteTable, text, integer, primaryKey, unique } from "drizzle-orm/sqlite-core"; +import { + sqliteTable, + text, + integer, + primaryKey, + foreignKey, + unique, +} from "drizzle-orm/sqlite-core"; import { user } from "./user"; export const client = sqliteTable( @@ -31,16 +38,24 @@ export const userClient = sqliteTable( }), ); -export const userClientChallenge = sqliteTable("user_client_challenge", { - id: integer("id").primaryKey(), - userId: integer("user_id") - .notNull() - .references(() => user.id), - clientId: integer("client_id") - .notNull() - .references(() => client.id), - answer: text("answer").notNull().unique(), // Base64 - allowedIp: text("allowed_ip").notNull(), - expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(), - isUsed: integer("is_used", { mode: "boolean" }).notNull().default(false), -}); +export const userClientChallenge = sqliteTable( + "user_client_challenge", + { + id: integer("id").primaryKey(), + userId: integer("user_id") + .notNull() + .references(() => user.id), + clientId: integer("client_id") + .notNull() + .references(() => client.id), + answer: text("answer").notNull().unique(), // Base64 + allowedIp: text("allowed_ip").notNull(), + expiresAt: integer("expires_at", { mode: "timestamp_ms" }).notNull(), + }, + (t) => ({ + ref: foreignKey({ + columns: [t.userId, t.clientId], + foreignColumns: [userClient.userId, userClient.clientId], + }), + }), +); diff --git a/src/lib/server/services/client.ts b/src/lib/server/services/client.ts index 73973bb..cc706e1 100644 --- a/src/lib/server/services/client.ts +++ b/src/lib/server/services/client.ts @@ -8,8 +8,7 @@ import { getUserClient, setUserClientStateToPending, registerUserClientChallenge, - getUserClientChallenge, - markUserClientChallengeAsUsed, + consumeUserClientChallenge, } from "$lib/server/db/client"; import { IntegrityError } from "$lib/server/db/error"; import { verifyPubKey, verifySignature, generateChallenge } from "$lib/server/modules/crypto"; @@ -81,15 +80,11 @@ export const verifyUserClient = async ( answer: string, answerSig: string, ) => { - const challenge = await getUserClientChallenge(answer, ip); + const challenge = await consumeUserClientChallenge(userId, answer, ip); if (!challenge) { error(403, "Invalid challenge answer"); - } else if (challenge.userId !== userId) { - error(403, "Forbidden"); } - await markUserClientChallengeAsUsed(challenge.id); - const client = await getClient(challenge.clientId); if (!client) { error(500, "Invalid challenge answer"); @@ -97,7 +92,7 @@ export const verifyUserClient = async ( error(403, "Invalid challenge answer signature"); } - await setUserClientStateToPending(userId, challenge.clientId); + await setUserClientStateToPending(userId, client.id); }; export const getUserClientStatus = async (userId: number, clientId: number) => {