mirror of
https://github.com/kmc7468/arkvault.git
synced 2025-12-16 15:08:46 +00:00
DB에 동시적으로 접근하더라도 데이터 무결성이 깨지지 않도록 DB 접근 코드 수정
This commit is contained in:
@@ -1,30 +1,36 @@
|
||||
import { and, or, eq, gt, lte, count } from "drizzle-orm";
|
||||
import { SqliteError } from "better-sqlite3";
|
||||
import { and, or, eq, gt, lte } from "drizzle-orm";
|
||||
import db from "./drizzle";
|
||||
import { IntegrityError } from "./error";
|
||||
import { client, userClient, userClientChallenge } from "./schema";
|
||||
|
||||
export const createClient = async (encPubKey: string, sigPubKey: string, userId: number) => {
|
||||
return await db.transaction(async (tx) => {
|
||||
const clients = await tx
|
||||
.select()
|
||||
.from(client)
|
||||
.where(or(eq(client.encPubKey, sigPubKey), eq(client.sigPubKey, encPubKey)));
|
||||
if (clients.length > 0) {
|
||||
throw new Error("Already used public key(s)");
|
||||
}
|
||||
return await db.transaction(
|
||||
async (tx) => {
|
||||
const clients = await tx
|
||||
.select({ id: client.id })
|
||||
.from(client)
|
||||
.where(or(eq(client.encPubKey, sigPubKey), eq(client.sigPubKey, encPubKey)))
|
||||
.limit(1);
|
||||
if (clients.length !== 0) {
|
||||
throw new IntegrityError("Public key(s) already registered");
|
||||
}
|
||||
|
||||
const insertRes = await tx
|
||||
.insert(client)
|
||||
.values({ encPubKey, sigPubKey })
|
||||
.returning({ id: client.id });
|
||||
const { id: clientId } = insertRes[0]!;
|
||||
await tx.insert(userClient).values({ userId, clientId });
|
||||
const newClients = await tx
|
||||
.insert(client)
|
||||
.values({ encPubKey, sigPubKey })
|
||||
.returning({ id: client.id });
|
||||
const { id: clientId } = newClients[0]!;
|
||||
await tx.insert(userClient).values({ userId, clientId });
|
||||
|
||||
return clientId;
|
||||
});
|
||||
return clientId;
|
||||
},
|
||||
{ behavior: "exclusive" },
|
||||
);
|
||||
};
|
||||
|
||||
export const getClient = async (clientId: number) => {
|
||||
const clients = await db.select().from(client).where(eq(client.id, clientId)).execute();
|
||||
const clients = await db.select().from(client).where(eq(client.id, clientId)).limit(1);
|
||||
return clients[0] ?? null;
|
||||
};
|
||||
|
||||
@@ -33,24 +39,23 @@ export const getClientByPubKeys = async (encPubKey: string, sigPubKey: string) =
|
||||
.select()
|
||||
.from(client)
|
||||
.where(and(eq(client.encPubKey, encPubKey), eq(client.sigPubKey, sigPubKey)))
|
||||
.execute();
|
||||
.limit(1);
|
||||
return clients[0] ?? null;
|
||||
};
|
||||
|
||||
export const countClientByPubKey = async (pubKey: string) => {
|
||||
const clients = await db
|
||||
.select({ count: count() })
|
||||
.from(client)
|
||||
.where(or(eq(client.encPubKey, pubKey), eq(client.encPubKey, pubKey)));
|
||||
return clients[0]?.count ?? 0;
|
||||
};
|
||||
|
||||
export const createUserClient = async (userId: number, clientId: number) => {
|
||||
await db.insert(userClient).values({ userId, clientId }).execute();
|
||||
try {
|
||||
await db.insert(userClient).values({ userId, clientId });
|
||||
} catch (e) {
|
||||
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_PRIMARYKEY") {
|
||||
throw new IntegrityError("User client already exists");
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
export const getAllUserClients = async (userId: number) => {
|
||||
return await db.select().from(userClient).where(eq(userClient.userId, userId)).execute();
|
||||
return await db.select().from(userClient).where(eq(userClient.userId, userId));
|
||||
};
|
||||
|
||||
export const getUserClient = async (userId: number, clientId: number) => {
|
||||
@@ -58,7 +63,7 @@ export const getUserClient = async (userId: number, clientId: number) => {
|
||||
.select()
|
||||
.from(userClient)
|
||||
.where(and(eq(userClient.userId, userId), eq(userClient.clientId, clientId)))
|
||||
.execute();
|
||||
.limit(1);
|
||||
return userClients[0] ?? null;
|
||||
};
|
||||
|
||||
@@ -68,7 +73,7 @@ export const getUserClientWithDetails = async (userId: number, clientId: number)
|
||||
.from(userClient)
|
||||
.innerJoin(client, eq(userClient.clientId, client.id))
|
||||
.where(and(eq(userClient.userId, userId), eq(userClient.clientId, clientId)))
|
||||
.execute();
|
||||
.limit(1);
|
||||
return userClients[0] ?? null;
|
||||
};
|
||||
|
||||
@@ -82,8 +87,7 @@ export const setUserClientStateToPending = async (userId: number, clientId: numb
|
||||
eq(userClient.clientId, clientId),
|
||||
eq(userClient.state, "challenging"),
|
||||
),
|
||||
)
|
||||
.execute();
|
||||
);
|
||||
};
|
||||
|
||||
export const setUserClientStateToActive = async (userId: number, clientId: number) => {
|
||||
@@ -96,8 +100,7 @@ export const setUserClientStateToActive = async (userId: number, clientId: numbe
|
||||
eq(userClient.clientId, clientId),
|
||||
eq(userClient.state, "pending"),
|
||||
),
|
||||
)
|
||||
.execute();
|
||||
);
|
||||
};
|
||||
|
||||
export const registerUserClientChallenge = async (
|
||||
@@ -107,16 +110,13 @@ export const registerUserClientChallenge = async (
|
||||
allowedIp: string,
|
||||
expiresAt: Date,
|
||||
) => {
|
||||
await db
|
||||
.insert(userClientChallenge)
|
||||
.values({
|
||||
userId,
|
||||
clientId,
|
||||
answer,
|
||||
allowedIp,
|
||||
expiresAt,
|
||||
})
|
||||
.execute();
|
||||
await db.insert(userClientChallenge).values({
|
||||
userId,
|
||||
clientId,
|
||||
answer,
|
||||
allowedIp,
|
||||
expiresAt,
|
||||
});
|
||||
};
|
||||
|
||||
export const getUserClientChallenge = async (answer: string, ip: string) => {
|
||||
@@ -131,21 +131,14 @@ export const getUserClientChallenge = async (answer: string, ip: string) => {
|
||||
eq(userClientChallenge.isUsed, false),
|
||||
),
|
||||
)
|
||||
.execute();
|
||||
.limit(1);
|
||||
return challenges[0] ?? null;
|
||||
};
|
||||
|
||||
export const markUserClientChallengeAsUsed = async (id: number) => {
|
||||
await db
|
||||
.update(userClientChallenge)
|
||||
.set({ isUsed: true })
|
||||
.where(eq(userClientChallenge.id, id))
|
||||
.execute();
|
||||
await db.update(userClientChallenge).set({ isUsed: true }).where(eq(userClientChallenge.id, id));
|
||||
};
|
||||
|
||||
export const cleanupExpiredUserClientChallenges = async () => {
|
||||
await db
|
||||
.delete(userClientChallenge)
|
||||
.where(lte(userClientChallenge.expiresAt, new Date()))
|
||||
.execute();
|
||||
await db.delete(userClientChallenge).where(lte(userClientChallenge.expiresAt, new Date()));
|
||||
};
|
||||
|
||||
21
src/lib/server/db/error.ts
Normal file
21
src/lib/server/db/error.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
type IntegrityErrorMessages =
|
||||
// Client
|
||||
| "Public key(s) already registered"
|
||||
| "User client already exists"
|
||||
// File
|
||||
| "Directory not found"
|
||||
| "File not found"
|
||||
| "Invalid DEK version"
|
||||
// MEK
|
||||
| "MEK already registered"
|
||||
| "Inactive MEK version"
|
||||
// Token
|
||||
| "Refresh token not found"
|
||||
| "Refresh token already registered";
|
||||
|
||||
export class IntegrityError extends Error {
|
||||
constructor(public message: IntegrityErrorMessages) {
|
||||
super(message);
|
||||
this.name = "IntegrityError";
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import { and, eq, isNull } from "drizzle-orm";
|
||||
import db from "./drizzle";
|
||||
import { IntegrityError } from "./error";
|
||||
import { directory, file, mek } from "./schema";
|
||||
|
||||
type DirectoryId = "root" | number;
|
||||
@@ -27,40 +28,42 @@ export interface NewFileParams {
|
||||
encNameIv: string;
|
||||
}
|
||||
|
||||
export const registerNewDirectory = async (params: NewDirectoryParams) => {
|
||||
return await db.transaction(async (tx) => {
|
||||
const meks = await tx
|
||||
.select()
|
||||
.from(mek)
|
||||
.where(and(eq(mek.userId, params.userId), eq(mek.state, "active")));
|
||||
if (meks[0]?.version !== params.mekVersion) {
|
||||
throw new Error("Invalid MEK version");
|
||||
}
|
||||
export const registerDirectory = async (params: NewDirectoryParams) => {
|
||||
await db.transaction(
|
||||
async (tx) => {
|
||||
const meks = await tx
|
||||
.select({ version: mek.version })
|
||||
.from(mek)
|
||||
.where(and(eq(mek.userId, params.userId), eq(mek.state, "active")))
|
||||
.limit(1);
|
||||
if (meks[0]?.version !== params.mekVersion) {
|
||||
throw new IntegrityError("Inactive MEK version");
|
||||
}
|
||||
|
||||
const now = new Date();
|
||||
await tx.insert(directory).values({
|
||||
createdAt: now,
|
||||
parentId: params.parentId === "root" ? null : params.parentId,
|
||||
userId: params.userId,
|
||||
mekVersion: params.mekVersion,
|
||||
encDek: params.encDek,
|
||||
dekVersion: params.dekVersion,
|
||||
encName: { ciphertext: params.encName, iv: params.encNameIv },
|
||||
});
|
||||
});
|
||||
await tx.insert(directory).values({
|
||||
createdAt: new Date(),
|
||||
parentId: params.parentId === "root" ? null : params.parentId,
|
||||
userId: params.userId,
|
||||
mekVersion: params.mekVersion,
|
||||
encDek: params.encDek,
|
||||
dekVersion: params.dekVersion,
|
||||
encName: { ciphertext: params.encName, iv: params.encNameIv },
|
||||
});
|
||||
},
|
||||
{ behavior: "exclusive" },
|
||||
);
|
||||
};
|
||||
|
||||
export const getAllDirectoriesByParent = async (userId: number, directoryId: DirectoryId) => {
|
||||
export const getAllDirectoriesByParent = async (userId: number, parentId: DirectoryId) => {
|
||||
return await db
|
||||
.select()
|
||||
.from(directory)
|
||||
.where(
|
||||
and(
|
||||
eq(directory.userId, userId),
|
||||
directoryId === "root" ? isNull(directory.parentId) : eq(directory.parentId, directoryId),
|
||||
parentId === "root" ? isNull(directory.parentId) : eq(directory.parentId, parentId),
|
||||
),
|
||||
)
|
||||
.execute();
|
||||
);
|
||||
};
|
||||
|
||||
export const getDirectory = async (userId: number, directoryId: number) => {
|
||||
@@ -68,7 +71,7 @@ export const getDirectory = async (userId: number, directoryId: number) => {
|
||||
.select()
|
||||
.from(directory)
|
||||
.where(and(eq(directory.userId, userId), eq(directory.id, directoryId)))
|
||||
.execute();
|
||||
.limit(1);
|
||||
return res[0] ?? null;
|
||||
};
|
||||
|
||||
@@ -79,72 +82,87 @@ export const setDirectoryEncName = async (
|
||||
encName: string,
|
||||
encNameIv: string,
|
||||
) => {
|
||||
const res = await db
|
||||
.update(directory)
|
||||
.set({ encName: { ciphertext: encName, iv: encNameIv } })
|
||||
.where(
|
||||
and(
|
||||
eq(directory.userId, userId),
|
||||
eq(directory.id, directoryId),
|
||||
eq(directory.dekVersion, dekVersion),
|
||||
),
|
||||
)
|
||||
.execute();
|
||||
return res.changes > 0;
|
||||
await db.transaction(
|
||||
async (tx) => {
|
||||
const directories = await tx
|
||||
.select({ version: directory.dekVersion })
|
||||
.from(directory)
|
||||
.where(and(eq(directory.userId, userId), eq(directory.id, directoryId)))
|
||||
.limit(1);
|
||||
if (!directories[0]) {
|
||||
throw new IntegrityError("Directory not found");
|
||||
} else if (directories[0].version.getTime() !== dekVersion.getTime()) {
|
||||
throw new IntegrityError("Invalid DEK version");
|
||||
}
|
||||
|
||||
await tx
|
||||
.update(directory)
|
||||
.set({ encName: { ciphertext: encName, iv: encNameIv } })
|
||||
.where(and(eq(directory.userId, userId), eq(directory.id, directoryId)));
|
||||
},
|
||||
{ behavior: "exclusive" },
|
||||
);
|
||||
};
|
||||
|
||||
export const unregisterDirectory = async (userId: number, directoryId: number) => {
|
||||
return await db.transaction(async (tx) => {
|
||||
const getFilePaths = async (parentId: number) => {
|
||||
const files = await tx
|
||||
.select({ path: file.path })
|
||||
.from(file)
|
||||
.where(and(eq(file.userId, userId), eq(file.parentId, parentId)));
|
||||
return files.map(({ path }) => path);
|
||||
};
|
||||
const unregisterSubDirectoriesRecursively = async (directoryId: number): Promise<string[]> => {
|
||||
const subDirectories = await tx
|
||||
.select({ id: directory.id })
|
||||
.from(directory)
|
||||
.where(and(eq(directory.userId, userId), eq(directory.parentId, directoryId)));
|
||||
const subDirectoryFilePaths = await Promise.all(
|
||||
subDirectories.map(async ({ id }) => await unregisterSubDirectoriesRecursively(id)),
|
||||
);
|
||||
const filePaths = await getFilePaths(directoryId);
|
||||
return await db.transaction(
|
||||
async (tx) => {
|
||||
const unregisterFiles = async (parentId: number) => {
|
||||
const files = await tx
|
||||
.delete(file)
|
||||
.where(and(eq(file.userId, userId), eq(file.parentId, parentId)))
|
||||
.returning({ path: file.path });
|
||||
return files.map(({ path }) => path);
|
||||
};
|
||||
const unregisterDirectoryRecursively = async (directoryId: number): Promise<string[]> => {
|
||||
const filePaths = await unregisterFiles(directoryId);
|
||||
const subDirectories = await tx
|
||||
.select({ id: directory.id })
|
||||
.from(directory)
|
||||
.where(and(eq(directory.userId, userId), eq(directory.parentId, directoryId)));
|
||||
const subDirectoryFilePaths = await Promise.all(
|
||||
subDirectories.map(async ({ id }) => await unregisterDirectoryRecursively(id)),
|
||||
);
|
||||
|
||||
await tx.delete(file).where(eq(file.parentId, directoryId));
|
||||
await tx.delete(directory).where(eq(directory.id, directoryId));
|
||||
|
||||
return filePaths.concat(...subDirectoryFilePaths);
|
||||
};
|
||||
return await unregisterSubDirectoriesRecursively(directoryId);
|
||||
});
|
||||
const deleteRes = await tx.delete(directory).where(eq(directory.id, directoryId));
|
||||
if (deleteRes.changes === 0) {
|
||||
throw new IntegrityError("Directory not found");
|
||||
}
|
||||
return filePaths.concat(...subDirectoryFilePaths);
|
||||
};
|
||||
return await unregisterDirectoryRecursively(directoryId);
|
||||
},
|
||||
{ behavior: "exclusive" },
|
||||
);
|
||||
};
|
||||
|
||||
export const registerNewFile = async (params: NewFileParams) => {
|
||||
await db.transaction(async (tx) => {
|
||||
const meks = await tx
|
||||
.select()
|
||||
.from(mek)
|
||||
.where(and(eq(mek.userId, params.userId), eq(mek.state, "active")));
|
||||
if (meks[0]?.version !== params.mekVersion) {
|
||||
throw new Error("Invalid MEK version");
|
||||
}
|
||||
export const registerFile = async (params: NewFileParams) => {
|
||||
await db.transaction(
|
||||
async (tx) => {
|
||||
const meks = await tx
|
||||
.select({ version: mek.version })
|
||||
.from(mek)
|
||||
.where(and(eq(mek.userId, params.userId), eq(mek.state, "active")))
|
||||
.limit(1);
|
||||
if (meks[0]?.version !== params.mekVersion) {
|
||||
throw new IntegrityError("Inactive MEK version");
|
||||
}
|
||||
|
||||
const now = new Date();
|
||||
await tx.insert(file).values({
|
||||
path: params.path,
|
||||
parentId: params.parentId === "root" ? null : params.parentId,
|
||||
createdAt: now,
|
||||
userId: params.userId,
|
||||
mekVersion: params.mekVersion,
|
||||
contentType: params.contentType,
|
||||
encDek: params.encDek,
|
||||
dekVersion: params.dekVersion,
|
||||
encContentIv: params.encContentIv,
|
||||
encName: { ciphertext: params.encName, iv: params.encNameIv },
|
||||
});
|
||||
});
|
||||
await tx.insert(file).values({
|
||||
path: params.path,
|
||||
parentId: params.parentId === "root" ? null : params.parentId,
|
||||
createdAt: new Date(),
|
||||
userId: params.userId,
|
||||
mekVersion: params.mekVersion,
|
||||
contentType: params.contentType,
|
||||
encDek: params.encDek,
|
||||
dekVersion: params.dekVersion,
|
||||
encContentIv: params.encContentIv,
|
||||
encName: { ciphertext: params.encName, iv: params.encNameIv },
|
||||
});
|
||||
},
|
||||
{ behavior: "exclusive" },
|
||||
);
|
||||
};
|
||||
|
||||
export const getAllFilesByParent = async (userId: number, parentId: DirectoryId) => {
|
||||
@@ -156,8 +174,7 @@ export const getAllFilesByParent = async (userId: number, parentId: DirectoryId)
|
||||
eq(file.userId, userId),
|
||||
parentId === "root" ? isNull(file.parentId) : eq(file.parentId, parentId),
|
||||
),
|
||||
)
|
||||
.execute();
|
||||
);
|
||||
};
|
||||
|
||||
export const getFile = async (userId: number, fileId: number) => {
|
||||
@@ -165,7 +182,7 @@ export const getFile = async (userId: number, fileId: number) => {
|
||||
.select()
|
||||
.from(file)
|
||||
.where(and(eq(file.userId, userId), eq(file.id, fileId)))
|
||||
.execute();
|
||||
.limit(1);
|
||||
return res[0] ?? null;
|
||||
};
|
||||
|
||||
@@ -176,19 +193,35 @@ export const setFileEncName = async (
|
||||
encName: string,
|
||||
encNameIv: string,
|
||||
) => {
|
||||
const res = await db
|
||||
.update(file)
|
||||
.set({ encName: { ciphertext: encName, iv: encNameIv } })
|
||||
.where(and(eq(file.userId, userId), eq(file.id, fileId), eq(file.dekVersion, dekVersion)))
|
||||
.execute();
|
||||
return res.changes > 0;
|
||||
await db.transaction(
|
||||
async (tx) => {
|
||||
const files = await tx
|
||||
.select({ version: file.dekVersion })
|
||||
.from(file)
|
||||
.where(and(eq(file.userId, userId), eq(file.id, fileId)))
|
||||
.limit(1);
|
||||
if (!files[0]) {
|
||||
throw new IntegrityError("File not found");
|
||||
} else if (files[0].version.getTime() !== dekVersion.getTime()) {
|
||||
throw new IntegrityError("Invalid DEK version");
|
||||
}
|
||||
|
||||
await tx
|
||||
.update(file)
|
||||
.set({ encName: { ciphertext: encName, iv: encNameIv } })
|
||||
.where(and(eq(file.userId, userId), eq(file.id, fileId)));
|
||||
},
|
||||
{ behavior: "exclusive" },
|
||||
);
|
||||
};
|
||||
|
||||
export const unregisterFile = async (userId: number, fileId: number) => {
|
||||
const res = await db
|
||||
const files = await db
|
||||
.delete(file)
|
||||
.where(and(eq(file.userId, userId), eq(file.id, fileId)))
|
||||
.returning({ path: file.path })
|
||||
.execute();
|
||||
return res[0]?.path ?? null;
|
||||
.returning({ path: file.path });
|
||||
if (!files[0]) {
|
||||
throw new IntegrityError("File not found");
|
||||
}
|
||||
return files[0].path;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { SqliteError } from "better-sqlite3";
|
||||
import { and, or, eq } from "drizzle-orm";
|
||||
import db from "./drizzle";
|
||||
import { IntegrityError } from "./error";
|
||||
import { mek, clientMek } from "./schema";
|
||||
|
||||
export const registerInitialMek = async (
|
||||
@@ -8,22 +10,32 @@ export const registerInitialMek = async (
|
||||
encMek: string,
|
||||
encMekSig: string,
|
||||
) => {
|
||||
await db.transaction(async (tx) => {
|
||||
await tx.insert(mek).values({
|
||||
userId,
|
||||
version: 1,
|
||||
createdBy,
|
||||
createdAt: new Date(),
|
||||
state: "active",
|
||||
});
|
||||
await tx.insert(clientMek).values({
|
||||
userId,
|
||||
clientId: createdBy,
|
||||
mekVersion: 1,
|
||||
encMek,
|
||||
encMekSig,
|
||||
});
|
||||
});
|
||||
await db.transaction(
|
||||
async (tx) => {
|
||||
try {
|
||||
await tx.insert(mek).values({
|
||||
userId,
|
||||
version: 1,
|
||||
createdBy,
|
||||
createdAt: new Date(),
|
||||
state: "active",
|
||||
});
|
||||
await tx.insert(clientMek).values({
|
||||
userId,
|
||||
clientId: createdBy,
|
||||
mekVersion: 1,
|
||||
encMek,
|
||||
encMekSig,
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_PRIMARYKEY") {
|
||||
throw new IntegrityError("MEK already registered");
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
},
|
||||
{ behavior: "exclusive" },
|
||||
);
|
||||
};
|
||||
|
||||
export const getInitialMek = async (userId: number) => {
|
||||
@@ -31,19 +43,10 @@ export const getInitialMek = async (userId: number) => {
|
||||
.select()
|
||||
.from(mek)
|
||||
.where(and(eq(mek.userId, userId), eq(mek.version, 1)))
|
||||
.execute();
|
||||
.limit(1);
|
||||
return meks[0] ?? null;
|
||||
};
|
||||
|
||||
export const getActiveMekVersion = async (userId: number) => {
|
||||
const meks = await db
|
||||
.select({ version: mek.version })
|
||||
.from(mek)
|
||||
.where(and(eq(mek.userId, userId), eq(mek.state, "active")))
|
||||
.execute();
|
||||
return meks[0]?.version ?? null;
|
||||
};
|
||||
|
||||
export const getAllValidClientMeks = async (userId: number, clientId: number) => {
|
||||
return await db
|
||||
.select()
|
||||
@@ -55,6 +58,5 @@ export const getAllValidClientMeks = async (userId: number, clientId: number) =>
|
||||
eq(clientMek.clientId, clientId),
|
||||
or(eq(mek.state, "active"), eq(mek.state, "retired")),
|
||||
),
|
||||
)
|
||||
.execute();
|
||||
);
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@ import { SqliteError } from "better-sqlite3";
|
||||
import { and, eq, gt, lte } from "drizzle-orm";
|
||||
import env from "$lib/server/loadenv";
|
||||
import db from "./drizzle";
|
||||
import { IntegrityError } from "./error";
|
||||
import { refreshToken, tokenUpgradeChallenge } from "./schema";
|
||||
|
||||
const expiresAt = () => new Date(Date.now() + env.jwt.refreshExp);
|
||||
@@ -12,44 +13,45 @@ export const registerRefreshToken = async (
|
||||
tokenId: string,
|
||||
) => {
|
||||
try {
|
||||
await db
|
||||
.insert(refreshToken)
|
||||
.values({
|
||||
id: tokenId,
|
||||
userId,
|
||||
clientId,
|
||||
expiresAt: expiresAt(),
|
||||
})
|
||||
.execute();
|
||||
return true;
|
||||
await db.insert(refreshToken).values({
|
||||
id: tokenId,
|
||||
userId,
|
||||
clientId,
|
||||
expiresAt: expiresAt(),
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
|
||||
return false;
|
||||
throw new IntegrityError("Refresh token already registered");
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
||||
export const getRefreshToken = async (tokenId: string) => {
|
||||
const tokens = await db.select().from(refreshToken).where(eq(refreshToken.id, tokenId)).execute();
|
||||
const tokens = await db.select().from(refreshToken).where(eq(refreshToken.id, tokenId)).limit(1);
|
||||
return tokens[0] ?? null;
|
||||
};
|
||||
|
||||
export const rotateRefreshToken = async (oldTokenId: string, newTokenId: string) => {
|
||||
return await db.transaction(async (tx) => {
|
||||
await tx
|
||||
.delete(tokenUpgradeChallenge)
|
||||
.where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId));
|
||||
const res = await db
|
||||
.update(refreshToken)
|
||||
.set({
|
||||
id: newTokenId,
|
||||
expiresAt: expiresAt(),
|
||||
})
|
||||
.where(eq(refreshToken.id, oldTokenId))
|
||||
.execute();
|
||||
return res.changes > 0;
|
||||
});
|
||||
await db.transaction(
|
||||
async (tx) => {
|
||||
await tx
|
||||
.delete(tokenUpgradeChallenge)
|
||||
.where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId));
|
||||
|
||||
const res = await tx
|
||||
.update(refreshToken)
|
||||
.set({
|
||||
id: newTokenId,
|
||||
expiresAt: expiresAt(),
|
||||
})
|
||||
.where(eq(refreshToken.id, oldTokenId));
|
||||
if (res.changes === 0) {
|
||||
throw new IntegrityError("Refresh token not found");
|
||||
}
|
||||
},
|
||||
{ behavior: "exclusive" },
|
||||
);
|
||||
};
|
||||
|
||||
export const upgradeRefreshToken = async (
|
||||
@@ -57,29 +59,34 @@ export const upgradeRefreshToken = async (
|
||||
newTokenId: string,
|
||||
clientId: number,
|
||||
) => {
|
||||
return await db.transaction(async (tx) => {
|
||||
await tx
|
||||
.delete(tokenUpgradeChallenge)
|
||||
.where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId));
|
||||
const res = await tx
|
||||
.update(refreshToken)
|
||||
.set({
|
||||
id: newTokenId,
|
||||
clientId,
|
||||
expiresAt: expiresAt(),
|
||||
})
|
||||
.where(eq(refreshToken.id, oldTokenId))
|
||||
.execute();
|
||||
return res.changes > 0;
|
||||
});
|
||||
await db.transaction(
|
||||
async (tx) => {
|
||||
await tx
|
||||
.delete(tokenUpgradeChallenge)
|
||||
.where(eq(tokenUpgradeChallenge.refreshTokenId, oldTokenId));
|
||||
|
||||
const res = await tx
|
||||
.update(refreshToken)
|
||||
.set({
|
||||
id: newTokenId,
|
||||
clientId,
|
||||
expiresAt: expiresAt(),
|
||||
})
|
||||
.where(eq(refreshToken.id, oldTokenId));
|
||||
if (res.changes === 0) {
|
||||
throw new IntegrityError("Refresh token not found");
|
||||
}
|
||||
},
|
||||
{ behavior: "exclusive" },
|
||||
);
|
||||
};
|
||||
|
||||
export const revokeRefreshToken = async (tokenId: string) => {
|
||||
await db.delete(refreshToken).where(eq(refreshToken.id, tokenId)).execute();
|
||||
await db.delete(refreshToken).where(eq(refreshToken.id, tokenId));
|
||||
};
|
||||
|
||||
export const cleanupExpiredRefreshTokens = async () => {
|
||||
await db.delete(refreshToken).where(lte(refreshToken.expiresAt, new Date())).execute();
|
||||
await db.delete(refreshToken).where(lte(refreshToken.expiresAt, new Date()));
|
||||
};
|
||||
|
||||
export const registerTokenUpgradeChallenge = async (
|
||||
@@ -89,16 +96,13 @@ export const registerTokenUpgradeChallenge = async (
|
||||
allowedIp: string,
|
||||
expiresAt: Date,
|
||||
) => {
|
||||
await db
|
||||
.insert(tokenUpgradeChallenge)
|
||||
.values({
|
||||
refreshTokenId: tokenId,
|
||||
clientId,
|
||||
answer,
|
||||
allowedIp,
|
||||
expiresAt,
|
||||
})
|
||||
.execute();
|
||||
await db.insert(tokenUpgradeChallenge).values({
|
||||
refreshTokenId: tokenId,
|
||||
clientId,
|
||||
answer,
|
||||
allowedIp,
|
||||
expiresAt,
|
||||
});
|
||||
};
|
||||
|
||||
export const getTokenUpgradeChallenge = async (answer: string, ip: string) => {
|
||||
@@ -113,7 +117,7 @@ export const getTokenUpgradeChallenge = async (answer: string, ip: string) => {
|
||||
eq(tokenUpgradeChallenge.isUsed, false),
|
||||
),
|
||||
)
|
||||
.execute();
|
||||
.limit(1);
|
||||
return challenges[0] ?? null;
|
||||
};
|
||||
|
||||
@@ -121,13 +125,9 @@ export const markTokenUpgradeChallengeAsUsed = async (id: number) => {
|
||||
await db
|
||||
.update(tokenUpgradeChallenge)
|
||||
.set({ isUsed: true })
|
||||
.where(eq(tokenUpgradeChallenge.id, id))
|
||||
.execute();
|
||||
.where(eq(tokenUpgradeChallenge.id, id));
|
||||
};
|
||||
|
||||
export const cleanupExpiredTokenUpgradeChallenges = async () => {
|
||||
await db
|
||||
.delete(tokenUpgradeChallenge)
|
||||
.where(lte(tokenUpgradeChallenge.expiresAt, new Date()))
|
||||
.execute();
|
||||
await db.delete(tokenUpgradeChallenge).where(lte(tokenUpgradeChallenge.expiresAt, new Date()));
|
||||
};
|
||||
|
||||
@@ -3,6 +3,6 @@ import db from "./drizzle";
|
||||
import { user } from "./schema";
|
||||
|
||||
export const getUserByEmail = async (email: string) => {
|
||||
const users = await db.select().from(user).where(eq(user.email, email)).execute();
|
||||
const users = await db.select().from(user).where(eq(user.email, email)).limit(1);
|
||||
return users[0] ?? null;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user