From 0bdf990dae5b9eabaf5eb5796fa7ef4372ad3943 Mon Sep 17 00:00:00 2001 From: static Date: Sat, 11 Jan 2025 03:55:19 +0900 Subject: [PATCH] =?UTF-8?q?DB=EC=97=90=20=EB=8F=99=EC=8B=9C=EC=A0=81?= =?UTF-8?q?=EC=9C=BC=EB=A1=9C=20=EC=A0=91=EA=B7=BC=ED=95=98=EB=8D=94?= =?UTF-8?q?=EB=9D=BC=EB=8F=84=20=EB=8D=B0=EC=9D=B4=ED=84=B0=20=EB=AC=B4?= =?UTF-8?q?=EA=B2=B0=EC=84=B1=EC=9D=B4=20=EA=B9=A8=EC=A7=80=EC=A7=80=20?= =?UTF-8?q?=EC=95=8A=EB=8F=84=EB=A1=9D=20DB=20=EC=A0=91=EA=B7=BC=20?= =?UTF-8?q?=EC=BD=94=EB=93=9C=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/lib/server/db/client.ts | 103 ++++----- src/lib/server/db/error.ts | 21 ++ src/lib/server/db/file.ts | 227 +++++++++++-------- src/lib/server/db/mek.ts | 58 ++--- src/lib/server/db/token.ts | 120 +++++----- src/lib/server/db/user.ts | 2 +- src/lib/server/services/auth.ts | 62 +++-- src/lib/server/services/client.ts | 90 ++++---- src/lib/server/services/directory.ts | 80 +++---- src/lib/server/services/file.ts | 141 +++++------- src/lib/server/services/mek.ts | 18 +- src/routes/api/file/[id]/download/+server.ts | 2 +- 12 files changed, 486 insertions(+), 438 deletions(-) create mode 100644 src/lib/server/db/error.ts diff --git a/src/lib/server/db/client.ts b/src/lib/server/db/client.ts index bdf2404..7a7b58b 100644 --- a/src/lib/server/db/client.ts +++ b/src/lib/server/db/client.ts @@ -1,30 +1,36 @@ -import { and, or, eq, gt, lte, count } from "drizzle-orm"; +import { SqliteError } from "better-sqlite3"; +import { and, or, eq, gt, lte } from "drizzle-orm"; import db from "./drizzle"; +import { IntegrityError } from "./error"; import { client, userClient, userClientChallenge } from "./schema"; export const createClient = async (encPubKey: string, sigPubKey: string, userId: number) => { - return await db.transaction(async (tx) => { - const clients = await tx - .select() - .from(client) - .where(or(eq(client.encPubKey, sigPubKey), eq(client.sigPubKey, encPubKey))); - if (clients.length > 0) { - throw new Error("Already used public key(s)"); - } + return await db.transaction( + async (tx) => { + const clients = await tx + .select({ id: client.id }) + .from(client) + .where(or(eq(client.encPubKey, sigPubKey), eq(client.sigPubKey, encPubKey))) + .limit(1); + if (clients.length !== 0) { + throw new IntegrityError("Public key(s) already registered"); + } - const insertRes = await tx - .insert(client) - .values({ encPubKey, sigPubKey }) - .returning({ id: client.id }); - const { id: clientId } = insertRes[0]!; - await tx.insert(userClient).values({ userId, clientId }); + const newClients = await tx + .insert(client) + .values({ encPubKey, sigPubKey }) + .returning({ id: client.id }); + const { id: clientId } = newClients[0]!; + await tx.insert(userClient).values({ userId, clientId }); - return clientId; - }); + return clientId; + }, + { behavior: "exclusive" }, + ); }; export const getClient = async (clientId: number) => { - const clients = await db.select().from(client).where(eq(client.id, clientId)).execute(); + const clients = await db.select().from(client).where(eq(client.id, clientId)).limit(1); return clients[0] ?? null; }; @@ -33,24 +39,23 @@ export const getClientByPubKeys = async (encPubKey: string, sigPubKey: string) = .select() .from(client) .where(and(eq(client.encPubKey, encPubKey), eq(client.sigPubKey, sigPubKey))) - .execute(); + .limit(1); return clients[0] ?? null; }; -export const countClientByPubKey = async (pubKey: string) => { - const clients = await db - .select({ count: count() }) - .from(client) - .where(or(eq(client.encPubKey, pubKey), eq(client.encPubKey, pubKey))); - return clients[0]?.count ?? 0; -}; - export const createUserClient = async (userId: number, clientId: number) => { - await db.insert(userClient).values({ userId, clientId }).execute(); + try { + await db.insert(userClient).values({ userId, clientId }); + } catch (e) { + if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_PRIMARYKEY") { + throw new IntegrityError("User client already exists"); + } + throw e; + } }; export const getAllUserClients = async (userId: number) => { - return await db.select().from(userClient).where(eq(userClient.userId, userId)).execute(); + return await db.select().from(userClient).where(eq(userClient.userId, userId)); }; export const getUserClient = async (userId: number, clientId: number) => { @@ -58,7 +63,7 @@ export const getUserClient = async (userId: number, clientId: number) => { .select() .from(userClient) .where(and(eq(userClient.userId, userId), eq(userClient.clientId, clientId))) - .execute(); + .limit(1); return userClients[0] ?? null; }; @@ -68,7 +73,7 @@ export const getUserClientWithDetails = async (userId: number, clientId: number) .from(userClient) .innerJoin(client, eq(userClient.clientId, client.id)) .where(and(eq(userClient.userId, userId), eq(userClient.clientId, clientId))) - .execute(); + .limit(1); return userClients[0] ?? null; }; @@ -82,8 +87,7 @@ export const setUserClientStateToPending = async (userId: number, clientId: numb eq(userClient.clientId, clientId), eq(userClient.state, "challenging"), ), - ) - .execute(); + ); }; export const setUserClientStateToActive = async (userId: number, clientId: number) => { @@ -96,8 +100,7 @@ export const setUserClientStateToActive = async (userId: number, clientId: numbe eq(userClient.clientId, clientId), eq(userClient.state, "pending"), ), - ) - .execute(); + ); }; export const registerUserClientChallenge = async ( @@ -107,16 +110,13 @@ export const registerUserClientChallenge = async ( allowedIp: string, expiresAt: Date, ) => { - await db - .insert(userClientChallenge) - .values({ - userId, - clientId, - answer, - allowedIp, - expiresAt, - }) - .execute(); + await db.insert(userClientChallenge).values({ + userId, + clientId, + answer, + allowedIp, + expiresAt, + }); }; export const getUserClientChallenge = async (answer: string, ip: string) => { @@ -131,21 +131,14 @@ export const getUserClientChallenge = async (answer: string, ip: string) => { eq(userClientChallenge.isUsed, false), ), ) - .execute(); + .limit(1); return challenges[0] ?? null; }; export const markUserClientChallengeAsUsed = async (id: number) => { - await db - .update(userClientChallenge) - .set({ isUsed: true }) - .where(eq(userClientChallenge.id, id)) - .execute(); + await db.update(userClientChallenge).set({ isUsed: true }).where(eq(userClientChallenge.id, id)); }; export const cleanupExpiredUserClientChallenges = async () => { - await db - .delete(userClientChallenge) - .where(lte(userClientChallenge.expiresAt, new Date())) - .execute(); + await db.delete(userClientChallenge).where(lte(userClientChallenge.expiresAt, new Date())); }; diff --git a/src/lib/server/db/error.ts b/src/lib/server/db/error.ts new file mode 100644 index 0000000..7644800 --- /dev/null +++ b/src/lib/server/db/error.ts @@ -0,0 +1,21 @@ +type IntegrityErrorMessages = + // Client + | "Public key(s) already registered" + | "User client already exists" + // File + | "Directory not found" + | "File not found" + | "Invalid DEK version" + // MEK + | "MEK already registered" + | "Inactive MEK version" + // Token + | "Refresh token not found" + | "Refresh token already registered"; + +export class IntegrityError extends Error { + constructor(public message: IntegrityErrorMessages) { + super(message); + this.name = "IntegrityError"; + } +} diff --git a/src/lib/server/db/file.ts b/src/lib/server/db/file.ts index 2fe4b53..270add3 100644 --- a/src/lib/server/db/file.ts +++ b/src/lib/server/db/file.ts @@ -1,5 +1,6 @@ import { and, eq, isNull } from "drizzle-orm"; import db from "./drizzle"; +import { IntegrityError } from "./error"; import { directory, file, mek } from "./schema"; type DirectoryId = "root" | number; @@ -27,40 +28,42 @@ export interface NewFileParams { encNameIv: string; } -export const registerNewDirectory = async (params: NewDirectoryParams) => { - return await db.transaction(async (tx) => { - const meks = await tx - .select() - .from(mek) - .where(and(eq(mek.userId, params.userId), eq(mek.state, "active"))); - if (meks[0]?.version !== params.mekVersion) { - throw new Error("Invalid MEK version"); - } +export const registerDirectory = async (params: NewDirectoryParams) => { + await db.transaction( + async (tx) => { + const meks = await tx + .select({ version: mek.version }) + .from(mek) + .where(and(eq(mek.userId, params.userId), eq(mek.state, "active"))) + .limit(1); + if (meks[0]?.version !== params.mekVersion) { + throw new IntegrityError("Inactive MEK version"); + } - const now = new Date(); - await tx.insert(directory).values({ - createdAt: now, - parentId: params.parentId === "root" ? null : params.parentId, - userId: params.userId, - mekVersion: params.mekVersion, - encDek: params.encDek, - dekVersion: params.dekVersion, - encName: { ciphertext: params.encName, iv: params.encNameIv }, - }); - }); + await tx.insert(directory).values({ + createdAt: new Date(), + parentId: params.parentId === "root" ? null : params.parentId, + userId: params.userId, + mekVersion: params.mekVersion, + encDek: params.encDek, + dekVersion: params.dekVersion, + encName: { ciphertext: params.encName, iv: params.encNameIv }, + }); + }, + { behavior: "exclusive" }, + ); }; -export const getAllDirectoriesByParent = async (userId: number, directoryId: DirectoryId) => { +export const getAllDirectoriesByParent = async (userId: number, parentId: DirectoryId) => { return await db .select() .from(directory) .where( and( eq(directory.userId, userId), - directoryId === "root" ? isNull(directory.parentId) : eq(directory.parentId, directoryId), + parentId === "root" ? isNull(directory.parentId) : eq(directory.parentId, parentId), ), - ) - .execute(); + ); }; export const getDirectory = async (userId: number, directoryId: number) => { @@ -68,7 +71,7 @@ export const getDirectory = async (userId: number, directoryId: number) => { .select() .from(directory) .where(and(eq(directory.userId, userId), eq(directory.id, directoryId))) - .execute(); + .limit(1); return res[0] ?? null; }; @@ -79,72 +82,87 @@ export const setDirectoryEncName = async ( encName: string, encNameIv: string, ) => { - const res = await db - .update(directory) - .set({ encName: { ciphertext: encName, iv: encNameIv } }) - .where( - and( - eq(directory.userId, userId), - eq(directory.id, directoryId), - eq(directory.dekVersion, dekVersion), - ), - ) - .execute(); - return res.changes > 0; + await db.transaction( + async (tx) => { + const directories = await tx + .select({ version: directory.dekVersion }) + .from(directory) + .where(and(eq(directory.userId, userId), eq(directory.id, directoryId))) + .limit(1); + if (!directories[0]) { + throw new IntegrityError("Directory not found"); + } else if (directories[0].version.getTime() !== dekVersion.getTime()) { + throw new IntegrityError("Invalid DEK version"); + } + + await tx + .update(directory) + .set({ encName: { ciphertext: encName, iv: encNameIv } }) + .where(and(eq(directory.userId, userId), eq(directory.id, directoryId))); + }, + { behavior: "exclusive" }, + ); }; export const unregisterDirectory = async (userId: number, directoryId: number) => { - return await db.transaction(async (tx) => { - const getFilePaths = async (parentId: number) => { - const files = await tx - .select({ path: file.path }) - .from(file) - .where(and(eq(file.userId, userId), eq(file.parentId, parentId))); - return files.map(({ path }) => path); - }; - const unregisterSubDirectoriesRecursively = async (directoryId: number): Promise => { - const subDirectories = await tx - .select({ id: directory.id }) - .from(directory) - .where(and(eq(directory.userId, userId), eq(directory.parentId, directoryId))); - const subDirectoryFilePaths = await Promise.all( - subDirectories.map(async ({ id }) => await unregisterSubDirectoriesRecursively(id)), - ); - const filePaths = await getFilePaths(directoryId); + return await db.transaction( + async (tx) => { + const unregisterFiles = async (parentId: number) => { + const files = await tx + .delete(file) + .where(and(eq(file.userId, userId), eq(file.parentId, parentId))) + .returning({ path: file.path }); + return files.map(({ path }) => path); + }; + const unregisterDirectoryRecursively = async (directoryId: number): Promise => { + const filePaths = await unregisterFiles(directoryId); + const subDirectories = await tx + .select({ id: directory.id }) + .from(directory) + .where(and(eq(directory.userId, userId), eq(directory.parentId, directoryId))); + const subDirectoryFilePaths = await Promise.all( + subDirectories.map(async ({ id }) => await unregisterDirectoryRecursively(id)), + ); - await tx.delete(file).where(eq(file.parentId, directoryId)); - await tx.delete(directory).where(eq(directory.id, directoryId)); - - return filePaths.concat(...subDirectoryFilePaths); - }; - return await unregisterSubDirectoriesRecursively(directoryId); - }); + const deleteRes = await tx.delete(directory).where(eq(directory.id, directoryId)); + if (deleteRes.changes === 0) { + throw new IntegrityError("Directory not found"); + } + return filePaths.concat(...subDirectoryFilePaths); + }; + return await unregisterDirectoryRecursively(directoryId); + }, + { behavior: "exclusive" }, + ); }; -export const registerNewFile = async (params: NewFileParams) => { - await db.transaction(async (tx) => { - const meks = await tx - .select() - .from(mek) - .where(and(eq(mek.userId, params.userId), eq(mek.state, "active"))); - if (meks[0]?.version !== params.mekVersion) { - throw new Error("Invalid MEK version"); - } +export const registerFile = async (params: NewFileParams) => { + await db.transaction( + async (tx) => { + const meks = await tx + .select({ version: mek.version }) + .from(mek) + .where(and(eq(mek.userId, params.userId), eq(mek.state, "active"))) + .limit(1); + if (meks[0]?.version !== params.mekVersion) { + throw new IntegrityError("Inactive MEK version"); + } - const now = new Date(); - await tx.insert(file).values({ - path: params.path, - parentId: params.parentId === "root" ? null : params.parentId, - createdAt: now, - userId: params.userId, - mekVersion: params.mekVersion, - contentType: params.contentType, - encDek: params.encDek, - dekVersion: params.dekVersion, - encContentIv: params.encContentIv, - encName: { ciphertext: params.encName, iv: params.encNameIv }, - }); - }); + await tx.insert(file).values({ + path: params.path, + parentId: params.parentId === "root" ? null : params.parentId, + createdAt: new Date(), + userId: params.userId, + mekVersion: params.mekVersion, + contentType: params.contentType, + encDek: params.encDek, + dekVersion: params.dekVersion, + encContentIv: params.encContentIv, + encName: { ciphertext: params.encName, iv: params.encNameIv }, + }); + }, + { behavior: "exclusive" }, + ); }; export const getAllFilesByParent = async (userId: number, parentId: DirectoryId) => { @@ -156,8 +174,7 @@ export const getAllFilesByParent = async (userId: number, parentId: DirectoryId) eq(file.userId, userId), parentId === "root" ? isNull(file.parentId) : eq(file.parentId, parentId), ), - ) - .execute(); + ); }; export const getFile = async (userId: number, fileId: number) => { @@ -165,7 +182,7 @@ export const getFile = async (userId: number, fileId: number) => { .select() .from(file) .where(and(eq(file.userId, userId), eq(file.id, fileId))) - .execute(); + .limit(1); return res[0] ?? null; }; @@ -176,19 +193,35 @@ export const setFileEncName = async ( encName: string, encNameIv: string, ) => { - const res = await db - .update(file) - .set({ encName: { ciphertext: encName, iv: encNameIv } }) - .where(and(eq(file.userId, userId), eq(file.id, fileId), eq(file.dekVersion, dekVersion))) - .execute(); - return res.changes > 0; + await db.transaction( + async (tx) => { + const files = await tx + .select({ version: file.dekVersion }) + .from(file) + .where(and(eq(file.userId, userId), eq(file.id, fileId))) + .limit(1); + if (!files[0]) { + throw new IntegrityError("File not found"); + } else if (files[0].version.getTime() !== dekVersion.getTime()) { + throw new IntegrityError("Invalid DEK version"); + } + + await tx + .update(file) + .set({ encName: { ciphertext: encName, iv: encNameIv } }) + .where(and(eq(file.userId, userId), eq(file.id, fileId))); + }, + { behavior: "exclusive" }, + ); }; export const unregisterFile = async (userId: number, fileId: number) => { - const res = await db + const files = await db .delete(file) .where(and(eq(file.userId, userId), eq(file.id, fileId))) - .returning({ path: file.path }) - .execute(); - return res[0]?.path ?? null; + .returning({ path: file.path }); + if (!files[0]) { + throw new IntegrityError("File not found"); + } + return files[0].path; }; diff --git a/src/lib/server/db/mek.ts b/src/lib/server/db/mek.ts index 7215ce0..237ef59 100644 --- a/src/lib/server/db/mek.ts +++ b/src/lib/server/db/mek.ts @@ -1,5 +1,7 @@ +import { SqliteError } from "better-sqlite3"; import { and, or, eq } from "drizzle-orm"; import db from "./drizzle"; +import { IntegrityError } from "./error"; import { mek, clientMek } from "./schema"; export const registerInitialMek = async ( @@ -8,22 +10,32 @@ export const registerInitialMek = async ( encMek: string, encMekSig: string, ) => { - await db.transaction(async (tx) => { - await tx.insert(mek).values({ - userId, - version: 1, - createdBy, - createdAt: new Date(), - state: "active", - }); - await tx.insert(clientMek).values({ - userId, - clientId: createdBy, - mekVersion: 1, - encMek, - encMekSig, - }); - }); + await db.transaction( + async (tx) => { + try { + await tx.insert(mek).values({ + userId, + version: 1, + createdBy, + createdAt: new Date(), + state: "active", + }); + await tx.insert(clientMek).values({ + userId, + clientId: createdBy, + mekVersion: 1, + encMek, + encMekSig, + }); + } catch (e) { + if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_PRIMARYKEY") { + throw new IntegrityError("MEK already registered"); + } + throw e; + } + }, + { behavior: "exclusive" }, + ); }; export const getInitialMek = async (userId: number) => { @@ -31,19 +43,10 @@ export const getInitialMek = async (userId: number) => { .select() .from(mek) .where(and(eq(mek.userId, userId), eq(mek.version, 1))) - .execute(); + .limit(1); return meks[0] ?? null; }; -export const getActiveMekVersion = async (userId: number) => { - const meks = await db - .select({ version: mek.version }) - .from(mek) - .where(and(eq(mek.userId, userId), eq(mek.state, "active"))) - .execute(); - return meks[0]?.version ?? null; -}; - export const getAllValidClientMeks = async (userId: number, clientId: number) => { return await db .select() @@ -55,6 +58,5 @@ export const getAllValidClientMeks = async (userId: number, clientId: number) => eq(clientMek.clientId, clientId), or(eq(mek.state, "active"), eq(mek.state, "retired")), ), - ) - .execute(); + ); }; diff --git a/src/lib/server/db/token.ts b/src/lib/server/db/token.ts index e26a8ef..25bf1de 100644 --- a/src/lib/server/db/token.ts +++ b/src/lib/server/db/token.ts @@ -2,6 +2,7 @@ import { SqliteError } from "better-sqlite3"; import { and, eq, gt, lte } from "drizzle-orm"; import env from "$lib/server/loadenv"; import db from "./drizzle"; +import { IntegrityError } from "./error"; import { refreshToken, tokenUpgradeChallenge } from "./schema"; const expiresAt = () => new Date(Date.now() + env.jwt.refreshExp); @@ -12,44 +13,45 @@ export const registerRefreshToken = async ( tokenId: string, ) => { try { - await db - .insert(refreshToken) - .values({ - id: tokenId, - userId, - clientId, - expiresAt: expiresAt(), - }) - .execute(); - return true; + await db.insert(refreshToken).values({ + id: tokenId, + userId, + clientId, + expiresAt: expiresAt(), + }); } catch (e) { if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") { - return false; + throw new IntegrityError("Refresh token already registered"); } throw e; } }; export const getRefreshToken = async (tokenId: string) => { - const tokens = await db.select().from(refreshToken).where(eq(refreshToken.id, tokenId)).execute(); + const tokens = await db.select().from(refreshToken).where(eq(refreshToken.id, tokenId)).limit(1); return tokens[0] ?? null; }; export const rotateRefreshToken = async (oldTokenId: string, newTokenId: string) => { - return await db.transaction(async (tx) => { - await tx - .delete(tokenUpgradeChallenge) - .where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId)); - const res = await db - .update(refreshToken) - .set({ - id: newTokenId, - expiresAt: expiresAt(), - }) - .where(eq(refreshToken.id, oldTokenId)) - .execute(); - return res.changes > 0; - }); + await db.transaction( + async (tx) => { + await tx + .delete(tokenUpgradeChallenge) + .where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId)); + + const res = await tx + .update(refreshToken) + .set({ + id: newTokenId, + expiresAt: expiresAt(), + }) + .where(eq(refreshToken.id, oldTokenId)); + if (res.changes === 0) { + throw new IntegrityError("Refresh token not found"); + } + }, + { behavior: "exclusive" }, + ); }; export const upgradeRefreshToken = async ( @@ -57,29 +59,34 @@ export const upgradeRefreshToken = async ( newTokenId: string, clientId: number, ) => { - return await db.transaction(async (tx) => { - await tx - .delete(tokenUpgradeChallenge) - .where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId)); - const res = await tx - .update(refreshToken) - .set({ - id: newTokenId, - clientId, - expiresAt: expiresAt(), - }) - .where(eq(refreshToken.id, oldTokenId)) - .execute(); - return res.changes > 0; - }); + await db.transaction( + async (tx) => { + await tx + .delete(tokenUpgradeChallenge) + .where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId)); + + const res = await tx + .update(refreshToken) + .set({ + id: newTokenId, + clientId, + expiresAt: expiresAt(), + }) + .where(eq(refreshToken.id, oldTokenId)); + if (res.changes === 0) { + throw new IntegrityError("Refresh token not found"); + } + }, + { behavior: "exclusive" }, + ); }; export const revokeRefreshToken = async (tokenId: string) => { - await db.delete(refreshToken).where(eq(refreshToken.id, tokenId)).execute(); + await db.delete(refreshToken).where(eq(refreshToken.id, tokenId)); }; export const cleanupExpiredRefreshTokens = async () => { - await db.delete(refreshToken).where(lte(refreshToken.expiresAt, new Date())).execute(); + await db.delete(refreshToken).where(lte(refreshToken.expiresAt, new Date())); }; export const registerTokenUpgradeChallenge = async ( @@ -89,16 +96,13 @@ export const registerTokenUpgradeChallenge = async ( allowedIp: string, expiresAt: Date, ) => { - await db - .insert(tokenUpgradeChallenge) - .values({ - refreshTokenId: tokenId, - clientId, - answer, - allowedIp, - expiresAt, - }) - .execute(); + await db.insert(tokenUpgradeChallenge).values({ + refreshTokenId: tokenId, + clientId, + answer, + allowedIp, + expiresAt, + }); }; export const getTokenUpgradeChallenge = async (answer: string, ip: string) => { @@ -113,7 +117,7 @@ export const getTokenUpgradeChallenge = async (answer: string, ip: string) => { eq(tokenUpgradeChallenge.isUsed, false), ), ) - .execute(); + .limit(1); return challenges[0] ?? null; }; @@ -121,13 +125,9 @@ export const markTokenUpgradeChallengeAsUsed = async (id: number) => { await db .update(tokenUpgradeChallenge) .set({ isUsed: true }) - .where(eq(tokenUpgradeChallenge.id, id)) - .execute(); + .where(eq(tokenUpgradeChallenge.id, id)); }; export const cleanupExpiredTokenUpgradeChallenges = async () => { - await db - .delete(tokenUpgradeChallenge) - .where(lte(tokenUpgradeChallenge.expiresAt, new Date())) - .execute(); + await db.delete(tokenUpgradeChallenge).where(lte(tokenUpgradeChallenge.expiresAt, new Date())); }; diff --git a/src/lib/server/db/user.ts b/src/lib/server/db/user.ts index 38a53f0..1efe43a 100644 --- a/src/lib/server/db/user.ts +++ b/src/lib/server/db/user.ts @@ -3,6 +3,6 @@ import db from "./drizzle"; import { user } from "./schema"; export const getUserByEmail = async (email: string) => { - const users = await db.select().from(user).where(eq(user.email, email)).execute(); + const users = await db.select().from(user).where(eq(user.email, email)).limit(1); return users[0] ?? null; }; diff --git a/src/lib/server/services/auth.ts b/src/lib/server/services/auth.ts index 53c2e51..36a3c5a 100644 --- a/src/lib/server/services/auth.ts +++ b/src/lib/server/services/auth.ts @@ -4,9 +4,10 @@ import { v4 as uuidv4 } from "uuid"; import { getClient, getClientByPubKeys, getUserClient } from "$lib/server/db/client"; import { getUserByEmail } from "$lib/server/db/user"; import env from "$lib/server/loadenv"; +import { IntegrityError } from "$lib/server/db/error"; import { - getRefreshToken, registerRefreshToken, + getRefreshToken, rotateRefreshToken, upgradeRefreshToken, revokeRefreshToken, @@ -29,10 +30,15 @@ const issueRefreshToken = async (userId: number, clientId?: number) => { const jti = uuidv4(); const token = issueToken({ type: "refresh", jti }); - if (!(await registerRefreshToken(userId, clientId ?? null, jti))) { - error(403, "Already logged in"); + try { + await registerRefreshToken(userId, clientId ?? null, jti); + return token; + } catch (e) { + if (e instanceof IntegrityError && e.message === "Refresh token already registered") { + error(409, "Already logged in"); + } + throw e; } - return token; }; export const login = async (email: string, password: string) => { @@ -57,7 +63,7 @@ const verifyRefreshToken = async (refreshToken: string) => { const tokenData = await getRefreshToken(tokenPayload.jti); if (!tokenData) { - error(500, "Refresh token not found"); + error(500, "Invalid refresh token"); } return { @@ -76,13 +82,18 @@ export const refreshToken = async (refreshToken: string) => { const { jti: oldJti, userId, clientId } = await verifyRefreshToken(refreshToken); const newJti = uuidv4(); - if (!(await rotateRefreshToken(oldJti, newJti))) { - error(500, "Refresh token not found"); + try { + await rotateRefreshToken(oldJti, newJti); + return { + accessToken: issueAccessToken(userId, clientId), + refreshToken: issueToken({ type: "refresh", jti: newJti }), + }; + } catch (e) { + if (e instanceof IntegrityError && e.message === "Refresh token not found") { + error(500, "Invalid refresh token"); + } + throw e; } - return { - accessToken: issueAccessToken(userId, clientId), - refreshToken: issueToken({ type: "refresh", jti: newJti }), - }; }; const expiresAt = () => new Date(Date.now() + env.challenge.tokenUpgradeExp); @@ -120,7 +131,7 @@ export const createTokenUpgradeChallenge = async ( if (!client) { error(401, "Invalid public key(s)"); } else if (!userClient || userClient.state === "challenging") { - error(401, "Unregistered client"); + error(403, "Unregistered client"); } return { challenge: await createChallenge(ip, jti, client.id, encPubKey) }; @@ -139,26 +150,31 @@ export const upgradeToken = async ( const challenge = await getTokenUpgradeChallenge(answer, ip); if (!challenge) { - error(401, "Invalid challenge answer"); + error(403, "Invalid challenge answer"); } else if (challenge.refreshTokenId !== oldJti) { error(403, "Forbidden"); } + await markTokenUpgradeChallengeAsUsed(challenge.id); + const client = await getClient(challenge.clientId); if (!client) { error(500, "Invalid challenge answer"); } else if (!verifySignature(Buffer.from(answer, "base64"), answerSig, client.sigPubKey)) { - error(401, "Invalid challenge answer signature"); + error(403, "Invalid challenge answer signature"); } - await markTokenUpgradeChallengeAsUsed(challenge.id); - - const newJti = uuidv4(); - if (!(await upgradeRefreshToken(oldJti, newJti, client.id))) { - error(500, "Refresh token not found"); + try { + const newJti = uuidv4(); + await upgradeRefreshToken(oldJti, newJti, client.id); + return { + accessToken: issueAccessToken(userId, client.id), + refreshToken: issueToken({ type: "refresh", jti: newJti }), + }; + } catch (e) { + if (e instanceof IntegrityError && e.message === "Refresh token not found") { + error(500, "Invalid refresh token"); + } + throw e; } - return { - accessToken: issueAccessToken(userId, client.id), - refreshToken: issueToken({ type: "refresh", jti: newJti }), - }; }; diff --git a/src/lib/server/services/client.ts b/src/lib/server/services/client.ts index 1f99d3a..73973bb 100644 --- a/src/lib/server/services/client.ts +++ b/src/lib/server/services/client.ts @@ -3,7 +3,6 @@ import { createClient, getClient, getClientByPubKeys, - countClientByPubKey, createUserClient, getAllUserClients, getUserClient, @@ -12,6 +11,7 @@ import { getUserClientChallenge, markUserClientChallengeAsUsed, } from "$lib/server/db/client"; +import { IntegrityError } from "$lib/server/db/error"; import { verifyPubKey, verifySignature, generateChallenge } from "$lib/server/modules/crypto"; import { isInitialMekNeeded } from "$lib/server/modules/mek"; import env from "$lib/server/loadenv"; @@ -29,8 +29,8 @@ export const getUserClientList = async (userId: number) => { const expiresAt = () => new Date(Date.now() + env.challenge.userClientExp); const createUserClientChallenge = async ( - userId: number, ip: string, + userId: number, clientId: number, encPubKey: string, ) => { @@ -45,33 +45,59 @@ export const registerUserClient = async ( encPubKey: string, sigPubKey: string, ) => { - let clientId; - const client = await getClientByPubKeys(encPubKey, sigPubKey); if (client) { - const userClient = await getUserClient(userId, client.id); - if (userClient) { - error(409, "Client already registered"); + try { + await createUserClient(userId, client.id); + return { challenge: await createUserClientChallenge(ip, userId, client.id, encPubKey) }; + } catch (e) { + if (e instanceof IntegrityError && e.message === "User client already exists") { + error(409, "Client already registered"); + } + throw e; } - - await createUserClient(userId, client.id); - clientId = client.id; } else { - if (!verifyPubKey(encPubKey) || !verifyPubKey(sigPubKey)) { + if (encPubKey === sigPubKey) { + error(400, "Same public keys"); + } else if (!verifyPubKey(encPubKey) || !verifyPubKey(sigPubKey)) { error(400, "Invalid public key(s)"); - } else if (encPubKey === sigPubKey) { - error(400, "Public keys must be different"); - } else if ( - (await countClientByPubKey(encPubKey)) > 0 || - (await countClientByPubKey(sigPubKey)) > 0 - ) { - error(409, "Public key(s) already registered"); } - clientId = await createClient(encPubKey, sigPubKey, userId); + try { + const clientId = await createClient(encPubKey, sigPubKey, userId); + return { challenge: await createUserClientChallenge(ip, userId, clientId, encPubKey) }; + } catch (e) { + if (e instanceof IntegrityError && e.message === "Public key(s) already registered") { + error(409, "Public key(s) already used"); + } + throw e; + } + } +}; + +export const verifyUserClient = async ( + userId: number, + ip: string, + answer: string, + answerSig: string, +) => { + const challenge = await getUserClientChallenge(answer, ip); + if (!challenge) { + error(403, "Invalid challenge answer"); + } else if (challenge.userId !== userId) { + error(403, "Forbidden"); } - return { challenge: await createUserClientChallenge(userId, ip, clientId, encPubKey) }; + await markUserClientChallengeAsUsed(challenge.id); + + const client = await getClient(challenge.clientId); + if (!client) { + error(500, "Invalid challenge answer"); + } else if (!verifySignature(Buffer.from(answer, "base64"), answerSig, client.sigPubKey)) { + error(403, "Invalid challenge answer signature"); + } + + await setUserClientStateToPending(userId, challenge.clientId); }; export const getUserClientStatus = async (userId: number, clientId: number) => { @@ -85,27 +111,3 @@ export const getUserClientStatus = async (userId: number, clientId: number) => { isInitialMekNeeded: await isInitialMekNeeded(userId), }; }; - -export const verifyUserClient = async ( - userId: number, - ip: string, - answer: string, - answerSig: string, -) => { - const challenge = await getUserClientChallenge(answer, ip); - if (!challenge) { - error(401, "Invalid challenge answer"); - } else if (challenge.userId !== userId) { - error(403, "Forbidden"); - } - - const client = await getClient(challenge.clientId); - if (!client) { - error(500, "Invalid challenge answer"); - } else if (!verifySignature(Buffer.from(answer, "base64"), answerSig, client.sigPubKey)) { - error(401, "Invalid challenge answer signature"); - } - - await markUserClientChallengeAsUsed(challenge.id); - await setUserClientStateToPending(userId, challenge.clientId); -}; diff --git a/src/lib/server/services/directory.ts b/src/lib/server/services/directory.ts index 01d39d5..5dc408f 100644 --- a/src/lib/server/services/directory.ts +++ b/src/lib/server/services/directory.ts @@ -1,44 +1,15 @@ import { error } from "@sveltejs/kit"; import { unlink } from "fs/promises"; +import { IntegrityError } from "$lib/server/db/error"; import { + registerDirectory, getAllDirectoriesByParent, - registerNewDirectory, getDirectory, setDirectoryEncName, unregisterDirectory, getAllFilesByParent, type NewDirectoryParams, } from "$lib/server/db/file"; -import { getActiveMekVersion } from "$lib/server/db/mek"; - -export const deleteDirectory = async (userId: number, directoryId: number) => { - const directory = await getDirectory(userId, directoryId); - if (!directory) { - error(404, "Invalid directory id"); - } - - const filePaths = await unregisterDirectory(userId, directoryId); - filePaths.map((path) => unlink(path)); // Intended -}; - -export const renameDirectory = async ( - userId: number, - directoryId: number, - dekVersion: Date, - newEncName: string, - newEncNameIv: string, -) => { - const directory = await getDirectory(userId, directoryId); - if (!directory) { - error(404, "Invalid directory id"); - } else if (directory.dekVersion.getTime() !== dekVersion.getTime()) { - error(400, "Invalid DEK version"); - } - - if (!(await setDirectoryEncName(userId, directoryId, dekVersion, newEncName, newEncNameIv))) { - error(500, "Invalid directory id or DEK version"); - } -}; export const getDirectoryInformation = async (userId: number, directoryId: "root" | number) => { const directory = directoryId !== "root" ? await getDirectory(userId, directoryId) : undefined; @@ -62,19 +33,52 @@ export const getDirectoryInformation = async (userId: number, directoryId: "root }; }; -export const createDirectory = async (params: NewDirectoryParams) => { - const activeMekVersion = await getActiveMekVersion(params.userId); - if (activeMekVersion === null) { - error(500, "Invalid MEK version"); - } else if (activeMekVersion !== params.mekVersion) { - error(400, "Invalid MEK version"); +export const deleteDirectory = async (userId: number, directoryId: number) => { + try { + const filePaths = await unregisterDirectory(userId, directoryId); + filePaths.map((path) => unlink(path)); // Intended + } catch (e) { + if (e instanceof IntegrityError && e.message === "Directory not found") { + error(404, "Invalid directory id"); + } + throw e; } +}; +export const renameDirectory = async ( + userId: number, + directoryId: number, + dekVersion: Date, + newEncName: string, + newEncNameIv: string, +) => { + try { + await setDirectoryEncName(userId, directoryId, dekVersion, newEncName, newEncNameIv); + } catch (e) { + if (e instanceof IntegrityError) { + if (e.message === "Directory not found") { + error(404, "Invalid directory id"); + } else if (e.message === "Invalid DEK version") { + error(400, "Invalid DEK version"); + } + } + throw e; + } +}; + +export const createDirectory = async (params: NewDirectoryParams) => { const oneMinuteAgo = new Date(Date.now() - 60 * 1000); const oneMinuteLater = new Date(Date.now() + 60 * 1000); if (params.dekVersion <= oneMinuteAgo || params.dekVersion >= oneMinuteLater) { error(400, "Invalid DEK version"); } - await registerNewDirectory(params); + try { + await registerDirectory(params); + } catch (e) { + if (e instanceof IntegrityError && e.message === "Inactive MEK version") { + error(400, "Invalid MEK version"); + } + throw e; + } }; diff --git a/src/lib/server/services/file.ts b/src/lib/server/services/file.ts index 7bf9b72..c87414c 100644 --- a/src/lib/server/services/file.ts +++ b/src/lib/server/services/file.ts @@ -1,77 +1,19 @@ import { error } from "@sveltejs/kit"; -import { createReadStream, createWriteStream, ReadStream, WriteStream } from "fs"; +import { createReadStream, createWriteStream } from "fs"; import { mkdir, stat, unlink } from "fs/promises"; import { dirname } from "path"; +import { Readable, Writable } from "stream"; import { v4 as uuidv4 } from "uuid"; +import { IntegrityError } from "$lib/server/db/error"; import { - registerNewFile, + registerFile, getFile, setFileEncName, unregisterFile, type NewFileParams, } from "$lib/server/db/file"; -import { getActiveMekVersion } from "$lib/server/db/mek"; import env from "$lib/server/loadenv"; -export const deleteFile = async (userId: number, fileId: number) => { - const file = await getFile(userId, fileId); - if (!file) { - error(404, "Invalid file id"); - } - - const path = await unregisterFile(userId, fileId); - if (!path) { - error(500, "Invalid file id"); - } - - unlink(path); // Intended -}; - -const convertToReadableStream = (readStream: ReadStream) => { - return new ReadableStream({ - start: (controller) => { - readStream.on("data", (chunk) => controller.enqueue(new Uint8Array(chunk as Buffer))); - readStream.on("end", () => controller.close()); - readStream.on("error", (e) => controller.error(e)); - }, - cancel: () => { - readStream.destroy(); - }, - }); -}; - -export const getFileStream = async (userId: number, fileId: number) => { - const file = await getFile(userId, fileId); - if (!file) { - error(404, "Invalid file id"); - } - - const { size } = await stat(file.path); - return { - encContentStream: convertToReadableStream(createReadStream(file.path)), - encContentSize: size, - }; -}; - -export const renameFile = async ( - userId: number, - fileId: number, - dekVersion: Date, - newEncName: string, - newEncNameIv: string, -) => { - const file = await getFile(userId, fileId); - if (!file) { - error(404, "Invalid file id"); - } else if (file.dekVersion.getTime() !== dekVersion.getTime()) { - error(400, "Invalid DEK version"); - } - - if (!(await setFileEncName(userId, fileId, dekVersion, newEncName, newEncNameIv))) { - error(500, "Invalid file id or DEK version"); - } -}; - export const getFileInformation = async (userId: number, fileId: number) => { const file = await getFile(userId, fileId); if (!file) { @@ -89,20 +31,50 @@ export const getFileInformation = async (userId: number, fileId: number) => { }; }; -const convertToWritableStream = (writeStream: WriteStream) => { - return new WritableStream({ - write: (chunk) => - new Promise((resolve, reject) => { - writeStream.write(chunk, (e) => { - if (e) { - reject(e); - } else { - resolve(); - } - }); - }), - close: () => new Promise((resolve) => writeStream.end(resolve)), - }); +export const deleteFile = async (userId: number, fileId: number) => { + try { + const filePath = await unregisterFile(userId, fileId); + unlink(filePath); // Intended + } catch (e) { + if (e instanceof IntegrityError && e.message === "File not found") { + error(404, "Invalid file id"); + } + throw e; + } +}; + +export const getFileStream = async (userId: number, fileId: number) => { + const file = await getFile(userId, fileId); + if (!file) { + error(404, "Invalid file id"); + } + + const { size } = await stat(file.path); + return { + encContentStream: Readable.toWeb(createReadStream(file.path)), + encContentSize: size, + }; +}; + +export const renameFile = async ( + userId: number, + fileId: number, + dekVersion: Date, + newEncName: string, + newEncNameIv: string, +) => { + try { + await setFileEncName(userId, fileId, dekVersion, newEncName, newEncNameIv); + } catch (e) { + if (e instanceof IntegrityError) { + if (e.message === "File not found") { + error(404, "Invalid file id"); + } else if (e.message === "Invalid DEK version") { + error(400, "Invalid DEK version"); + } + } + throw e; + } }; const safeUnlink = async (path: string) => { @@ -113,13 +85,6 @@ export const uploadFile = async ( params: Omit, encContentStream: ReadableStream, ) => { - const activeMekVersion = await getActiveMekVersion(params.userId); - if (activeMekVersion === null) { - error(500, "Invalid MEK version"); - } else if (activeMekVersion !== params.mekVersion) { - error(400, "Invalid MEK version"); - } - const oneMinuteAgo = new Date(Date.now() - 60 * 1000); const oneMinuteLater = new Date(Date.now() + 60 * 1000); if (params.dekVersion <= oneMinuteAgo || params.dekVersion >= oneMinuteLater) { @@ -131,14 +96,20 @@ export const uploadFile = async ( try { await encContentStream.pipeTo( - convertToWritableStream(createWriteStream(path, { flags: "wx", mode: 0o600 })), + Writable.toWeb(createWriteStream(path, { flags: "wx", mode: 0o600 })), ); - await registerNewFile({ + await registerFile({ ...params, path, }); } catch (e) { await safeUnlink(path); + + if (e instanceof IntegrityError) { + if (e.message === "Inactive MEK version") { + error(400, "Invalid MEK version"); + } + } throw e; } }; diff --git a/src/lib/server/services/mek.ts b/src/lib/server/services/mek.ts index 95caef9..e0deeb0 100644 --- a/src/lib/server/services/mek.ts +++ b/src/lib/server/services/mek.ts @@ -1,7 +1,8 @@ import { error } from "@sveltejs/kit"; import { setUserClientStateToActive } from "$lib/server/db/client"; +import { IntegrityError } from "$lib/server/db/error"; import { registerInitialMek, getAllValidClientMeks } from "$lib/server/db/mek"; -import { isInitialMekNeeded, verifyClientEncMekSig } from "$lib/server/modules/mek"; +import { verifyClientEncMekSig } from "$lib/server/modules/mek"; export const getClientMekList = async (userId: number, clientId: number) => { const clientMeks = await getAllValidClientMeks(userId, clientId); @@ -21,12 +22,17 @@ export const registerInitialActiveMek = async ( encMek: string, encMekSig: string, ) => { - if (!(await isInitialMekNeeded(userId))) { - error(409, "Initial MEK already registered"); - } else if (!(await verifyClientEncMekSig(userId, createdBy, 1, encMek, encMekSig))) { + if (!(await verifyClientEncMekSig(userId, createdBy, 1, encMek, encMekSig))) { error(400, "Invalid signature"); } - await registerInitialMek(userId, createdBy, encMek, encMekSig); - await setUserClientStateToActive(userId, createdBy); + try { + await registerInitialMek(userId, createdBy, encMek, encMekSig); + await setUserClientStateToActive(userId, createdBy); + } catch (e) { + if (e instanceof IntegrityError && e.message === "MEK already registered") { + error(409, "Initial MEK already registered"); + } + throw e; + } }; diff --git a/src/routes/api/file/[id]/download/+server.ts b/src/routes/api/file/[id]/download/+server.ts index 42b832f..58f915d 100644 --- a/src/routes/api/file/[id]/download/+server.ts +++ b/src/routes/api/file/[id]/download/+server.ts @@ -16,7 +16,7 @@ export const GET: RequestHandler = async ({ cookies, params }) => { const { id } = zodRes.data; const { encContentStream, encContentSize } = await getFileStream(userId, id); - return new Response(encContentStream, { + return new Response(encContentStream as ReadableStream, { headers: { "Content-Type": "application/octet-stream", "Content-Length": encContentSize.toString(),