Compare commits

...

11 Commits

Author SHA1 Message Date
Owen Schwartz
10f95896aa Merge pull request #3030 from fosrl/dev
1.18.3-s.2 fix
2026-05-07 20:08:05 -07:00
Owen
5b8994d143 Cange to use primaryDb 2026-05-07 20:07:06 -07:00
Owen
c46ef2fe9c Fix ts type issue 2026-05-07 20:03:48 -07:00
Owen Schwartz
4cd025dd91 Merge pull request #3029 from fosrl/dev
1.18.3-s.2
2026-05-07 17:44:35 -07:00
Owen
ce04ea9720 Fix not including today
Fixes #3028
2026-05-07 16:15:13 -07:00
Owen
a3ce382725 Pick up other domains in the sans field 2026-05-07 15:49:12 -07:00
Owen
4eb49e3e60 Make the rebuild long running function background 2026-05-07 15:40:34 -07:00
Owen
2a9481023a Dont show link when wildcard 2026-05-07 15:15:03 -07:00
Owen
8ed01372b8 Add org to logs 2026-05-07 15:14:44 -07:00
Owen Schwartz
6a7d4fd385 Merge pull request #3021 from fosrl/dev
If not exists on trial table
2026-05-06 20:00:55 -07:00
Owen
7bc08c0425 If not exists on trial table 2026-05-06 20:00:23 -07:00
26 changed files with 480 additions and 301 deletions

View File

@@ -87,7 +87,7 @@ function createDb() {
export const db = createDb(); export const db = createDb();
export default db; export default db;
export const primaryDb = db.$primary; export const primaryDb = db.$primary as typeof db; // is this typeof a problem - techincally they are different types
export type Transaction = Parameters< export type Transaction = Parameters<
Parameters<(typeof db)["transaction"]>[0] Parameters<(typeof db)["transaction"]>[0]
>[0]; >[0];

View File

@@ -25,9 +25,9 @@ import { tierMatrix } from "./billing/tierMatrix";
export async function calculateUserClientsForOrgs( export async function calculateUserClientsForOrgs(
userId: string, userId: string,
trx?: Transaction trx: Transaction | typeof db = db
): Promise<void> { ): Promise<void> {
const execute = async (transaction: Transaction) => { const execute = async (transaction: Transaction | typeof db) => {
const orgCache = new Map<string, typeof orgs.$inferSelect | null>(); const orgCache = new Map<string, typeof orgs.$inferSelect | null>();
const adminRoleCache = new Map< const adminRoleCache = new Map<
string, string,
@@ -437,7 +437,7 @@ export async function calculateUserClientsForOrgs(
async function cleanupOrphanedClients( async function cleanupOrphanedClients(
userId: string, userId: string,
trx: Transaction, trx: Transaction | typeof db,
userOrgIds: string[] = [] userOrgIds: string[] = []
): Promise<void> { ): Promise<void> {
// Find all OLM clients for this user that should be deleted // Find all OLM clients for this user that should be deleted

View File

@@ -124,7 +124,7 @@ export function computeBuckets(
let totalDowntime = 0; let totalDowntime = 0;
for (let d = 0; d < days; d++) { for (let d = 0; d < days; d++) {
const dayStartSec = todayMidnightSec - (days - d) * 86400; const dayStartSec = todayMidnightSec - (days - 1 - d) * 86400;
const dayEndSec = dayStartSec + 86400; const dayEndSec = dayStartSec + 86400;
const dayEvents = events.filter( const dayEvents = events.filter(

View File

@@ -485,6 +485,133 @@ async function syncAcmeCertsFromHttp(endpoint: string): Promise<void> {
} }
} }
async function storeCertForDomain(
domain: string,
certPem: string,
keyPem: string,
validatedX509: crypto.X509Certificate
): Promise<void> {
const wildcard = domain.startsWith("*.");
const existing = await db
.select()
.from(certificates)
.where(eq(certificates.domain, domain))
.limit(1);
let oldCertPem: string | null = null;
let oldKeyPem: string | null = null;
if (existing.length > 0 && existing[0].certFile) {
try {
const storedCertPem = decrypt(
existing[0].certFile,
config.getRawConfig().server.secret!
);
const wildcardUnchanged = existing[0].wildcard === wildcard;
if (storedCertPem === certPem && wildcardUnchanged) {
return;
}
oldCertPem = storedCertPem;
if (existing[0].keyFile) {
try {
oldKeyPem = decrypt(
existing[0].keyFile,
config.getRawConfig().server.secret!
);
} catch (keyErr) {
logger.debug(
`acmeCertSync: could not decrypt stored key for ${domain}: ${keyErr}`
);
}
}
} catch (err) {
logger.debug(
`acmeCertSync: could not decrypt stored cert for ${domain}, will update: ${err}`
);
}
}
let expiresAt: number | null = null;
try {
expiresAt = Math.floor(
new Date(validatedX509.validTo).getTime() / 1000
);
} catch (err) {
logger.debug(
`acmeCertSync: could not parse cert expiry for ${domain}: ${err}`
);
}
const encryptedCert = encrypt(
certPem,
config.getRawConfig().server.secret!
);
const encryptedKey = encrypt(keyPem, config.getRawConfig().server.secret!);
const now = Math.floor(Date.now() / 1000);
const domainId = await findDomainId(domain);
if (domainId) {
logger.debug(
`acmeCertSync: resolved domainId "${domainId}" for cert domain "${domain}"`
);
} else {
logger.debug(
`acmeCertSync: no matching domain record found for cert domain "${domain}"`
);
}
if (existing.length > 0) {
logger.debug(
`acmeCertSync: updating existing certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db
.update(certificates)
.set({
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
updatedAt: now,
wildcard,
...(domainId !== null && { domainId })
})
.where(eq(certificates.domain, domain));
logger.debug(
`acmeCertSync: updated certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await pushCertUpdateToAffectedNewts(
domain,
domainId,
oldCertPem,
oldKeyPem
);
} else {
logger.debug(
`acmeCertSync: inserting new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db.insert(certificates).values({
domain,
domainId,
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
createdAt: now,
updatedAt: now,
wildcard
});
logger.debug(
`acmeCertSync: inserted new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await pushCertUpdateToAffectedNewts(domain, domainId, null, null);
}
}
function findAcmeJsonFiles(dirPath: string): string[] { function findAcmeJsonFiles(dirPath: string): string[] {
const results: string[] = []; const results: string[] = [];
let entries: fs.Dirent[]; let entries: fs.Dirent[];
@@ -575,18 +702,16 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
} }
for (const cert of allCerts) { for (const cert of allCerts) {
const domain = cert?.domain?.main; const mainDomain = cert?.domain?.main;
if (!domain || typeof domain !== "string") { if (!mainDomain || typeof mainDomain !== "string") {
logger.debug(`acmeCertSync: skipping cert with missing domain`); logger.debug(`acmeCertSync: skipping cert with missing domain`);
continue; continue;
} }
const { wildcard } = detectWildcard(domain, cert.domain?.sans);
if (!cert.certificate || !cert.key) { if (!cert.certificate || !cert.key) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - empty certificate or key field` `acmeCertSync: skipping cert for ${mainDomain} - empty certificate or key field`
); );
continue; continue;
} }
@@ -598,14 +723,14 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
keyPem = Buffer.from(cert.key, "base64").toString("utf8"); keyPem = Buffer.from(cert.key, "base64").toString("utf8");
} catch (err) { } catch (err) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - failed to base64-decode cert/key: ${err}` `acmeCertSync: skipping cert for ${mainDomain} - failed to base64-decode cert/key: ${err}`
); );
continue; continue;
} }
if (!certPem.trim() || !keyPem.trim()) { if (!certPem.trim() || !keyPem.trim()) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - blank PEM after base64 decode` `acmeCertSync: skipping cert for ${mainDomain} - blank PEM after base64 decode`
); );
continue; continue;
} }
@@ -616,7 +741,7 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
const firstCertPemForValidation = extractFirstCert(certPem); const firstCertPemForValidation = extractFirstCert(certPem);
if (!firstCertPemForValidation) { if (!firstCertPemForValidation) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - no PEM certificate block found` `acmeCertSync: skipping cert for ${mainDomain} - no PEM certificate block found`
); );
continue; continue;
} }
@@ -628,7 +753,7 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
); );
} catch (err) { } catch (err) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - invalid X.509 certificate: ${err}` `acmeCertSync: skipping cert for ${mainDomain} - invalid X.509 certificate: ${err}`
); );
continue; continue;
} }
@@ -638,139 +763,40 @@ async function syncAcmeCerts(acmeJsonPath: string): Promise<void> {
crypto.createPrivateKey(keyPem); crypto.createPrivateKey(keyPem);
} catch (err) { } catch (err) {
logger.debug( logger.debug(
`acmeCertSync: skipping cert for ${domain} - invalid private key: ${err}` `acmeCertSync: skipping cert for ${mainDomain} - invalid private key: ${err}`
); );
continue; continue;
} }
// Check if cert already exists in DB // Collect all domains covered by this cert: main + every SAN.
const existing = await db // Each domain gets its own row in the certificates table so that
.select() // lookups by any hostname on the cert succeed independently.
.from(certificates) const allDomains = new Set<string>([mainDomain]);
.where(and(eq(certificates.domain, domain))) if (Array.isArray(cert.domain?.sans)) {
.limit(1); for (const san of cert.domain.sans) {
if (typeof san === "string" && san.trim()) {
let oldCertPem: string | null = null; allDomains.add(san.trim());
let oldKeyPem: string | null = null;
if (existing.length > 0 && existing[0].certFile) {
try {
const storedCertPem = decrypt(
existing[0].certFile,
config.getRawConfig().server.secret!
);
const wildcardUnchanged = existing[0].wildcard === wildcard;
if (storedCertPem === certPem && wildcardUnchanged) {
// logger.debug(
// `acmeCertSync: cert for ${domain} is unchanged, skipping`
// );
continue;
} }
// Cert has changed; capture old values so we can send a correct
// update message to the newt after the DB write.
oldCertPem = storedCertPem;
if (existing[0].keyFile) {
try {
oldKeyPem = decrypt(
existing[0].keyFile,
config.getRawConfig().server.secret!
);
} catch (keyErr) {
logger.debug(
`acmeCertSync: could not decrypt stored key for ${domain}: ${keyErr}`
);
}
}
} catch (err) {
// Decryption failure means we should proceed with the update
logger.debug(
`acmeCertSync: could not decrypt stored cert for ${domain}, will update: ${err}`
);
} }
} }
// Parse cert expiry from the validated X.509 certificate logger.debug(
let expiresAt: number | null = null; `acmeCertSync: cert for ${mainDomain} covers ${allDomains.size} domain(s): ${[...allDomains].join(", ")}`
try {
expiresAt = Math.floor(
new Date(validatedX509.validTo).getTime() / 1000
);
} catch (err) {
logger.debug(
`acmeCertSync: could not parse cert expiry for ${domain}: ${err}`
);
}
const encryptedCert = encrypt(
certPem,
config.getRawConfig().server.secret!
); );
const encryptedKey = encrypt(
keyPem,
config.getRawConfig().server.secret!
);
const now = Math.floor(Date.now() / 1000);
const domainId = await findDomainId(domain); for (const domain of allDomains) {
if (domainId) { try {
logger.debug( await storeCertForDomain(
`acmeCertSync: resolved domainId "${domainId}" for cert domain "${domain}"` domain,
); certPem,
} else { keyPem,
logger.debug( validatedX509
`acmeCertSync: no matching domain record found for cert domain "${domain}"` );
); } catch (err) {
} logger.error(
`acmeCertSync: error storing cert for domain "${domain}": ${err}`
if (existing.length > 0) { );
logger.debug( }
`acmeCertSync: updating existing certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db
.update(certificates)
.set({
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
updatedAt: now,
wildcard,
...(domainId !== null && { domainId })
})
.where(eq(certificates.domain, domain));
logger.debug(
`acmeCertSync: updated certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await pushCertUpdateToAffectedNewts(
domain,
domainId,
oldCertPem,
oldKeyPem
);
} else {
logger.debug(
`acmeCertSync: inserting new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
await db.insert(certificates).values({
domain,
domainId,
certFile: encryptedCert,
keyFile: encryptedKey,
status: "valid",
expiresAt,
createdAt: now,
updatedAt: now,
wildcard
});
logger.debug(
`acmeCertSync: inserted new certificate for ${domain} (expires ${expiresAt ? new Date(expiresAt * 1000).toISOString() : "unknown"})`
);
// For a brand-new cert, push to any SSL resources that were waiting for it
await pushCertUpdateToAffectedNewts(domain, domainId, null, null);
} }
} }
} }

View File

@@ -14,7 +14,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import stoi from "@server/lib/stoi"; import stoi from "@server/lib/stoi";
import { clients, db } from "@server/db"; import { clients, db, primaryDb, Client } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db"; import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -122,8 +122,12 @@ export async function addUserRole(
); );
} }
let newUserRole: { userId: string; orgId: string; roleId: number } | null = let newUserRole: {
null; userId: string;
orgId: string;
roleId: number;
} | null = null;
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const inserted = await trx const inserted = await trx
.insert(userOrgRoles) .insert(userOrgRoles)
@@ -149,11 +153,19 @@ export async function addUserRole(
) )
); );
for (const orgClient of orgClients) { orgClientsToRebuild = orgClients;
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after adding role: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: newUserRole ?? { userId, orgId: role.orgId, roleId }, data: newUserRole ?? { userId, orgId: role.orgId, roleId },
success: true, success: true,

View File

@@ -14,7 +14,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import stoi from "@server/lib/stoi"; import stoi from "@server/lib/stoi";
import { db } from "@server/db"; import { db, primaryDb, Client } from "@server/db";
import { userOrgRoles, userOrgs, roles, clients } from "@server/db"; import { userOrgRoles, userOrgs, roles, clients } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -129,6 +129,7 @@ export async function removeUserRole(
} }
} }
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userOrgRoles) .delete(userOrgRoles)
@@ -150,11 +151,19 @@ export async function removeUserRole(
) )
); );
for (const orgClient of orgClients) { orgClientsToRebuild = orgClients;
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after removing role: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: { userId, orgId: role.orgId, roleId }, data: { userId, orgId: role.orgId, roleId },
success: true, success: true,

View File

@@ -13,7 +13,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { clients, db } from "@server/db"; import { clients, db, primaryDb, Client } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db"; import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and, inArray } from "drizzle-orm"; import { eq, and, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -115,6 +115,7 @@ export async function setUserOrgRoles(
); );
} }
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userOrgRoles) .delete(userOrgRoles)
@@ -142,11 +143,19 @@ export async function setUserOrgRoles(
and(eq(clients.userId, userId), eq(clients.orgId, orgId)) and(eq(clients.userId, userId), eq(clients.orgId, orgId))
); );
for (const orgClient of orgClients) { orgClientsToRebuild = orgClients;
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after setting roles: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: { userId, orgId, roleIds: uniqueRoleIds }, data: { userId, orgId, roleIds: uniqueRoleIds },
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, orgs, userOrgs, users } from "@server/db"; import { db, orgs, userOrgs, users, primaryDb } from "@server/db";
import { eq, and, inArray, not } from "drizzle-orm"; import { eq, and, inArray, not } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -218,13 +218,18 @@ export async function deleteMyAccount(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx.delete(users).where(eq(users.userId, userId)); await trx.delete(users).where(eq(users.userId, userId));
await calculateUserClientsForOrgs(userId, trx);
// loop through the other orgs and decrement the count // loop through the other orgs and decrement the count
for (const userOrg of otherOrgsTheUserWasIn) { for (const userOrg of otherOrgsTheUserWasIn) {
await usageService.add(userOrg.orgId, FeatureId.USERS, -1, trx); await usageService.add(userOrg.orgId, FeatureId.USERS, -1, trx);
} }
}); });
calculateUserClientsForOrgs(userId, primaryDb).catch((e) => {
logger.error(
`Failed to calculate user clients after deleting account for user ${userId}: ${e}`
);
});
try { try {
await invalidateSession(session.sessionId); await invalidateSession(session.sessionId);
} catch (error) { } catch (error) {

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db, primaryDb } from "@server/db";
import { import {
roles, roles,
Client, Client,
@@ -92,7 +92,10 @@ export async function createClient(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) { if (
req.user &&
(!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)
) {
return next( return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role") createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
); );
@@ -198,7 +201,10 @@ export async function createClient(
if (!randomExitNode) { if (!randomExitNode) {
return next( return next(
createHttpError(HttpCode.NOT_FOUND, `No exit nodes available. ${build == "saas" ? "Please contact support." : "You need to install gerbil to use the clients."}`) createHttpError(
HttpCode.NOT_FOUND,
`No exit nodes available. ${build == "saas" ? "Please contact support." : "You need to install gerbil to use the clients."}`
)
); );
} }
@@ -256,10 +262,18 @@ export async function createClient(
clientId: newClient.clientId, clientId: newClient.clientId,
dateCreated: moment().toISOString() dateCreated: moment().toISOString()
}); });
await rebuildClientAssociationsFromClient(newClient, trx);
}); });
if (newClient) {
rebuildClientAssociationsFromClient(newClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations after creating client: ${e}`
);
}
);
}
return response<CreateClientResponse>(res, { return response<CreateClientResponse>(res, {
data: newClient, data: newClient,
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db, primaryDb } from "@server/db";
import { import {
roles, roles,
Client, Client,
@@ -237,10 +237,18 @@ export async function createUserClient(
userId, userId,
clientId: newClient.clientId clientId: newClient.clientId
}); });
await rebuildClientAssociationsFromClient(newClient, trx);
}); });
if (newClient) {
rebuildClientAssociationsFromClient(newClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations after creating user client: ${e}`
);
}
);
}
return response<CreateClientAndOlmResponse>(res, { return response<CreateClientAndOlmResponse>(res, {
data: newClient, data: newClient,
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, olms } from "@server/db"; import { db, olms, primaryDb, Client, Olm } from "@server/db";
import { clients, clientSitesAssociationsCache } from "@server/db"; import { clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -71,14 +71,17 @@ export async function deleteClient(
); );
} }
let deletedClient: Client | undefined;
let olm: Olm | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// Then delete the client itself // Then delete the client itself
const [deletedClient] = await trx [deletedClient] = await trx
.delete(clients) .delete(clients)
.where(eq(clients.clientId, clientId)) .where(eq(clients.clientId, clientId))
.returning(); .returning();
const [olm] = await trx [olm] = await trx
.select() .select()
.from(olms) .from(olms)
.where(eq(olms.clientId, clientId)) .where(eq(olms.clientId, clientId))
@@ -88,14 +91,29 @@ export async function deleteClient(
if (!client.userId && client.olmId) { if (!client.userId && client.olmId) {
await trx.delete(olms).where(eq(olms.olmId, client.olmId)); await trx.delete(olms).where(eq(olms.olmId, client.olmId));
} }
await rebuildClientAssociationsFromClient(deletedClient, trx);
if (olm) {
await sendTerminateClient(deletedClient.clientId, OlmErrorCodes.TERMINATED_DELETED, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion
}
}); });
if (deletedClient) {
rebuildClientAssociationsFromClient(deletedClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations after deleting client ${clientId}: ${e}`
);
}
);
if (olm) {
sendTerminateClient(
deletedClient.clientId,
OlmErrorCodes.TERMINATED_DELETED,
olm.olmId
).catch((e) => {
logger.error(
`Failed to send terminate message for client ${deletedClient?.clientId} after deleting client ${clientId}: ${e}`
);
});
}
}
return response(res, { return response(res, {
data: null, data: null,
success: true, success: true,

View File

@@ -1,5 +1,5 @@
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { db, olms } from "@server/db"; import { db, olms, primaryDb } from "@server/db";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { z } from "zod"; import { z } from "zod";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
@@ -81,16 +81,19 @@ export async function createUserOlm(
const secretHash = await hashPassword(secret); const secretHash = await hashPassword(secret);
await db.transaction(async (trx) => { await db.insert(olms).values({
await trx.insert(olms).values({ olmId: olmId,
olmId: olmId, userId,
userId, name,
name, secretHash,
secretHash, dateCreated: moment().toISOString()
dateCreated: moment().toISOString() });
});
await calculateUserClientsForOrgs(userId, trx); calculateUserClientsForOrgs(userId, primaryDb).catch((e) => {
console.error(
"Error calculating user clients after creating olm:",
e
);
}); });
return response<CreateOlmResponse>(res, { return response<CreateOlmResponse>(res, {

View File

@@ -1,5 +1,5 @@
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import { Client, db } from "@server/db"; import { Client, db, Olm, primaryDb } from "@server/db";
import { olms, clients, clientSitesAssociationsCache } from "@server/db"; import { olms, clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -49,6 +49,7 @@ export async function deleteUserOlm(
const { olmId } = parsedParams.data; const { olmId } = parsedParams.data;
let deletedClient: Client | undefined;
// Delete associated clients and the OLM in a transaction // Delete associated clients and the OLM in a transaction
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// Find all clients associated with this OLM // Find all clients associated with this OLM
@@ -57,7 +58,6 @@ export async function deleteUserOlm(
.from(clients) .from(clients)
.where(eq(clients.olmId, olmId)); .where(eq(clients.olmId, olmId));
let deletedClient: Client | null = null;
// Delete all associated clients // Delete all associated clients
if (associatedClients.length > 0) { if (associatedClients.length > 0) {
[deletedClient] = await trx [deletedClient] = await trx
@@ -67,23 +67,28 @@ export async function deleteUserOlm(
} }
// Finally, delete the OLM itself // Finally, delete the OLM itself
const [olm] = await trx await trx.delete(olms).where(eq(olms.olmId, olmId)).returning();
.delete(olms)
.where(eq(olms.olmId, olmId))
.returning();
if (deletedClient) {
await rebuildClientAssociationsFromClient(deletedClient, trx);
if (olm) {
await sendTerminateClient(
deletedClient.clientId,
OlmErrorCodes.TERMINATED_DELETED,
olm.olmId
); // the olmId needs to be provided because it cant look it up after deletion
}
}
}); });
if (deletedClient) {
rebuildClientAssociationsFromClient(deletedClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client-site associations after deleting OLM ${olmId}: ${e}`
);
}
);
sendTerminateClient(
deletedClient.clientId,
OlmErrorCodes.TERMINATED_DELETED,
olmId
).catch((e) => {
logger.error(
`Failed to send terminate message for client ${deletedClient?.clientId} after deleting OLM ${olmId}: ${e}`
);
});
}
return response(res, { return response(res, {
data: null, data: null,
success: true, success: true,

View File

@@ -22,14 +22,14 @@ import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config"; import config from "@server/lib/config";
export const handleOlmRegisterMessage: MessageHandler = async (context) => { export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!"); logger.info("[handleOlmRegisterMessage] Handling register olm message");
const { message, client: c, sendToClient } = context; const { message, client: c, sendToClient } = context;
const olm = c as Olm; const olm = c as Olm;
const now = Math.floor(Date.now() / 1000); const now = Math.floor(Date.now() / 1000);
if (!olm) { if (!olm) {
logger.warn("Olm not found"); logger.warn("[handleOlmRegisterMessage] Olm not found");
return; return;
} }
@@ -46,16 +46,19 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} = message.data; } = message.data;
if (!olm.clientId) { if (!olm.clientId) {
logger.warn("Olm client ID not found"); logger.warn("[handleOlmRegisterMessage] Olm client ID not found");
sendOlmError(OlmErrorCodes.CLIENT_ID_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_ID_NOT_FOUND, olm.olmId);
return; return;
} }
logger.debug("Handling fingerprint insertion for olm register...", { logger.debug(
olmId: olm.olmId, "[handleOlmRegisterMessage] Handling fingerprint insertion for olm register...",
fingerprint, {
postures olmId: olm.olmId,
}); fingerprint,
postures
}
);
const isUserDevice = olm.userId !== null && olm.userId !== undefined; const isUserDevice = olm.userId !== null && olm.userId !== undefined;
@@ -85,14 +88,17 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.limit(1); .limit(1);
if (!client) { if (!client) {
logger.warn("Client ID not found"); logger.warn("[handleOlmRegisterMessage] Client not found", {
clientId: olm.clientId
});
sendOlmError(OlmErrorCodes.CLIENT_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_NOT_FOUND, olm.olmId);
return; return;
} }
if (client.blocked) { if (client.blocked) {
logger.debug( logger.debug(
`Client ${client.clientId} is blocked. Ignoring register.` `[handleOlmRegisterMessage] Client ${client.clientId} is blocked. Ignoring register.`,
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId);
return; return;
@@ -100,7 +106,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (client.approvalState == "pending") { if (client.approvalState == "pending") {
logger.debug( logger.debug(
`Client ${client.clientId} approval is pending. Ignoring register.` `[handleOlmRegisterMessage] Client ${client.clientId} approval is pending. Ignoring register.`,
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId); sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId);
return; return;
@@ -128,14 +135,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.limit(1); .limit(1);
if (!org) { if (!org) {
logger.warn("Org not found"); logger.warn("[handleOlmRegisterMessage] Org not found", {
orgId: client.orgId
});
sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId);
return; return;
} }
if (orgId) { if (orgId) {
if (!olm.userId) { if (!olm.userId) {
logger.warn("Olm has no user ID"); logger.warn("[handleOlmRegisterMessage] Olm has no user ID", {
orgId: client.orgId
});
sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId); sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId);
return; return;
} }
@@ -143,12 +154,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const { session: userSession, user } = const { session: userSession, user } =
await validateSessionToken(userToken); await validateSessionToken(userToken);
if (!userSession || !user) { if (!userSession || !user) {
logger.warn("Invalid user session for olm register"); logger.warn(
"[handleOlmRegisterMessage] Invalid user session for olm register",
{ orgId: client.orgId }
);
sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId); sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId);
return; return;
} }
if (user.userId !== olm.userId) { if (user.userId !== olm.userId) {
logger.warn("User ID mismatch for olm register"); logger.warn(
"[handleOlmRegisterMessage] User ID mismatch for olm register",
{ orgId: client.orgId }
);
sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId); sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId);
return; return;
} }
@@ -163,11 +180,15 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
sessionId // this is the user token passed in the message sessionId // this is the user token passed in the message
}); });
logger.debug("Policy check result:", policyCheck); logger.debug("[handleOlmRegisterMessage] Policy check result", {
orgId: client.orgId,
policyCheck
});
if (policyCheck?.error) { if (policyCheck?.error) {
logger.error( logger.error(
`Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}` `[handleOlmRegisterMessage] Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`,
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return; return;
@@ -175,7 +196,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (policyCheck.policies?.passwordAge?.compliant === false) { if (policyCheck.policies?.passwordAge?.compliant === false) {
logger.warn( logger.warn(
`Olm user ${olm.userId} has non-compliant password age for org ${orgId}` `[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant password age for org ${orgId}`,
{ orgId: client.orgId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED, OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED,
@@ -186,7 +208,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
policyCheck.policies?.maxSessionLength?.compliant === false policyCheck.policies?.maxSessionLength?.compliant === false
) { ) {
logger.warn( logger.warn(
`Olm user ${olm.userId} has non-compliant session length for org ${orgId}` `[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant session length for org ${orgId}`,
{ orgId: client.orgId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED, OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED,
@@ -195,7 +218,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} else if (policyCheck.policies?.requiredTwoFactor === false) { } else if (policyCheck.policies?.requiredTwoFactor === false) {
logger.warn( logger.warn(
`Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}` `[handleOlmRegisterMessage] Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`,
{ orgId: client.orgId }
); );
sendOlmError( sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED, OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED,
@@ -204,7 +228,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return; return;
} else if (!policyCheck.allowed) { } else if (!policyCheck.allowed) {
logger.warn( logger.warn(
`Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}` `[handleOlmRegisterMessage] Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`,
{ orgId: client.orgId }
); );
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId); sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return; return;
@@ -226,29 +251,39 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0; sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations // Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`); logger.debug(
`[handleOlmRegisterMessage] Found ${sitesCount} sites for client ${client.clientId}`,
{ orgId: client.orgId }
);
let jitMode = false; let jitMode = false;
if (sitesCount > 250 && build == "saas") { if (sitesCount > 250 && build == "saas") {
// THIS IS THE MAX ON THE BUSINESS TIER // THIS IS THE MAX ON THE BUSINESS TIER
// we have too many sites // we have too many sites
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites // If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info("Too many sites (%d), dropping into JIT mode", sitesCount); logger.info(
`[handleOlmRegisterMessage] Too many sites (${sitesCount}), dropping into JIT mode`,
{ orgId: client.orgId }
);
jitMode = true; jitMode = true;
} }
logger.debug( logger.debug(
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}` `[handleOlmRegisterMessage] Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`,
{ orgId: client.orgId }
); );
if (!publicKey) { if (!publicKey) {
logger.warn("Public key not provided"); logger.warn("[handleOlmRegisterMessage] Public key not provided", {
orgId: client.orgId
});
return; return;
} }
if (client.pubKey !== publicKey || client.archived) { if (client.pubKey !== publicKey || client.archived) {
logger.info( logger.info(
"Public key mismatch. Updating public key and clearing session info..." "[handleOlmRegisterMessage] Public key mismatch. Updating public key and clearing session info...",
{ orgId: client.orgId }
); );
// Update the client's public key // Update the client's public key
await db await db
@@ -274,12 +309,13 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// TODO: I still think there is a better way to do this rather than locking it out here but ??? // TODO: I still think there is a better way to do this rather than locking it out here but ???
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) { if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
logger.warn( logger.warn(
`Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?` `[handleOlmRegisterMessage] Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`,
{ orgId: client.orgId }
); );
return; return;
} }
// NOTE: its important that the client here is the old client and the public key is the new key // NOTE: its important that the client here is the old client and the public key is the new key
const siteConfigurations = await buildSiteConfigurationForOlmClient( const siteConfigurations = await buildSiteConfigurationForOlmClient(
client, client,
publicKey, publicKey,

View File

@@ -5,7 +5,8 @@ import {
clients, clients,
clientSiteResources, clientSiteResources,
siteResources, siteResources,
apiKeyOrg apiKeyOrg,
primaryDb
} from "@server/db"; } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -220,8 +221,12 @@ export async function batchAddClientToSiteResources(
siteResourceId: siteResource.siteResourceId siteResourceId: siteResource.siteResourceId
}); });
} }
});
await rebuildClientAssociationsFromClient(client, trx); rebuildClientAssociationsFromClient(client, primaryDb).catch((e) => {
logger.error(
`Failed to rebuild client associations after batch adding site resources for client ${clientId}: ${e}`
);
}); });
return response(res, { return response(res, {

View File

@@ -10,7 +10,8 @@ import {
SiteResource, SiteResource,
siteResources, siteResources,
sites, sites,
userSiteResources userSiteResources,
primaryDb
} from "@server/db"; } from "@server/db";
import { getUniqueSiteResourceName } from "@server/db/names"; import { getUniqueSiteResourceName } from "@server/db/names";
import { import {
@@ -519,12 +520,10 @@ export async function createSiteResource(
// own transaction so it always executes on the primary — avoiding any // own transaction so it always executes on the primary — avoiding any
// replica-lag issues while still allowing the HTTP response to return // replica-lag issues while still allowing the HTTP response to return
// early. // early.
db.transaction(async (trx) => { rebuildClientAssociationsFromSiteResource(
await rebuildClientAssociationsFromSiteResource( newSiteResource!,
newSiteResource!, primaryDb
trx ).catch((err) => {
);
}).catch((err) => {
logger.error( logger.error(
`Error rebuilding client associations for site resource ${newSiteResource!.siteResourceId}:`, `Error rebuilding client associations for site resource ${newSiteResource!.siteResourceId}:`,
err err

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, newts, sites } from "@server/db"; import { db, newts, primaryDb, sites } from "@server/db";
import { siteResources } from "@server/db"; import { siteResources } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -73,12 +73,10 @@ export async function deleteSiteResource(
// own transaction so it always executes on the primary — avoiding any // own transaction so it always executes on the primary — avoiding any
// replica-lag issues while still allowing the HTTP response to return // replica-lag issues while still allowing the HTTP response to return
// early. // early.
db.transaction(async (trx) => { rebuildClientAssociationsFromSiteResource(
await rebuildClientAssociationsFromSiteResource( removedSiteResource,
removedSiteResource, primaryDb
trx ).catch((err) => {
);
}).catch((err) => {
logger.error( logger.error(
`Error rebuilding client associations for site resource ${removedSiteResource!.siteResourceId}:`, `Error rebuilding client associations for site resource ${removedSiteResource!.siteResourceId}:`,
err err

View File

@@ -1,7 +1,13 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, orgs } from "@server/db"; import { db, orgs, primaryDb } from "@server/db";
import { roles, userInviteRoles, userInvites, userOrgs, users } from "@server/db"; import {
roles,
userInviteRoles,
userInvites,
userOrgs,
users
} from "@server/db";
import { eq, and, inArray } from "drizzle-orm"; import { eq, and, inArray } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -146,9 +152,7 @@ export async function acceptInvite(
.from(userInviteRoles) .from(userInviteRoles)
.where(eq(userInviteRoles.inviteId, inviteId)); .where(eq(userInviteRoles.inviteId, inviteId));
const inviteRoleIds = [ const inviteRoleIds = [...new Set(inviteRoleRows.map((r) => r.roleId))];
...new Set(inviteRoleRows.map((r) => r.roleId))
];
if (inviteRoleIds.length === 0) { if (inviteRoleIds.length === 0) {
return next( return next(
createHttpError( createHttpError(
@@ -193,13 +197,19 @@ export async function acceptInvite(
.delete(userInvites) .delete(userInvites)
.where(eq(userInvites.inviteId, inviteId)); .where(eq(userInvites.inviteId, inviteId));
await calculateUserClientsForOrgs(existingUser[0].userId, trx);
logger.debug( logger.debug(
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}` `User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}`
); );
}); });
calculateUserClientsForOrgs(existingUser[0].userId, primaryDb).catch(
(e) => {
logger.error(
`Failed to calculate user clients after accepting invite for user ${existingUser[0].userId}: ${e}`
);
}
);
return response<AcceptInviteResponse>(res, { return response<AcceptInviteResponse>(res, {
data: { accepted: true, orgId: existingInvite.orgId }, data: { accepted: true, orgId: existingInvite.orgId },
success: true, success: true,

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import stoi from "@server/lib/stoi"; import stoi from "@server/lib/stoi";
import { clients, db } from "@server/db"; import { clients, db, primaryDb, Client } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db"; import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -112,6 +112,8 @@ export async function addUserRoleLegacy(
); );
} }
let orgClientsToRebuild: Client[] = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userOrgRoles) .delete(userOrgRoles)
@@ -138,11 +140,19 @@ export async function addUserRoleLegacy(
) )
); );
for (const orgClient of orgClients) { orgClientsToRebuild = orgClients;
await rebuildClientAssociationsFromClient(orgClient, trx);
}
}); });
for (const orgClient of orgClientsToRebuild) {
rebuildClientAssociationsFromClient(orgClient, primaryDb).catch(
(e) => {
logger.error(
`Failed to rebuild client associations for client ${orgClient.clientId} after adding role: ${e}`
);
}
);
}
return response(res, { return response(res, {
data: { ...existingUser, roleId }, data: { ...existingUser, roleId },
success: true, success: true,

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db, primaryDb } from "@server/db";
import { users } from "@server/db"; import { users } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -53,8 +53,12 @@ export async function adminRemoveUser(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx.delete(users).where(eq(users.userId, userId)); await trx.delete(users).where(eq(users.userId, userId));
});
await calculateUserClientsForOrgs(userId, trx); calculateUserClientsForOrgs(userId, primaryDb).catch((e) => {
logger.error(
`Failed to calculate user clients after removing user ${userId}: ${e}`
);
}); });
return response(res, { return response(res, {

View File

@@ -6,7 +6,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { db, orgs } from "@server/db"; import { db, orgs, primaryDb } from "@server/db";
import { and, eq, inArray } from "drizzle-orm"; import { and, eq, inArray } from "drizzle-orm";
import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db"; import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db";
import { generateId } from "@server/auth/sessions/app"; import { generateId } from "@server/auth/sessions/app";
@@ -34,8 +34,7 @@ const bodySchema = z
roleId: z.number().int().positive().optional() roleId: z.number().int().positive().optional()
}) })
.refine( .refine(
(d) => (d) => (d.roleIds != null && d.roleIds.length > 0) || d.roleId != null,
(d.roleIds != null && d.roleIds.length > 0) || d.roleId != null,
{ message: "roleIds or roleId is required", path: ["roleIds"] } { message: "roleIds or roleId is required", path: ["roleIds"] }
) )
.transform((data) => ({ .transform((data) => ({
@@ -100,8 +99,14 @@ export async function createOrgUser(
} }
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const { username, email, name, type, idpId, roleIds: uniqueRoleIds } = const {
parsedBody.data; username,
email,
name,
type,
idpId,
roleIds: uniqueRoleIds
} = parsedBody.data;
if (build == "saas") { if (build == "saas") {
const usage = await usageService.getUsage(orgId, FeatureId.USERS); const usage = await usageService.getUsage(orgId, FeatureId.USERS);
@@ -232,6 +237,7 @@ export async function createOrgUser(
); );
} }
let userIdForClients: string | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const [existingUser] = await trx const [existingUser] = await trx
.select() .select()
@@ -270,7 +276,7 @@ export async function createOrgUser(
{ {
orgId, orgId,
userId: existingUser.userId, userId: existingUser.userId,
autoProvisioned: false, autoProvisioned: false
}, },
uniqueRoleIds, uniqueRoleIds,
trx trx
@@ -292,20 +298,30 @@ export async function createOrgUser(
}) })
.returning(); .returning();
await assignUserToOrg( await assignUserToOrg(
org, org,
{ {
orgId, orgId,
userId: newUser.userId, userId: newUser.userId,
autoProvisioned: false, autoProvisioned: false
}, },
uniqueRoleIds, uniqueRoleIds,
trx trx
); );
} }
await calculateUserClientsForOrgs(userId, trx); userIdForClients = userId;
}); });
if (userIdForClients) {
calculateUserClientsForOrgs(userIdForClients, primaryDb).catch(
(e) => {
logger.error(
`Failed to calculate user clients after creating org user: ${e}`
);
}
);
}
} else { } else {
return next( return next(
createHttpError(HttpCode.BAD_REQUEST, "User type is required") createHttpError(HttpCode.BAD_REQUEST, "User type is required")

View File

@@ -7,7 +7,8 @@ import {
siteResources, siteResources,
sites, sites,
UserOrg, UserOrg,
userSiteResources userSiteResources,
primaryDb
} from "@server/db"; } from "@server/db";
import { userOrgs, userResources, users, userSites } from "@server/db"; import { userOrgs, userResources, users, userSites } from "@server/db";
import { and, count, eq, exists, inArray } from "drizzle-orm"; import { and, count, eq, exists, inArray } from "drizzle-orm";
@@ -91,25 +92,12 @@ export async function removeUserOrg(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await removeUserFromOrg(org, userId, trx); await removeUserFromOrg(org, userId, trx);
});
// if (build === "saas") { calculateUserClientsForOrgs(userId, primaryDb).catch((e) => {
// const [rootUser] = await trx logger.error(
// .select() `Failed to calculate user clients after removing user ${userId} from org ${orgId}: ${e}`
// .from(users) );
// .where(eq(users.userId, userId));
//
// const [leftInOrgs] = await trx
// .select({ count: count() })
// .from(userOrgs)
// .where(eq(userOrgs.userId, userId));
//
// // if the user is not an internal user and does not belong to any org, delete the entire user
// if (rootUser?.type !== UserType.Internal && !leftInOrgs.count) {
// await trx.delete(users).where(eq(users.userId, userId));
// }
// }
await calculateUserClientsForOrgs(userId, trx);
}); });
return response(res, { return response(res, {

View File

@@ -44,7 +44,7 @@ export default async function migration() {
await db.execute(sql`BEGIN`); await db.execute(sql`BEGIN`);
await db.execute(sql` await db.execute(sql`
CREATE TABLE "trialNotifications" ( CREATE TABLE IF NOT EXISTS "trialNotifications" (
"notificationId" serial PRIMARY KEY NOT NULL, "notificationId" serial PRIMARY KEY NOT NULL,
"subscriptionId" varchar(255) NOT NULL, "subscriptionId" varchar(255) NOT NULL,
"notificationType" varchar(50) NOT NULL, "notificationType" varchar(50) NOT NULL,
@@ -52,10 +52,6 @@ export default async function migration() {
); );
`); `);
await db.execute(sql`
ALTER TABLE "trialNotifications" ADD CONSTRAINT "trialNotifications_subscriptionId_subscriptions_subscriptionId_fk" FOREIGN KEY ("subscriptionId") REFERENCES "public"."subscriptions"("subscriptionId") ON DELETE cascade ON UPDATE no action;
`);
await db.execute(sql`COMMIT`); await db.execute(sql`COMMIT`);
console.log("Migrated database"); console.log("Migrated database");
} catch (e) { } catch (e) {

View File

@@ -16,7 +16,7 @@ export default async function migration() {
db.transaction(() => { db.transaction(() => {
db.prepare( db.prepare(
` `
CREATE TABLE 'trialNotifications' ( CREATE TABLE IF NOT EXISTS 'trialNotifications' (
'notificationId' integer PRIMARY KEY AUTOINCREMENT NOT NULL, 'notificationId' integer PRIMARY KEY AUTOINCREMENT NOT NULL,
'subscriptionId' text NOT NULL, 'subscriptionId' text NOT NULL,
'notificationType' text NOT NULL, 'notificationType' text NOT NULL,

View File

@@ -55,7 +55,9 @@ export default async function ProxyResourcesPage(
pagination = responseData.pagination; pagination = responseData.pagination;
} catch (e) {} } catch (e) {}
const siteIdParam = parsePositiveInt(searchParams.get("siteId") ?? undefined); const siteIdParam = parsePositiveInt(
searchParams.get("siteId") ?? undefined
);
let initialFilterSite: { let initialFilterSite: {
siteId: number; siteId: number;
@@ -122,6 +124,7 @@ export default async function ProxyResourcesPage(
domainId: resource.domainId || undefined, domainId: resource.domainId || undefined,
fullDomain: resource.fullDomain ?? null, fullDomain: resource.fullDomain ?? null,
ssl: resource.ssl, ssl: resource.ssl,
wildcard: resource.wildcard,
targets: resource.targets?.map((target) => ({ targets: resource.targets?.map((target) => ({
targetId: target.targetId, targetId: target.targetId,
ip: target.ip, ip: target.ip,

View File

@@ -96,6 +96,7 @@ export type ResourceRow = {
targets?: TargetHealth[]; targets?: TargetHealth[];
health?: "healthy" | "degraded" | "unhealthy" | "unknown"; health?: "healthy" | "degraded" | "unhealthy" | "unknown";
sites: ResourceSiteRow[]; sites: ResourceSiteRow[];
wildcard?: boolean;
}; };
function StatusIcon({ function StatusIcon({
@@ -570,10 +571,14 @@ export default function ProxyResourcesTable({
/> />
) : null} ) : null}
<div className=""> <div className="">
<CopyToClipboard {!resourceRow.wildcard ? (
text={resourceRow.domain} <CopyToClipboard
isLink={true} text={resourceRow.domain}
/> isLink={true}
/>
) : (
<span>{resourceRow.domain}</span>
)}
</div> </div>
</div> </div>
); );