DB에 동시적으로 접근하더라도 데이터 무결성이 깨지지 않도록 DB 접근 코드 수정

This commit is contained in:
static
2025-01-11 03:55:19 +09:00
parent 045eb69487
commit 0bdf990dae
12 changed files with 486 additions and 438 deletions

View File

@@ -1,30 +1,36 @@
import { and, or, eq, gt, lte, count } from "drizzle-orm"; import { SqliteError } from "better-sqlite3";
import { and, or, eq, gt, lte } from "drizzle-orm";
import db from "./drizzle"; import db from "./drizzle";
import { IntegrityError } from "./error";
import { client, userClient, userClientChallenge } from "./schema"; import { client, userClient, userClientChallenge } from "./schema";
export const createClient = async (encPubKey: string, sigPubKey: string, userId: number) => { export const createClient = async (encPubKey: string, sigPubKey: string, userId: number) => {
return await db.transaction(async (tx) => { return await db.transaction(
const clients = await tx async (tx) => {
.select() const clients = await tx
.from(client) .select({ id: client.id })
.where(or(eq(client.encPubKey, sigPubKey), eq(client.sigPubKey, encPubKey))); .from(client)
if (clients.length > 0) { .where(or(eq(client.encPubKey, sigPubKey), eq(client.sigPubKey, encPubKey)))
throw new Error("Already used public key(s)"); .limit(1);
} if (clients.length !== 0) {
throw new IntegrityError("Public key(s) already registered");
}
const insertRes = await tx const newClients = await tx
.insert(client) .insert(client)
.values({ encPubKey, sigPubKey }) .values({ encPubKey, sigPubKey })
.returning({ id: client.id }); .returning({ id: client.id });
const { id: clientId } = insertRes[0]!; const { id: clientId } = newClients[0]!;
await tx.insert(userClient).values({ userId, clientId }); await tx.insert(userClient).values({ userId, clientId });
return clientId; return clientId;
}); },
{ behavior: "exclusive" },
);
}; };
export const getClient = async (clientId: number) => { export const getClient = async (clientId: number) => {
const clients = await db.select().from(client).where(eq(client.id, clientId)).execute(); const clients = await db.select().from(client).where(eq(client.id, clientId)).limit(1);
return clients[0] ?? null; return clients[0] ?? null;
}; };
@@ -33,24 +39,23 @@ export const getClientByPubKeys = async (encPubKey: string, sigPubKey: string) =
.select() .select()
.from(client) .from(client)
.where(and(eq(client.encPubKey, encPubKey), eq(client.sigPubKey, sigPubKey))) .where(and(eq(client.encPubKey, encPubKey), eq(client.sigPubKey, sigPubKey)))
.execute(); .limit(1);
return clients[0] ?? null; return clients[0] ?? null;
}; };
export const countClientByPubKey = async (pubKey: string) => {
const clients = await db
.select({ count: count() })
.from(client)
.where(or(eq(client.encPubKey, pubKey), eq(client.encPubKey, pubKey)));
return clients[0]?.count ?? 0;
};
export const createUserClient = async (userId: number, clientId: number) => { export const createUserClient = async (userId: number, clientId: number) => {
await db.insert(userClient).values({ userId, clientId }).execute(); try {
await db.insert(userClient).values({ userId, clientId });
} catch (e) {
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_PRIMARYKEY") {
throw new IntegrityError("User client already exists");
}
throw e;
}
}; };
export const getAllUserClients = async (userId: number) => { export const getAllUserClients = async (userId: number) => {
return await db.select().from(userClient).where(eq(userClient.userId, userId)).execute(); return await db.select().from(userClient).where(eq(userClient.userId, userId));
}; };
export const getUserClient = async (userId: number, clientId: number) => { export const getUserClient = async (userId: number, clientId: number) => {
@@ -58,7 +63,7 @@ export const getUserClient = async (userId: number, clientId: number) => {
.select() .select()
.from(userClient) .from(userClient)
.where(and(eq(userClient.userId, userId), eq(userClient.clientId, clientId))) .where(and(eq(userClient.userId, userId), eq(userClient.clientId, clientId)))
.execute(); .limit(1);
return userClients[0] ?? null; return userClients[0] ?? null;
}; };
@@ -68,7 +73,7 @@ export const getUserClientWithDetails = async (userId: number, clientId: number)
.from(userClient) .from(userClient)
.innerJoin(client, eq(userClient.clientId, client.id)) .innerJoin(client, eq(userClient.clientId, client.id))
.where(and(eq(userClient.userId, userId), eq(userClient.clientId, clientId))) .where(and(eq(userClient.userId, userId), eq(userClient.clientId, clientId)))
.execute(); .limit(1);
return userClients[0] ?? null; return userClients[0] ?? null;
}; };
@@ -82,8 +87,7 @@ export const setUserClientStateToPending = async (userId: number, clientId: numb
eq(userClient.clientId, clientId), eq(userClient.clientId, clientId),
eq(userClient.state, "challenging"), eq(userClient.state, "challenging"),
), ),
) );
.execute();
}; };
export const setUserClientStateToActive = async (userId: number, clientId: number) => { export const setUserClientStateToActive = async (userId: number, clientId: number) => {
@@ -96,8 +100,7 @@ export const setUserClientStateToActive = async (userId: number, clientId: numbe
eq(userClient.clientId, clientId), eq(userClient.clientId, clientId),
eq(userClient.state, "pending"), eq(userClient.state, "pending"),
), ),
) );
.execute();
}; };
export const registerUserClientChallenge = async ( export const registerUserClientChallenge = async (
@@ -107,16 +110,13 @@ export const registerUserClientChallenge = async (
allowedIp: string, allowedIp: string,
expiresAt: Date, expiresAt: Date,
) => { ) => {
await db await db.insert(userClientChallenge).values({
.insert(userClientChallenge) userId,
.values({ clientId,
userId, answer,
clientId, allowedIp,
answer, expiresAt,
allowedIp, });
expiresAt,
})
.execute();
}; };
export const getUserClientChallenge = async (answer: string, ip: string) => { export const getUserClientChallenge = async (answer: string, ip: string) => {
@@ -131,21 +131,14 @@ export const getUserClientChallenge = async (answer: string, ip: string) => {
eq(userClientChallenge.isUsed, false), eq(userClientChallenge.isUsed, false),
), ),
) )
.execute(); .limit(1);
return challenges[0] ?? null; return challenges[0] ?? null;
}; };
export const markUserClientChallengeAsUsed = async (id: number) => { export const markUserClientChallengeAsUsed = async (id: number) => {
await db await db.update(userClientChallenge).set({ isUsed: true }).where(eq(userClientChallenge.id, id));
.update(userClientChallenge)
.set({ isUsed: true })
.where(eq(userClientChallenge.id, id))
.execute();
}; };
export const cleanupExpiredUserClientChallenges = async () => { export const cleanupExpiredUserClientChallenges = async () => {
await db await db.delete(userClientChallenge).where(lte(userClientChallenge.expiresAt, new Date()));
.delete(userClientChallenge)
.where(lte(userClientChallenge.expiresAt, new Date()))
.execute();
}; };

View File

@@ -0,0 +1,21 @@
type IntegrityErrorMessages =
// Client
| "Public key(s) already registered"
| "User client already exists"
// File
| "Directory not found"
| "File not found"
| "Invalid DEK version"
// MEK
| "MEK already registered"
| "Inactive MEK version"
// Token
| "Refresh token not found"
| "Refresh token already registered";
export class IntegrityError extends Error {
constructor(public message: IntegrityErrorMessages) {
super(message);
this.name = "IntegrityError";
}
}

View File

@@ -1,5 +1,6 @@
import { and, eq, isNull } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
import db from "./drizzle"; import db from "./drizzle";
import { IntegrityError } from "./error";
import { directory, file, mek } from "./schema"; import { directory, file, mek } from "./schema";
type DirectoryId = "root" | number; type DirectoryId = "root" | number;
@@ -27,40 +28,42 @@ export interface NewFileParams {
encNameIv: string; encNameIv: string;
} }
export const registerNewDirectory = async (params: NewDirectoryParams) => { export const registerDirectory = async (params: NewDirectoryParams) => {
return await db.transaction(async (tx) => { await db.transaction(
const meks = await tx async (tx) => {
.select() const meks = await tx
.from(mek) .select({ version: mek.version })
.where(and(eq(mek.userId, params.userId), eq(mek.state, "active"))); .from(mek)
if (meks[0]?.version !== params.mekVersion) { .where(and(eq(mek.userId, params.userId), eq(mek.state, "active")))
throw new Error("Invalid MEK version"); .limit(1);
} if (meks[0]?.version !== params.mekVersion) {
throw new IntegrityError("Inactive MEK version");
}
const now = new Date(); await tx.insert(directory).values({
await tx.insert(directory).values({ createdAt: new Date(),
createdAt: now, parentId: params.parentId === "root" ? null : params.parentId,
parentId: params.parentId === "root" ? null : params.parentId, userId: params.userId,
userId: params.userId, mekVersion: params.mekVersion,
mekVersion: params.mekVersion, encDek: params.encDek,
encDek: params.encDek, dekVersion: params.dekVersion,
dekVersion: params.dekVersion, encName: { ciphertext: params.encName, iv: params.encNameIv },
encName: { ciphertext: params.encName, iv: params.encNameIv }, });
}); },
}); { behavior: "exclusive" },
);
}; };
export const getAllDirectoriesByParent = async (userId: number, directoryId: DirectoryId) => { export const getAllDirectoriesByParent = async (userId: number, parentId: DirectoryId) => {
return await db return await db
.select() .select()
.from(directory) .from(directory)
.where( .where(
and( and(
eq(directory.userId, userId), eq(directory.userId, userId),
directoryId === "root" ? isNull(directory.parentId) : eq(directory.parentId, directoryId), parentId === "root" ? isNull(directory.parentId) : eq(directory.parentId, parentId),
), ),
) );
.execute();
}; };
export const getDirectory = async (userId: number, directoryId: number) => { export const getDirectory = async (userId: number, directoryId: number) => {
@@ -68,7 +71,7 @@ export const getDirectory = async (userId: number, directoryId: number) => {
.select() .select()
.from(directory) .from(directory)
.where(and(eq(directory.userId, userId), eq(directory.id, directoryId))) .where(and(eq(directory.userId, userId), eq(directory.id, directoryId)))
.execute(); .limit(1);
return res[0] ?? null; return res[0] ?? null;
}; };
@@ -79,72 +82,87 @@ export const setDirectoryEncName = async (
encName: string, encName: string,
encNameIv: string, encNameIv: string,
) => { ) => {
const res = await db await db.transaction(
.update(directory) async (tx) => {
.set({ encName: { ciphertext: encName, iv: encNameIv } }) const directories = await tx
.where( .select({ version: directory.dekVersion })
and( .from(directory)
eq(directory.userId, userId), .where(and(eq(directory.userId, userId), eq(directory.id, directoryId)))
eq(directory.id, directoryId), .limit(1);
eq(directory.dekVersion, dekVersion), if (!directories[0]) {
), throw new IntegrityError("Directory not found");
) } else if (directories[0].version.getTime() !== dekVersion.getTime()) {
.execute(); throw new IntegrityError("Invalid DEK version");
return res.changes > 0; }
await tx
.update(directory)
.set({ encName: { ciphertext: encName, iv: encNameIv } })
.where(and(eq(directory.userId, userId), eq(directory.id, directoryId)));
},
{ behavior: "exclusive" },
);
}; };
export const unregisterDirectory = async (userId: number, directoryId: number) => { export const unregisterDirectory = async (userId: number, directoryId: number) => {
return await db.transaction(async (tx) => { return await db.transaction(
const getFilePaths = async (parentId: number) => { async (tx) => {
const files = await tx const unregisterFiles = async (parentId: number) => {
.select({ path: file.path }) const files = await tx
.from(file) .delete(file)
.where(and(eq(file.userId, userId), eq(file.parentId, parentId))); .where(and(eq(file.userId, userId), eq(file.parentId, parentId)))
return files.map(({ path }) => path); .returning({ path: file.path });
}; return files.map(({ path }) => path);
const unregisterSubDirectoriesRecursively = async (directoryId: number): Promise<string[]> => { };
const subDirectories = await tx const unregisterDirectoryRecursively = async (directoryId: number): Promise<string[]> => {
.select({ id: directory.id }) const filePaths = await unregisterFiles(directoryId);
.from(directory) const subDirectories = await tx
.where(and(eq(directory.userId, userId), eq(directory.parentId, directoryId))); .select({ id: directory.id })
const subDirectoryFilePaths = await Promise.all( .from(directory)
subDirectories.map(async ({ id }) => await unregisterSubDirectoriesRecursively(id)), .where(and(eq(directory.userId, userId), eq(directory.parentId, directoryId)));
); const subDirectoryFilePaths = await Promise.all(
const filePaths = await getFilePaths(directoryId); subDirectories.map(async ({ id }) => await unregisterDirectoryRecursively(id)),
);
await tx.delete(file).where(eq(file.parentId, directoryId)); const deleteRes = await tx.delete(directory).where(eq(directory.id, directoryId));
await tx.delete(directory).where(eq(directory.id, directoryId)); if (deleteRes.changes === 0) {
throw new IntegrityError("Directory not found");
return filePaths.concat(...subDirectoryFilePaths); }
}; return filePaths.concat(...subDirectoryFilePaths);
return await unregisterSubDirectoriesRecursively(directoryId); };
}); return await unregisterDirectoryRecursively(directoryId);
},
{ behavior: "exclusive" },
);
}; };
export const registerNewFile = async (params: NewFileParams) => { export const registerFile = async (params: NewFileParams) => {
await db.transaction(async (tx) => { await db.transaction(
const meks = await tx async (tx) => {
.select() const meks = await tx
.from(mek) .select({ version: mek.version })
.where(and(eq(mek.userId, params.userId), eq(mek.state, "active"))); .from(mek)
if (meks[0]?.version !== params.mekVersion) { .where(and(eq(mek.userId, params.userId), eq(mek.state, "active")))
throw new Error("Invalid MEK version"); .limit(1);
} if (meks[0]?.version !== params.mekVersion) {
throw new IntegrityError("Inactive MEK version");
}
const now = new Date(); await tx.insert(file).values({
await tx.insert(file).values({ path: params.path,
path: params.path, parentId: params.parentId === "root" ? null : params.parentId,
parentId: params.parentId === "root" ? null : params.parentId, createdAt: new Date(),
createdAt: now, userId: params.userId,
userId: params.userId, mekVersion: params.mekVersion,
mekVersion: params.mekVersion, contentType: params.contentType,
contentType: params.contentType, encDek: params.encDek,
encDek: params.encDek, dekVersion: params.dekVersion,
dekVersion: params.dekVersion, encContentIv: params.encContentIv,
encContentIv: params.encContentIv, encName: { ciphertext: params.encName, iv: params.encNameIv },
encName: { ciphertext: params.encName, iv: params.encNameIv }, });
}); },
}); { behavior: "exclusive" },
);
}; };
export const getAllFilesByParent = async (userId: number, parentId: DirectoryId) => { export const getAllFilesByParent = async (userId: number, parentId: DirectoryId) => {
@@ -156,8 +174,7 @@ export const getAllFilesByParent = async (userId: number, parentId: DirectoryId)
eq(file.userId, userId), eq(file.userId, userId),
parentId === "root" ? isNull(file.parentId) : eq(file.parentId, parentId), parentId === "root" ? isNull(file.parentId) : eq(file.parentId, parentId),
), ),
) );
.execute();
}; };
export const getFile = async (userId: number, fileId: number) => { export const getFile = async (userId: number, fileId: number) => {
@@ -165,7 +182,7 @@ export const getFile = async (userId: number, fileId: number) => {
.select() .select()
.from(file) .from(file)
.where(and(eq(file.userId, userId), eq(file.id, fileId))) .where(and(eq(file.userId, userId), eq(file.id, fileId)))
.execute(); .limit(1);
return res[0] ?? null; return res[0] ?? null;
}; };
@@ -176,19 +193,35 @@ export const setFileEncName = async (
encName: string, encName: string,
encNameIv: string, encNameIv: string,
) => { ) => {
const res = await db await db.transaction(
.update(file) async (tx) => {
.set({ encName: { ciphertext: encName, iv: encNameIv } }) const files = await tx
.where(and(eq(file.userId, userId), eq(file.id, fileId), eq(file.dekVersion, dekVersion))) .select({ version: file.dekVersion })
.execute(); .from(file)
return res.changes > 0; .where(and(eq(file.userId, userId), eq(file.id, fileId)))
.limit(1);
if (!files[0]) {
throw new IntegrityError("File not found");
} else if (files[0].version.getTime() !== dekVersion.getTime()) {
throw new IntegrityError("Invalid DEK version");
}
await tx
.update(file)
.set({ encName: { ciphertext: encName, iv: encNameIv } })
.where(and(eq(file.userId, userId), eq(file.id, fileId)));
},
{ behavior: "exclusive" },
);
}; };
export const unregisterFile = async (userId: number, fileId: number) => { export const unregisterFile = async (userId: number, fileId: number) => {
const res = await db const files = await db
.delete(file) .delete(file)
.where(and(eq(file.userId, userId), eq(file.id, fileId))) .where(and(eq(file.userId, userId), eq(file.id, fileId)))
.returning({ path: file.path }) .returning({ path: file.path });
.execute(); if (!files[0]) {
return res[0]?.path ?? null; throw new IntegrityError("File not found");
}
return files[0].path;
}; };

View File

@@ -1,5 +1,7 @@
import { SqliteError } from "better-sqlite3";
import { and, or, eq } from "drizzle-orm"; import { and, or, eq } from "drizzle-orm";
import db from "./drizzle"; import db from "./drizzle";
import { IntegrityError } from "./error";
import { mek, clientMek } from "./schema"; import { mek, clientMek } from "./schema";
export const registerInitialMek = async ( export const registerInitialMek = async (
@@ -8,22 +10,32 @@ export const registerInitialMek = async (
encMek: string, encMek: string,
encMekSig: string, encMekSig: string,
) => { ) => {
await db.transaction(async (tx) => { await db.transaction(
await tx.insert(mek).values({ async (tx) => {
userId, try {
version: 1, await tx.insert(mek).values({
createdBy, userId,
createdAt: new Date(), version: 1,
state: "active", createdBy,
}); createdAt: new Date(),
await tx.insert(clientMek).values({ state: "active",
userId, });
clientId: createdBy, await tx.insert(clientMek).values({
mekVersion: 1, userId,
encMek, clientId: createdBy,
encMekSig, mekVersion: 1,
}); encMek,
}); encMekSig,
});
} catch (e) {
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_PRIMARYKEY") {
throw new IntegrityError("MEK already registered");
}
throw e;
}
},
{ behavior: "exclusive" },
);
}; };
export const getInitialMek = async (userId: number) => { export const getInitialMek = async (userId: number) => {
@@ -31,19 +43,10 @@ export const getInitialMek = async (userId: number) => {
.select() .select()
.from(mek) .from(mek)
.where(and(eq(mek.userId, userId), eq(mek.version, 1))) .where(and(eq(mek.userId, userId), eq(mek.version, 1)))
.execute(); .limit(1);
return meks[0] ?? null; return meks[0] ?? null;
}; };
export const getActiveMekVersion = async (userId: number) => {
const meks = await db
.select({ version: mek.version })
.from(mek)
.where(and(eq(mek.userId, userId), eq(mek.state, "active")))
.execute();
return meks[0]?.version ?? null;
};
export const getAllValidClientMeks = async (userId: number, clientId: number) => { export const getAllValidClientMeks = async (userId: number, clientId: number) => {
return await db return await db
.select() .select()
@@ -55,6 +58,5 @@ export const getAllValidClientMeks = async (userId: number, clientId: number) =>
eq(clientMek.clientId, clientId), eq(clientMek.clientId, clientId),
or(eq(mek.state, "active"), eq(mek.state, "retired")), or(eq(mek.state, "active"), eq(mek.state, "retired")),
), ),
) );
.execute();
}; };

View File

@@ -2,6 +2,7 @@ import { SqliteError } from "better-sqlite3";
import { and, eq, gt, lte } from "drizzle-orm"; import { and, eq, gt, lte } from "drizzle-orm";
import env from "$lib/server/loadenv"; import env from "$lib/server/loadenv";
import db from "./drizzle"; import db from "./drizzle";
import { IntegrityError } from "./error";
import { refreshToken, tokenUpgradeChallenge } from "./schema"; import { refreshToken, tokenUpgradeChallenge } from "./schema";
const expiresAt = () => new Date(Date.now() + env.jwt.refreshExp); const expiresAt = () => new Date(Date.now() + env.jwt.refreshExp);
@@ -12,44 +13,45 @@ export const registerRefreshToken = async (
tokenId: string, tokenId: string,
) => { ) => {
try { try {
await db await db.insert(refreshToken).values({
.insert(refreshToken) id: tokenId,
.values({ userId,
id: tokenId, clientId,
userId, expiresAt: expiresAt(),
clientId, });
expiresAt: expiresAt(),
})
.execute();
return true;
} catch (e) { } catch (e) {
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") { if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
return false; throw new IntegrityError("Refresh token already registered");
} }
throw e; throw e;
} }
}; };
export const getRefreshToken = async (tokenId: string) => { export const getRefreshToken = async (tokenId: string) => {
const tokens = await db.select().from(refreshToken).where(eq(refreshToken.id, tokenId)).execute(); const tokens = await db.select().from(refreshToken).where(eq(refreshToken.id, tokenId)).limit(1);
return tokens[0] ?? null; return tokens[0] ?? null;
}; };
export const rotateRefreshToken = async (oldTokenId: string, newTokenId: string) => { export const rotateRefreshToken = async (oldTokenId: string, newTokenId: string) => {
return await db.transaction(async (tx) => { await db.transaction(
await tx async (tx) => {
.delete(tokenUpgradeChallenge) await tx
.where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId)); .delete(tokenUpgradeChallenge)
const res = await db .where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId));
.update(refreshToken)
.set({ const res = await tx
id: newTokenId, .update(refreshToken)
expiresAt: expiresAt(), .set({
}) id: newTokenId,
.where(eq(refreshToken.id, oldTokenId)) expiresAt: expiresAt(),
.execute(); })
return res.changes > 0; .where(eq(refreshToken.id, oldTokenId));
}); if (res.changes === 0) {
throw new IntegrityError("Refresh token not found");
}
},
{ behavior: "exclusive" },
);
}; };
export const upgradeRefreshToken = async ( export const upgradeRefreshToken = async (
@@ -57,29 +59,34 @@ export const upgradeRefreshToken = async (
newTokenId: string, newTokenId: string,
clientId: number, clientId: number,
) => { ) => {
return await db.transaction(async (tx) => { await db.transaction(
await tx async (tx) => {
.delete(tokenUpgradeChallenge) await tx
.where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId)); .delete(tokenUpgradeChallenge)
const res = await tx .where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId));
.update(refreshToken)
.set({ const res = await tx
id: newTokenId, .update(refreshToken)
clientId, .set({
expiresAt: expiresAt(), id: newTokenId,
}) clientId,
.where(eq(refreshToken.id, oldTokenId)) expiresAt: expiresAt(),
.execute(); })
return res.changes > 0; .where(eq(refreshToken.id, oldTokenId));
}); if (res.changes === 0) {
throw new IntegrityError("Refresh token not found");
}
},
{ behavior: "exclusive" },
);
}; };
export const revokeRefreshToken = async (tokenId: string) => { export const revokeRefreshToken = async (tokenId: string) => {
await db.delete(refreshToken).where(eq(refreshToken.id, tokenId)).execute(); await db.delete(refreshToken).where(eq(refreshToken.id, tokenId));
}; };
export const cleanupExpiredRefreshTokens = async () => { export const cleanupExpiredRefreshTokens = async () => {
await db.delete(refreshToken).where(lte(refreshToken.expiresAt, new Date())).execute(); await db.delete(refreshToken).where(lte(refreshToken.expiresAt, new Date()));
}; };
export const registerTokenUpgradeChallenge = async ( export const registerTokenUpgradeChallenge = async (
@@ -89,16 +96,13 @@ export const registerTokenUpgradeChallenge = async (
allowedIp: string, allowedIp: string,
expiresAt: Date, expiresAt: Date,
) => { ) => {
await db await db.insert(tokenUpgradeChallenge).values({
.insert(tokenUpgradeChallenge) refreshTokenId: tokenId,
.values({ clientId,
refreshTokenId: tokenId, answer,
clientId, allowedIp,
answer, expiresAt,
allowedIp, });
expiresAt,
})
.execute();
}; };
export const getTokenUpgradeChallenge = async (answer: string, ip: string) => { export const getTokenUpgradeChallenge = async (answer: string, ip: string) => {
@@ -113,7 +117,7 @@ export const getTokenUpgradeChallenge = async (answer: string, ip: string) => {
eq(tokenUpgradeChallenge.isUsed, false), eq(tokenUpgradeChallenge.isUsed, false),
), ),
) )
.execute(); .limit(1);
return challenges[0] ?? null; return challenges[0] ?? null;
}; };
@@ -121,13 +125,9 @@ export const markTokenUpgradeChallengeAsUsed = async (id: number) => {
await db await db
.update(tokenUpgradeChallenge) .update(tokenUpgradeChallenge)
.set({ isUsed: true }) .set({ isUsed: true })
.where(eq(tokenUpgradeChallenge.id, id)) .where(eq(tokenUpgradeChallenge.id, id));
.execute();
}; };
export const cleanupExpiredTokenUpgradeChallenges = async () => { export const cleanupExpiredTokenUpgradeChallenges = async () => {
await db await db.delete(tokenUpgradeChallenge).where(lte(tokenUpgradeChallenge.expiresAt, new Date()));
.delete(tokenUpgradeChallenge)
.where(lte(tokenUpgradeChallenge.expiresAt, new Date()))
.execute();
}; };

View File

@@ -3,6 +3,6 @@ import db from "./drizzle";
import { user } from "./schema"; import { user } from "./schema";
export const getUserByEmail = async (email: string) => { export const getUserByEmail = async (email: string) => {
const users = await db.select().from(user).where(eq(user.email, email)).execute(); const users = await db.select().from(user).where(eq(user.email, email)).limit(1);
return users[0] ?? null; return users[0] ?? null;
}; };

View File

@@ -4,9 +4,10 @@ import { v4 as uuidv4 } from "uuid";
import { getClient, getClientByPubKeys, getUserClient } from "$lib/server/db/client"; import { getClient, getClientByPubKeys, getUserClient } from "$lib/server/db/client";
import { getUserByEmail } from "$lib/server/db/user"; import { getUserByEmail } from "$lib/server/db/user";
import env from "$lib/server/loadenv"; import env from "$lib/server/loadenv";
import { IntegrityError } from "$lib/server/db/error";
import { import {
getRefreshToken,
registerRefreshToken, registerRefreshToken,
getRefreshToken,
rotateRefreshToken, rotateRefreshToken,
upgradeRefreshToken, upgradeRefreshToken,
revokeRefreshToken, revokeRefreshToken,
@@ -29,10 +30,15 @@ const issueRefreshToken = async (userId: number, clientId?: number) => {
const jti = uuidv4(); const jti = uuidv4();
const token = issueToken({ type: "refresh", jti }); const token = issueToken({ type: "refresh", jti });
if (!(await registerRefreshToken(userId, clientId ?? null, jti))) { try {
error(403, "Already logged in"); await registerRefreshToken(userId, clientId ?? null, jti);
return token;
} catch (e) {
if (e instanceof IntegrityError && e.message === "Refresh token already registered") {
error(409, "Already logged in");
}
throw e;
} }
return token;
}; };
export const login = async (email: string, password: string) => { export const login = async (email: string, password: string) => {
@@ -57,7 +63,7 @@ const verifyRefreshToken = async (refreshToken: string) => {
const tokenData = await getRefreshToken(tokenPayload.jti); const tokenData = await getRefreshToken(tokenPayload.jti);
if (!tokenData) { if (!tokenData) {
error(500, "Refresh token not found"); error(500, "Invalid refresh token");
} }
return { return {
@@ -76,13 +82,18 @@ export const refreshToken = async (refreshToken: string) => {
const { jti: oldJti, userId, clientId } = await verifyRefreshToken(refreshToken); const { jti: oldJti, userId, clientId } = await verifyRefreshToken(refreshToken);
const newJti = uuidv4(); const newJti = uuidv4();
if (!(await rotateRefreshToken(oldJti, newJti))) { try {
error(500, "Refresh token not found"); await rotateRefreshToken(oldJti, newJti);
return {
accessToken: issueAccessToken(userId, clientId),
refreshToken: issueToken({ type: "refresh", jti: newJti }),
};
} catch (e) {
if (e instanceof IntegrityError && e.message === "Refresh token not found") {
error(500, "Invalid refresh token");
}
throw e;
} }
return {
accessToken: issueAccessToken(userId, clientId),
refreshToken: issueToken({ type: "refresh", jti: newJti }),
};
}; };
const expiresAt = () => new Date(Date.now() + env.challenge.tokenUpgradeExp); const expiresAt = () => new Date(Date.now() + env.challenge.tokenUpgradeExp);
@@ -120,7 +131,7 @@ export const createTokenUpgradeChallenge = async (
if (!client) { if (!client) {
error(401, "Invalid public key(s)"); error(401, "Invalid public key(s)");
} else if (!userClient || userClient.state === "challenging") { } else if (!userClient || userClient.state === "challenging") {
error(401, "Unregistered client"); error(403, "Unregistered client");
} }
return { challenge: await createChallenge(ip, jti, client.id, encPubKey) }; return { challenge: await createChallenge(ip, jti, client.id, encPubKey) };
@@ -139,26 +150,31 @@ export const upgradeToken = async (
const challenge = await getTokenUpgradeChallenge(answer, ip); const challenge = await getTokenUpgradeChallenge(answer, ip);
if (!challenge) { if (!challenge) {
error(401, "Invalid challenge answer"); error(403, "Invalid challenge answer");
} else if (challenge.refreshTokenId !== oldJti) { } else if (challenge.refreshTokenId !== oldJti) {
error(403, "Forbidden"); error(403, "Forbidden");
} }
await markTokenUpgradeChallengeAsUsed(challenge.id);
const client = await getClient(challenge.clientId); const client = await getClient(challenge.clientId);
if (!client) { if (!client) {
error(500, "Invalid challenge answer"); error(500, "Invalid challenge answer");
} else if (!verifySignature(Buffer.from(answer, "base64"), answerSig, client.sigPubKey)) { } else if (!verifySignature(Buffer.from(answer, "base64"), answerSig, client.sigPubKey)) {
error(401, "Invalid challenge answer signature"); error(403, "Invalid challenge answer signature");
} }
await markTokenUpgradeChallengeAsUsed(challenge.id); try {
const newJti = uuidv4();
const newJti = uuidv4(); await upgradeRefreshToken(oldJti, newJti, client.id);
if (!(await upgradeRefreshToken(oldJti, newJti, client.id))) { return {
error(500, "Refresh token not found"); accessToken: issueAccessToken(userId, client.id),
refreshToken: issueToken({ type: "refresh", jti: newJti }),
};
} catch (e) {
if (e instanceof IntegrityError && e.message === "Refresh token not found") {
error(500, "Invalid refresh token");
}
throw e;
} }
return {
accessToken: issueAccessToken(userId, client.id),
refreshToken: issueToken({ type: "refresh", jti: newJti }),
};
}; };

View File

@@ -3,7 +3,6 @@ import {
createClient, createClient,
getClient, getClient,
getClientByPubKeys, getClientByPubKeys,
countClientByPubKey,
createUserClient, createUserClient,
getAllUserClients, getAllUserClients,
getUserClient, getUserClient,
@@ -12,6 +11,7 @@ import {
getUserClientChallenge, getUserClientChallenge,
markUserClientChallengeAsUsed, markUserClientChallengeAsUsed,
} from "$lib/server/db/client"; } from "$lib/server/db/client";
import { IntegrityError } from "$lib/server/db/error";
import { verifyPubKey, verifySignature, generateChallenge } from "$lib/server/modules/crypto"; import { verifyPubKey, verifySignature, generateChallenge } from "$lib/server/modules/crypto";
import { isInitialMekNeeded } from "$lib/server/modules/mek"; import { isInitialMekNeeded } from "$lib/server/modules/mek";
import env from "$lib/server/loadenv"; import env from "$lib/server/loadenv";
@@ -29,8 +29,8 @@ export const getUserClientList = async (userId: number) => {
const expiresAt = () => new Date(Date.now() + env.challenge.userClientExp); const expiresAt = () => new Date(Date.now() + env.challenge.userClientExp);
const createUserClientChallenge = async ( const createUserClientChallenge = async (
userId: number,
ip: string, ip: string,
userId: number,
clientId: number, clientId: number,
encPubKey: string, encPubKey: string,
) => { ) => {
@@ -45,33 +45,59 @@ export const registerUserClient = async (
encPubKey: string, encPubKey: string,
sigPubKey: string, sigPubKey: string,
) => { ) => {
let clientId;
const client = await getClientByPubKeys(encPubKey, sigPubKey); const client = await getClientByPubKeys(encPubKey, sigPubKey);
if (client) { if (client) {
const userClient = await getUserClient(userId, client.id); try {
if (userClient) { await createUserClient(userId, client.id);
error(409, "Client already registered"); return { challenge: await createUserClientChallenge(ip, userId, client.id, encPubKey) };
} catch (e) {
if (e instanceof IntegrityError && e.message === "User client already exists") {
error(409, "Client already registered");
}
throw e;
} }
await createUserClient(userId, client.id);
clientId = client.id;
} else { } else {
if (!verifyPubKey(encPubKey) || !verifyPubKey(sigPubKey)) { if (encPubKey === sigPubKey) {
error(400, "Same public keys");
} else if (!verifyPubKey(encPubKey) || !verifyPubKey(sigPubKey)) {
error(400, "Invalid public key(s)"); error(400, "Invalid public key(s)");
} else if (encPubKey === sigPubKey) {
error(400, "Public keys must be different");
} else if (
(await countClientByPubKey(encPubKey)) > 0 ||
(await countClientByPubKey(sigPubKey)) > 0
) {
error(409, "Public key(s) already registered");
} }
clientId = await createClient(encPubKey, sigPubKey, userId); try {
const clientId = await createClient(encPubKey, sigPubKey, userId);
return { challenge: await createUserClientChallenge(ip, userId, clientId, encPubKey) };
} catch (e) {
if (e instanceof IntegrityError && e.message === "Public key(s) already registered") {
error(409, "Public key(s) already used");
}
throw e;
}
}
};
export const verifyUserClient = async (
userId: number,
ip: string,
answer: string,
answerSig: string,
) => {
const challenge = await getUserClientChallenge(answer, ip);
if (!challenge) {
error(403, "Invalid challenge answer");
} else if (challenge.userId !== userId) {
error(403, "Forbidden");
} }
return { challenge: await createUserClientChallenge(userId, ip, clientId, encPubKey) }; await markUserClientChallengeAsUsed(challenge.id);
const client = await getClient(challenge.clientId);
if (!client) {
error(500, "Invalid challenge answer");
} else if (!verifySignature(Buffer.from(answer, "base64"), answerSig, client.sigPubKey)) {
error(403, "Invalid challenge answer signature");
}
await setUserClientStateToPending(userId, challenge.clientId);
}; };
export const getUserClientStatus = async (userId: number, clientId: number) => { export const getUserClientStatus = async (userId: number, clientId: number) => {
@@ -85,27 +111,3 @@ export const getUserClientStatus = async (userId: number, clientId: number) => {
isInitialMekNeeded: await isInitialMekNeeded(userId), isInitialMekNeeded: await isInitialMekNeeded(userId),
}; };
}; };
export const verifyUserClient = async (
userId: number,
ip: string,
answer: string,
answerSig: string,
) => {
const challenge = await getUserClientChallenge(answer, ip);
if (!challenge) {
error(401, "Invalid challenge answer");
} else if (challenge.userId !== userId) {
error(403, "Forbidden");
}
const client = await getClient(challenge.clientId);
if (!client) {
error(500, "Invalid challenge answer");
} else if (!verifySignature(Buffer.from(answer, "base64"), answerSig, client.sigPubKey)) {
error(401, "Invalid challenge answer signature");
}
await markUserClientChallengeAsUsed(challenge.id);
await setUserClientStateToPending(userId, challenge.clientId);
};

View File

@@ -1,44 +1,15 @@
import { error } from "@sveltejs/kit"; import { error } from "@sveltejs/kit";
import { unlink } from "fs/promises"; import { unlink } from "fs/promises";
import { IntegrityError } from "$lib/server/db/error";
import { import {
registerDirectory,
getAllDirectoriesByParent, getAllDirectoriesByParent,
registerNewDirectory,
getDirectory, getDirectory,
setDirectoryEncName, setDirectoryEncName,
unregisterDirectory, unregisterDirectory,
getAllFilesByParent, getAllFilesByParent,
type NewDirectoryParams, type NewDirectoryParams,
} from "$lib/server/db/file"; } from "$lib/server/db/file";
import { getActiveMekVersion } from "$lib/server/db/mek";
export const deleteDirectory = async (userId: number, directoryId: number) => {
const directory = await getDirectory(userId, directoryId);
if (!directory) {
error(404, "Invalid directory id");
}
const filePaths = await unregisterDirectory(userId, directoryId);
filePaths.map((path) => unlink(path)); // Intended
};
export const renameDirectory = async (
userId: number,
directoryId: number,
dekVersion: Date,
newEncName: string,
newEncNameIv: string,
) => {
const directory = await getDirectory(userId, directoryId);
if (!directory) {
error(404, "Invalid directory id");
} else if (directory.dekVersion.getTime() !== dekVersion.getTime()) {
error(400, "Invalid DEK version");
}
if (!(await setDirectoryEncName(userId, directoryId, dekVersion, newEncName, newEncNameIv))) {
error(500, "Invalid directory id or DEK version");
}
};
export const getDirectoryInformation = async (userId: number, directoryId: "root" | number) => { export const getDirectoryInformation = async (userId: number, directoryId: "root" | number) => {
const directory = directoryId !== "root" ? await getDirectory(userId, directoryId) : undefined; const directory = directoryId !== "root" ? await getDirectory(userId, directoryId) : undefined;
@@ -62,19 +33,52 @@ export const getDirectoryInformation = async (userId: number, directoryId: "root
}; };
}; };
export const createDirectory = async (params: NewDirectoryParams) => { export const deleteDirectory = async (userId: number, directoryId: number) => {
const activeMekVersion = await getActiveMekVersion(params.userId); try {
if (activeMekVersion === null) { const filePaths = await unregisterDirectory(userId, directoryId);
error(500, "Invalid MEK version"); filePaths.map((path) => unlink(path)); // Intended
} else if (activeMekVersion !== params.mekVersion) { } catch (e) {
error(400, "Invalid MEK version"); if (e instanceof IntegrityError && e.message === "Directory not found") {
error(404, "Invalid directory id");
}
throw e;
} }
};
export const renameDirectory = async (
userId: number,
directoryId: number,
dekVersion: Date,
newEncName: string,
newEncNameIv: string,
) => {
try {
await setDirectoryEncName(userId, directoryId, dekVersion, newEncName, newEncNameIv);
} catch (e) {
if (e instanceof IntegrityError) {
if (e.message === "Directory not found") {
error(404, "Invalid directory id");
} else if (e.message === "Invalid DEK version") {
error(400, "Invalid DEK version");
}
}
throw e;
}
};
export const createDirectory = async (params: NewDirectoryParams) => {
const oneMinuteAgo = new Date(Date.now() - 60 * 1000); const oneMinuteAgo = new Date(Date.now() - 60 * 1000);
const oneMinuteLater = new Date(Date.now() + 60 * 1000); const oneMinuteLater = new Date(Date.now() + 60 * 1000);
if (params.dekVersion <= oneMinuteAgo || params.dekVersion >= oneMinuteLater) { if (params.dekVersion <= oneMinuteAgo || params.dekVersion >= oneMinuteLater) {
error(400, "Invalid DEK version"); error(400, "Invalid DEK version");
} }
await registerNewDirectory(params); try {
await registerDirectory(params);
} catch (e) {
if (e instanceof IntegrityError && e.message === "Inactive MEK version") {
error(400, "Invalid MEK version");
}
throw e;
}
}; };

View File

@@ -1,77 +1,19 @@
import { error } from "@sveltejs/kit"; import { error } from "@sveltejs/kit";
import { createReadStream, createWriteStream, ReadStream, WriteStream } from "fs"; import { createReadStream, createWriteStream } from "fs";
import { mkdir, stat, unlink } from "fs/promises"; import { mkdir, stat, unlink } from "fs/promises";
import { dirname } from "path"; import { dirname } from "path";
import { Readable, Writable } from "stream";
import { v4 as uuidv4 } from "uuid"; import { v4 as uuidv4 } from "uuid";
import { IntegrityError } from "$lib/server/db/error";
import { import {
registerNewFile, registerFile,
getFile, getFile,
setFileEncName, setFileEncName,
unregisterFile, unregisterFile,
type NewFileParams, type NewFileParams,
} from "$lib/server/db/file"; } from "$lib/server/db/file";
import { getActiveMekVersion } from "$lib/server/db/mek";
import env from "$lib/server/loadenv"; import env from "$lib/server/loadenv";
export const deleteFile = async (userId: number, fileId: number) => {
const file = await getFile(userId, fileId);
if (!file) {
error(404, "Invalid file id");
}
const path = await unregisterFile(userId, fileId);
if (!path) {
error(500, "Invalid file id");
}
unlink(path); // Intended
};
const convertToReadableStream = (readStream: ReadStream) => {
return new ReadableStream<Uint8Array>({
start: (controller) => {
readStream.on("data", (chunk) => controller.enqueue(new Uint8Array(chunk as Buffer)));
readStream.on("end", () => controller.close());
readStream.on("error", (e) => controller.error(e));
},
cancel: () => {
readStream.destroy();
},
});
};
export const getFileStream = async (userId: number, fileId: number) => {
const file = await getFile(userId, fileId);
if (!file) {
error(404, "Invalid file id");
}
const { size } = await stat(file.path);
return {
encContentStream: convertToReadableStream(createReadStream(file.path)),
encContentSize: size,
};
};
export const renameFile = async (
userId: number,
fileId: number,
dekVersion: Date,
newEncName: string,
newEncNameIv: string,
) => {
const file = await getFile(userId, fileId);
if (!file) {
error(404, "Invalid file id");
} else if (file.dekVersion.getTime() !== dekVersion.getTime()) {
error(400, "Invalid DEK version");
}
if (!(await setFileEncName(userId, fileId, dekVersion, newEncName, newEncNameIv))) {
error(500, "Invalid file id or DEK version");
}
};
export const getFileInformation = async (userId: number, fileId: number) => { export const getFileInformation = async (userId: number, fileId: number) => {
const file = await getFile(userId, fileId); const file = await getFile(userId, fileId);
if (!file) { if (!file) {
@@ -89,20 +31,50 @@ export const getFileInformation = async (userId: number, fileId: number) => {
}; };
}; };
const convertToWritableStream = (writeStream: WriteStream) => { export const deleteFile = async (userId: number, fileId: number) => {
return new WritableStream<Uint8Array>({ try {
write: (chunk) => const filePath = await unregisterFile(userId, fileId);
new Promise((resolve, reject) => { unlink(filePath); // Intended
writeStream.write(chunk, (e) => { } catch (e) {
if (e) { if (e instanceof IntegrityError && e.message === "File not found") {
reject(e); error(404, "Invalid file id");
} else { }
resolve(); throw e;
} }
}); };
}),
close: () => new Promise((resolve) => writeStream.end(resolve)), export const getFileStream = async (userId: number, fileId: number) => {
}); const file = await getFile(userId, fileId);
if (!file) {
error(404, "Invalid file id");
}
const { size } = await stat(file.path);
return {
encContentStream: Readable.toWeb(createReadStream(file.path)),
encContentSize: size,
};
};
export const renameFile = async (
userId: number,
fileId: number,
dekVersion: Date,
newEncName: string,
newEncNameIv: string,
) => {
try {
await setFileEncName(userId, fileId, dekVersion, newEncName, newEncNameIv);
} catch (e) {
if (e instanceof IntegrityError) {
if (e.message === "File not found") {
error(404, "Invalid file id");
} else if (e.message === "Invalid DEK version") {
error(400, "Invalid DEK version");
}
}
throw e;
}
}; };
const safeUnlink = async (path: string) => { const safeUnlink = async (path: string) => {
@@ -113,13 +85,6 @@ export const uploadFile = async (
params: Omit<NewFileParams, "path">, params: Omit<NewFileParams, "path">,
encContentStream: ReadableStream<Uint8Array>, encContentStream: ReadableStream<Uint8Array>,
) => { ) => {
const activeMekVersion = await getActiveMekVersion(params.userId);
if (activeMekVersion === null) {
error(500, "Invalid MEK version");
} else if (activeMekVersion !== params.mekVersion) {
error(400, "Invalid MEK version");
}
const oneMinuteAgo = new Date(Date.now() - 60 * 1000); const oneMinuteAgo = new Date(Date.now() - 60 * 1000);
const oneMinuteLater = new Date(Date.now() + 60 * 1000); const oneMinuteLater = new Date(Date.now() + 60 * 1000);
if (params.dekVersion <= oneMinuteAgo || params.dekVersion >= oneMinuteLater) { if (params.dekVersion <= oneMinuteAgo || params.dekVersion >= oneMinuteLater) {
@@ -131,14 +96,20 @@ export const uploadFile = async (
try { try {
await encContentStream.pipeTo( await encContentStream.pipeTo(
convertToWritableStream(createWriteStream(path, { flags: "wx", mode: 0o600 })), Writable.toWeb(createWriteStream(path, { flags: "wx", mode: 0o600 })),
); );
await registerNewFile({ await registerFile({
...params, ...params,
path, path,
}); });
} catch (e) { } catch (e) {
await safeUnlink(path); await safeUnlink(path);
if (e instanceof IntegrityError) {
if (e.message === "Inactive MEK version") {
error(400, "Invalid MEK version");
}
}
throw e; throw e;
} }
}; };

View File

@@ -1,7 +1,8 @@
import { error } from "@sveltejs/kit"; import { error } from "@sveltejs/kit";
import { setUserClientStateToActive } from "$lib/server/db/client"; import { setUserClientStateToActive } from "$lib/server/db/client";
import { IntegrityError } from "$lib/server/db/error";
import { registerInitialMek, getAllValidClientMeks } from "$lib/server/db/mek"; import { registerInitialMek, getAllValidClientMeks } from "$lib/server/db/mek";
import { isInitialMekNeeded, verifyClientEncMekSig } from "$lib/server/modules/mek"; import { verifyClientEncMekSig } from "$lib/server/modules/mek";
export const getClientMekList = async (userId: number, clientId: number) => { export const getClientMekList = async (userId: number, clientId: number) => {
const clientMeks = await getAllValidClientMeks(userId, clientId); const clientMeks = await getAllValidClientMeks(userId, clientId);
@@ -21,12 +22,17 @@ export const registerInitialActiveMek = async (
encMek: string, encMek: string,
encMekSig: string, encMekSig: string,
) => { ) => {
if (!(await isInitialMekNeeded(userId))) { if (!(await verifyClientEncMekSig(userId, createdBy, 1, encMek, encMekSig))) {
error(409, "Initial MEK already registered");
} else if (!(await verifyClientEncMekSig(userId, createdBy, 1, encMek, encMekSig))) {
error(400, "Invalid signature"); error(400, "Invalid signature");
} }
await registerInitialMek(userId, createdBy, encMek, encMekSig); try {
await setUserClientStateToActive(userId, createdBy); await registerInitialMek(userId, createdBy, encMek, encMekSig);
await setUserClientStateToActive(userId, createdBy);
} catch (e) {
if (e instanceof IntegrityError && e.message === "MEK already registered") {
error(409, "Initial MEK already registered");
}
throw e;
}
}; };

View File

@@ -16,7 +16,7 @@ export const GET: RequestHandler = async ({ cookies, params }) => {
const { id } = zodRes.data; const { id } = zodRes.data;
const { encContentStream, encContentSize } = await getFileStream(userId, id); const { encContentStream, encContentSize } = await getFileStream(userId, id);
return new Response(encContentStream, { return new Response(encContentStream as ReadableStream, {
headers: { headers: {
"Content-Type": "application/octet-stream", "Content-Type": "application/octet-stream",
"Content-Length": encContentSize.toString(), "Content-Length": encContentSize.toString(),