Compare commits

..

5 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
7c15c428b3 test: add normalized ASN validation coverage 2026-06-16 23:48:28 +00:00
copilot-swe-agent[bot]
f3a52e31d1 refactor: normalize ASN validation value once 2026-06-16 23:46:44 +00:00
copilot-swe-agent[bot]
5e26ceaf02 fix: allow ALL ASN values in policy rule validation 2026-06-16 23:44:35 +00:00
copilot-swe-agent[bot]
d6fe357fcb Initial plan 2026-06-16 23:39:56 +00:00
Owen
f9cc52ece9 Remove NoNewPrivileges
Fixes https://github.com/fosrl/newt/issues/383
2026-06-14 15:02:18 -07:00
27 changed files with 809 additions and 2010 deletions

View File

@@ -1,42 +1,52 @@
version: 2 version: 2
updates: updates:
- package-ecosystem: "npm" - package-ecosystem: "npm"
directory: "/" directory: "/"
schedule: schedule:
interval: "daily" interval: "daily"
open-pull-requests-limit: 1
groups: groups:
npm-dependencies: dev-patch-updates:
patterns: dependency-type: "development"
- "*" update-types:
- "patch"
dev-minor-updates:
dependency-type: "development"
update-types:
- "minor"
prod-patch-updates:
dependency-type: "production"
update-types:
- "patch"
prod-minor-updates:
dependency-type: "production"
update-types:
- "minor"
- package-ecosystem: "docker" - package-ecosystem: "docker"
directory: "/" directory: "/"
schedule: schedule:
interval: "daily" interval: "daily"
open-pull-requests-limit: 1
groups: groups:
docker-dependencies: patch-updates:
patterns: update-types:
- "*" - "patch"
minor-updates:
update-types:
- "minor"
- package-ecosystem: "github-actions" - package-ecosystem: "github-actions"
directory: "/" directory: "/"
schedule: schedule:
interval: "weekly" interval: "weekly"
open-pull-requests-limit: 1
groups:
github-actions-dependencies:
patterns:
- "*"
- package-ecosystem: "gomod" - package-ecosystem: "gomod"
directory: "/install" directory: "/install"
schedule: schedule:
interval: "daily" interval: "daily"
open-pull-requests-limit: 1
groups: groups:
go-install-dependencies: patch-updates:
patterns: update-types:
- "*" - "patch"
minor-updates:
update-types:
- "minor"

View File

@@ -11,7 +11,7 @@ import {
primaryKey, primaryKey,
uniqueIndex uniqueIndex
} from "drizzle-orm/pg-core"; } from "drizzle-orm/pg-core";
import { InferSelectModel, sql } from "drizzle-orm"; import { InferSelectModel } from "drizzle-orm";
import { import {
domains, domains,
orgs, orgs,
@@ -207,28 +207,17 @@ export const remoteExitNodeSessions = pgTable("remoteExitNodeSession", {
expiresAt: bigint("expiresAt", { mode: "number" }).notNull() expiresAt: bigint("expiresAt", { mode: "number" }).notNull()
}); });
export const loginPage = pgTable( export const loginPage = pgTable("loginPage", {
"loginPage",
{
loginPageId: serial("loginPageId").primaryKey(), loginPageId: serial("loginPageId").primaryKey(),
subdomain: varchar("subdomain"), subdomain: varchar("subdomain"),
fullDomain: varchar("fullDomain"), fullDomain: varchar("fullDomain"),
exitNodeId: integer("exitNodeId").references( exitNodeId: integer("exitNodeId").references(() => exitNodes.exitNodeId, {
() => exitNodes.exitNodeId,
{
onDelete: "set null" onDelete: "set null"
} }),
),
domainId: varchar("domainId").references(() => domains.domainId, { domainId: varchar("domainId").references(() => domains.domainId, {
onDelete: "set null" onDelete: "set null"
}) })
}, });
(t) => [
index("idx_loginpage_fulldomain")
.on(t.fullDomain)
.where(sql`${t.fullDomain} IS NOT NULL`)
]
);
export const loginPageOrg = pgTable("loginPageOrg", { export const loginPageOrg = pgTable("loginPageOrg", {
loginPageId: integer("loginPageId") loginPageId: integer("loginPageId")

View File

@@ -1,5 +1,5 @@
import { randomUUID } from "crypto"; import { randomUUID } from "crypto";
import { InferSelectModel, sql } from "drizzle-orm"; import { InferSelectModel } from "drizzle-orm";
import { import {
bigint, bigint,
boolean, boolean,
@@ -82,9 +82,7 @@ export const orgDomains = pgTable("orgDomains", {
.references(() => domains.domainId, { onDelete: "cascade" }) .references(() => domains.domainId, { onDelete: "cascade" })
}); });
export const sites = pgTable( export const sites = pgTable("sites", {
"sites",
{
siteId: serial("siteId").primaryKey(), siteId: serial("siteId").primaryKey(),
orgId: varchar("orgId") orgId: varchar("orgId")
.references(() => orgs.orgId, { .references(() => orgs.orgId, {
@@ -109,32 +107,17 @@ export const sites = pgTable(
publicKey: varchar("publicKey"), publicKey: varchar("publicKey"),
lastHolePunch: bigint("lastHolePunch", { mode: "number" }), lastHolePunch: bigint("lastHolePunch", { mode: "number" }),
listenPort: integer("listenPort"), listenPort: integer("listenPort"),
dockerSocketEnabled: boolean("dockerSocketEnabled") dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true),
.notNull() autoUpdateEnabled: boolean("autoUpdateEnabled").notNull().default(false),
.default(true),
autoUpdateEnabled: boolean("autoUpdateEnabled")
.notNull()
.default(false),
autoUpdateOverrideOrg: boolean("autoUpdateOverrideOrg") autoUpdateOverrideOrg: boolean("autoUpdateOverrideOrg")
.notNull() .notNull()
.default(false), .default(false),
status: varchar("status") status: varchar("status")
.$type<"pending" | "approved">() .$type<"pending" | "approved">()
.default("approved") .default("approved")
}, });
(t) => [
index("idx_sites_exitnodeid").on(t.exitNodeId),
index("idx_sites_exitnode_type_siteid").on(
t.exitNodeId,
t.type,
t.siteId
)
]
);
export const resources = pgTable( export const resources = pgTable("resources", {
"resources",
{
resourceId: serial("resourceId").primaryKey(), resourceId: serial("resourceId").primaryKey(),
resourcePolicyId: integer("resourcePolicyId").references( resourcePolicyId: integer("resourcePolicyId").references(
() => resourcePolicies.resourcePolicyId, () => resourcePolicies.resourcePolicyId,
@@ -199,13 +182,7 @@ export const resources = pgTable(
.$type<"site" | "remote" | "native">() .$type<"site" | "remote" | "native">()
.default("site"), .default("site"),
authDaemonPort: integer("authDaemonPort").default(22123) authDaemonPort: integer("authDaemonPort").default(22123)
}, });
(t) => [
index("idx_resources_fulldomain")
.on(t.fullDomain)
.where(sql`${t.fullDomain} IS NOT NULL`)
]
);
export const labels = pgTable("labels", { export const labels = pgTable("labels", {
labelId: serial("labelId").primaryKey(), labelId: serial("labelId").primaryKey(),
@@ -290,9 +267,7 @@ export const clientLabels = pgTable(
(t) => [unique("client_label_uniq").on(t.clientId, t.labelId)] (t) => [unique("client_label_uniq").on(t.clientId, t.labelId)]
); );
export const targets = pgTable( export const targets = pgTable("targets", {
"targets",
{
targetId: serial("targetId").primaryKey(), targetId: serial("targetId").primaryKey(),
resourceId: integer("resourceId") resourceId: integer("resourceId")
.references(() => resources.resourceId, { .references(() => resources.resourceId, {
@@ -319,18 +294,9 @@ export const targets = pgTable(
.notNull() .notNull()
.default("http"), .default("http"),
authToken: varchar("authToken") authToken: varchar("authToken")
}, });
(t) => [
index("idx_targets_resourceid_siteid").on(t.resourceId, t.siteId),
index("idx_targets_site_enabled_priority_target_resource")
.on(t.siteId, t.priority.desc(), t.targetId, t.resourceId)
.where(sql`${t.enabled} = true`)
]
);
export const targetHealthCheck = pgTable( export const targetHealthCheck = pgTable("targetHealthCheck", {
"targetHealthCheck",
{
targetHealthCheckId: serial("targetHealthCheckId").primaryKey(), targetHealthCheckId: serial("targetHealthCheckId").primaryKey(),
targetId: integer("targetId").references(() => targets.targetId, { targetId: integer("targetId").references(() => targets.targetId, {
onDelete: "cascade" onDelete: "cascade"
@@ -365,9 +331,7 @@ export const targetHealthCheck = pgTable(
hcTlsServerName: text("hcTlsServerName"), hcTlsServerName: text("hcTlsServerName"),
hcHealthyThreshold: integer("hcHealthyThreshold").default(1), hcHealthyThreshold: integer("hcHealthyThreshold").default(1),
hcUnhealthyThreshold: integer("hcUnhealthyThreshold").default(1) hcUnhealthyThreshold: integer("hcUnhealthyThreshold").default(1)
}, });
(t) => [index("idx_targethealthcheck_targetid").on(t.targetId)]
);
export const exitNodes = pgTable("exitNodes", { export const exitNodes = pgTable("exitNodes", {
exitNodeId: serial("exitNodeId").primaryKey(), exitNodeId: serial("exitNodeId").primaryKey(),
@@ -442,9 +406,7 @@ export const networks = pgTable("networks", {
.notNull() .notNull()
}); });
export const siteNetworks = pgTable( export const siteNetworks = pgTable("siteNetworks", {
"siteNetworks",
{
siteId: integer("siteId") siteId: integer("siteId")
.notNull() .notNull()
.references(() => sites.siteId, { .references(() => sites.siteId, {
@@ -453,63 +415,34 @@ export const siteNetworks = pgTable(
networkId: integer("networkId") networkId: integer("networkId")
.notNull() .notNull()
.references(() => networks.networkId, { onDelete: "cascade" }) .references(() => networks.networkId, { onDelete: "cascade" })
}, });
(t) => [
index("idx_sitenetworks_siteid").on(t.siteId),
index("idx_sitenetworks_networkid").on(t.networkId)
]
);
export const clientSiteResources = pgTable( export const clientSiteResources = pgTable("clientSiteResources", {
"clientSiteResources",
{
clientId: integer("clientId") clientId: integer("clientId")
.notNull() .notNull()
.references(() => clients.clientId, { onDelete: "cascade" }), .references(() => clients.clientId, { onDelete: "cascade" }),
siteResourceId: integer("siteResourceId") siteResourceId: integer("siteResourceId")
.notNull() .notNull()
.references(() => siteResources.siteResourceId, { .references(() => siteResources.siteResourceId, { onDelete: "cascade" })
onDelete: "cascade" });
})
},
(t) => [
index("idx_clientsiteresources_clientid").on(t.clientId),
index("idx_clientsiteresources_siteresourceid").on(t.siteResourceId)
]
);
export const roleSiteResources = pgTable( export const roleSiteResources = pgTable("roleSiteResources", {
"roleSiteResources",
{
roleId: integer("roleId") roleId: integer("roleId")
.notNull() .notNull()
.references(() => roles.roleId, { onDelete: "cascade" }), .references(() => roles.roleId, { onDelete: "cascade" }),
siteResourceId: integer("siteResourceId") siteResourceId: integer("siteResourceId")
.notNull() .notNull()
.references(() => siteResources.siteResourceId, { .references(() => siteResources.siteResourceId, { onDelete: "cascade" })
onDelete: "cascade" });
})
},
(t) => [index("idx_rolesiteresources_siteresourceid").on(t.siteResourceId)]
);
export const userSiteResources = pgTable( export const userSiteResources = pgTable("userSiteResources", {
"userSiteResources",
{
userId: varchar("userId") userId: varchar("userId")
.notNull() .notNull()
.references(() => users.userId, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
siteResourceId: integer("siteResourceId") siteResourceId: integer("siteResourceId")
.notNull() .notNull()
.references(() => siteResources.siteResourceId, { .references(() => siteResources.siteResourceId, { onDelete: "cascade" })
onDelete: "cascade" });
})
},
(t) => [
index("idx_usersiteresources_userid").on(t.userId),
index("idx_usersiteresources_siteresourceid").on(t.siteResourceId)
]
);
export const users = pgTable("user", { export const users = pgTable("user", {
userId: varchar("id").primaryKey(), userId: varchar("id").primaryKey(),
@@ -534,9 +467,7 @@ export const users = pgTable("user", {
locale: varchar("locale") locale: varchar("locale")
}); });
export const newts = pgTable( export const newts = pgTable("newt", {
"newt",
{
newtId: varchar("id").primaryKey(), newtId: varchar("id").primaryKey(),
secretHash: varchar("secretHash").notNull(), secretHash: varchar("secretHash").notNull(),
dateCreated: varchar("dateCreated").notNull(), dateCreated: varchar("dateCreated").notNull(),
@@ -544,9 +475,7 @@ export const newts = pgTable(
siteId: integer("siteId").references(() => sites.siteId, { siteId: integer("siteId").references(() => sites.siteId, {
onDelete: "cascade" onDelete: "cascade"
}) })
}, });
(t) => [index("idx_newt_siteid").on(t.siteId)]
);
export const twoFactorBackupCodes = pgTable("twoFactorBackupCodes", { export const twoFactorBackupCodes = pgTable("twoFactorBackupCodes", {
codeId: serial("id").primaryKey(), codeId: serial("id").primaryKey(),
@@ -647,9 +576,7 @@ export const userOrgRoles = pgTable(
(t) => [unique().on(t.userId, t.orgId, t.roleId)] (t) => [unique().on(t.userId, t.orgId, t.roleId)]
); );
export const roleActions = pgTable( export const roleActions = pgTable("roleActions", {
"roleActions",
{
roleId: integer("roleId") roleId: integer("roleId")
.notNull() .notNull()
.references(() => roles.roleId, { onDelete: "cascade" }), .references(() => roles.roleId, { onDelete: "cascade" }),
@@ -659,19 +586,9 @@ export const roleActions = pgTable(
orgId: varchar("orgId") orgId: varchar("orgId")
.notNull() .notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }) .references(() => orgs.orgId, { onDelete: "cascade" })
}, });
(t) => [
index("idx_roleActions_roleId_orgId_actionId").on(
t.roleId,
t.orgId,
t.actionId
)
]
);
export const userActions = pgTable( export const userActions = pgTable("userActions", {
"userActions",
{
userId: varchar("userId") userId: varchar("userId")
.notNull() .notNull()
.references(() => users.userId, { onDelete: "cascade" }), .references(() => users.userId, { onDelete: "cascade" }),
@@ -681,15 +598,7 @@ export const userActions = pgTable(
orgId: varchar("orgId") orgId: varchar("orgId")
.notNull() .notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }) .references(() => orgs.orgId, { onDelete: "cascade" })
}, });
(t) => [
index("idx_userActions_userId_orgId_actionId").on(
t.userId,
t.orgId,
t.actionId
)
]
);
export const roleSites = pgTable("roleSites", { export const roleSites = pgTable("roleSites", {
roleId: integer("roleId") roleId: integer("roleId")
@@ -1095,9 +1004,7 @@ export const idpOrg = pgTable("idpOrg", {
orgMapping: varchar("orgMapping") orgMapping: varchar("orgMapping")
}); });
export const clients = pgTable( export const clients = pgTable("clients", {
"clients",
{
clientId: serial("clientId").primaryKey(), clientId: serial("clientId").primaryKey(),
orgId: varchar("orgId") orgId: varchar("orgId")
.references(() => orgs.orgId, { .references(() => orgs.orgId, {
@@ -1130,9 +1037,7 @@ export const clients = pgTable(
approvalState: varchar("approvalState").$type< approvalState: varchar("approvalState").$type<
"pending" | "approved" | "denied" "pending" | "approved" | "denied"
>() >()
}, });
(t) => [index("idx_clients_userid").on(t.userId)]
);
export const clientSitesAssociationsCache = pgTable( export const clientSitesAssociationsCache = pgTable(
"clientSitesAssociationsCache", "clientSitesAssociationsCache",
@@ -1144,11 +1049,7 @@ export const clientSitesAssociationsCache = pgTable(
isJitMode: boolean("isJitMode").notNull().default(false), isJitMode: boolean("isJitMode").notNull().default(false),
endpoint: varchar("endpoint"), endpoint: varchar("endpoint"),
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
}, }
(t) => [
primaryKey({ columns: [t.clientId, t.siteId] }),
index("idx_clientsitesassociationscache_siteid").on(t.siteId)
]
); );
export const clientSiteResourcesAssociationsCache = pgTable( export const clientSiteResourcesAssociationsCache = pgTable(
@@ -1157,14 +1058,7 @@ export const clientSiteResourcesAssociationsCache = pgTable(
clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message
.notNull(), .notNull(),
siteResourceId: integer("siteResourceId").notNull() siteResourceId: integer("siteResourceId").notNull()
}, }
(t) => [
primaryKey({ columns: [t.clientId, t.siteResourceId] }),
index("idx_clientSiteResourcesAssociationsCache_siteResourceId").on(
t.siteResourceId,
t.clientId
)
]
); );
export const clientPostureSnapshots = pgTable("clientPostureSnapshots", { export const clientPostureSnapshots = pgTable("clientPostureSnapshots", {
@@ -1177,9 +1071,7 @@ export const clientPostureSnapshots = pgTable("clientPostureSnapshots", {
collectedAt: integer("collectedAt").notNull() collectedAt: integer("collectedAt").notNull()
}); });
export const olms = pgTable( export const olms = pgTable("olms", {
"olms",
{
olmId: varchar("id").primaryKey(), olmId: varchar("id").primaryKey(),
secretHash: varchar("secretHash").notNull(), secretHash: varchar("secretHash").notNull(),
dateCreated: varchar("dateCreated").notNull(), dateCreated: varchar("dateCreated").notNull(),
@@ -1195,9 +1087,7 @@ export const olms = pgTable(
onDelete: "cascade" onDelete: "cascade"
}), }),
archived: boolean("archived").notNull().default(false) archived: boolean("archived").notNull().default(false)
}, });
(t) => [index("idx_olms_clientid").on(t.clientId)]
);
export const currentFingerprint = pgTable("currentFingerprint", { export const currentFingerprint = pgTable("currentFingerprint", {
fingerprintId: serial("id").primaryKey(), fingerprintId: serial("id").primaryKey(),

View File

@@ -1,5 +1,6 @@
import { drizzle as DrizzleSqlite } from "drizzle-orm/better-sqlite3"; import { drizzle as DrizzleSqlite } from "drizzle-orm/better-sqlite3";
import Database from "better-sqlite3"; import Database from "better-sqlite3";
import type BetterSqlite3 from "better-sqlite3";
import * as schema from "./schema/schema"; import * as schema from "./schema/schema";
import path from "path"; import path from "path";
import fs from "fs"; import fs from "fs";
@@ -11,31 +12,68 @@ export const exists = checkFileExists(location);
bootstrapVolume(); bootstrapVolume();
/**
* Wraps better-sqlite3 Statement to call `finalize()` immediately after
* execution, freeing native sqlite3_stmt memory deterministically instead
* of waiting for GC. Fixes steady off-heap growth under load (#2120).
* WARNING: Finalizes after first execution — incompatible with drizzle's
* reusable .prepare() builders. No such usage exists in this codebase.
*/
function autoFinalizeStatement(
stmt: BetterSqlite3.Statement
): BetterSqlite3.Statement {
const wrapExec = <T extends (...args: any[]) => any>(fn: T): T => {
return function (this: any, ...args: any[]) {
try {
return fn.apply(this, args);
} finally {
try {
// finalize() exists on the native Statement at runtime but
// is missing from @types/better-sqlite3.
(stmt as any).finalize();
} catch {
// Already finalized — harmless
}
}
} as unknown as T;
};
stmt.run = wrapExec(stmt.run);
stmt.get = wrapExec(stmt.get);
stmt.all = wrapExec(stmt.all);
return stmt;
}
function createDb() { function createDb() {
const sqlite = new Database(location); const sqlite = new Database(location);
if (process.env.ENABLE_SQLITE_WAL_MODE == "true") { if (process.env.ENABLE_SQLITE_WAL_MODE == "true") {
// Enable WAL mode — allows concurrent readers + single writer, preventing // Enable WAL mode — allows concurrent readers + single writer, preventing
// contention across subsystems (verifySession, Traefik, audit, ping). // contention across subsystems (verifySession, Traefik, audit, ping).
// NOTE: journal_mode persists in the DB file once set; unsetting this
// env var does NOT revert an existing WAL database.
sqlite.pragma("journal_mode = WAL"); sqlite.pragma("journal_mode = WAL");
// NORMAL sync mode: safe with WAL, reduces write lock hold time. // NORMAL sync mode: safe with WAL, reduces write lock hold time.
sqlite.pragma("synchronous = NORMAL"); sqlite.pragma("synchronous = NORMAL");
} }
// No busy_timeout pragma: better-sqlite3 already arms // Wait up to 5s on SQLITE_BUSY instead of failing — prevents audit log
// sqlite3_busy_timeout(db, 5000) via its default `timeout` option // retry loops that accumulate memory.
// (lib/database.js), so an explicit pragma is redundant. sqlite.pragma("busy_timeout = 5000");
// Intentionally NOT setting cache_size or mmap_size: a large page cache plus // 64 MB page cache (default 2 MB) — reduces I/O round-trips on large
// a multi-hundred-MB mmap region inflate RSS and cause page-cache thrashing // TraefikConfigManager JOINs that block the event loop.
// on small (~1 GB) instances. Leave SQLite on its conservative defaults. sqlite.pragma("cache_size = -65536");
// Intentionally NOT wrapping prepare()/statements: better-sqlite3 finalizes // 256 MB memory-mapped I/O — OS serves reads from page cache directly,
// sqlite3_stmt in the Statement destructor at GC, and drizzle-orm prepares a // reducing event-loop blocking.
// fresh statement per query (no statement cache), so statements cannot sqlite.pragma("mmap_size = 268435456");
// accumulate. better-sqlite3 11.x exposes no Statement.finalize() at all.
// Wrap prepare() so every drizzle-orm statement is auto-finalized after
// first use, preventing sqlite3_stmt accumulation between GC cycles.
const originalPrepare = sqlite.prepare.bind(sqlite);
(sqlite as any).prepare = function autoFinalizePrepare(source: string) {
return autoFinalizeStatement(originalPrepare(source));
};
return DrizzleSqlite(sqlite, { return DrizzleSqlite(sqlite, {
schema schema

View File

@@ -24,7 +24,6 @@ import license from "#dynamic/license/license";
import { initLogCleanupInterval } from "@server/lib/cleanupLogs"; import { initLogCleanupInterval } from "@server/lib/cleanupLogs";
import { initAcmeCertSync } from "#dynamic/lib/acmeCertSync"; import { initAcmeCertSync } from "#dynamic/lib/acmeCertSync";
import { fetchServerIp } from "@server/lib/serverIpService"; import { fetchServerIp } from "@server/lib/serverIpService";
import { startRebuildQueueProcessor } from "@server/lib/rebuildClientAssociations";
async function startServers() { async function startServers() {
await setHostMeta(); await setHostMeta();
@@ -42,7 +41,6 @@ async function startServers() {
initLogCleanupInterval(); initLogCleanupInterval();
initAcmeCertSync(); initAcmeCertSync();
startRebuildQueueProcessor();
// Start all servers // Start all servers
const apiServer = createApiServer(); const apiServer = createApiServer();

View File

@@ -12,7 +12,7 @@ import {
import { FeatureId, getFeatureMeterId } from "./features"; import { FeatureId, getFeatureMeterId } from "./features";
import logger from "@server/logger"; import logger from "@server/logger";
import { build } from "@server/build"; import { build } from "@server/build";
import { regionalCache as cache } from "#dynamic/lib/cache"; import cache from "#dynamic/lib/cache";
export function noop() { export function noop() {
if (build !== "saas") { if (build !== "saas") {
@@ -22,6 +22,7 @@ export function noop() {
} }
export class UsageService { export class UsageService {
constructor() { constructor() {
if (noop()) { if (noop()) {
return; return;
@@ -56,10 +57,7 @@ export class UsageService {
try { try {
let usage; let usage;
if (transaction) { if (transaction) {
const orgIdToUse = await this.getBillingOrg( const orgIdToUse = await this.getBillingOrg(orgId, transaction);
orgId,
transaction
);
usage = await this.internalAddUsage( usage = await this.internalAddUsage(
orgIdToUse, orgIdToUse,
featureId, featureId,

View File

@@ -48,18 +48,18 @@ export async function applyBlueprint({
name, name,
source = "API" source = "API"
}: ApplyBlueprintArgs): Promise<Blueprint> { }: ApplyBlueprintArgs): Promise<Blueprint> {
let blueprintSucceeded: boolean = false; // Validate the input data
let blueprintMessage = "";
let error: any | null = null;
try {
const validationResult = ConfigSchema.safeParse(configData); const validationResult = ConfigSchema.safeParse(configData);
if (!validationResult.success) { if (!validationResult.success) {
throw new Error(fromError(validationResult.error).toString()); throw new Error(fromError(validationResult.error).toString());
} }
const config: Config = validationResult.data; const config: Config = validationResult.data;
let blueprintSucceeded: boolean = false;
let blueprintMessage: string;
let error: any | null = null;
try {
let proxyResourcesResults: PublicResourcesResults = []; let proxyResourcesResults: PublicResourcesResults = [];
let clientResourcesResults: ClientResourcesResults = []; let clientResourcesResults: ClientResourcesResults = [];
await db.transaction(async (trx) => { await db.transaction(async (trx) => {

View File

@@ -1,74 +0,0 @@
const MAX_RECURSION_DEPTH = 100;
const segmentRegexCache = new Map<string, RegExp>();
function getSegmentRegex(patternPart: string): RegExp {
let regex = segmentRegexCache.get(patternPart);
if (!regex) {
const regexPattern = patternPart
.replace(/[.+^${}()|[\]\\]/g, "\\$&")
.replace(/\*/g, ".*")
.replace(/\?/g, ".");
regex = new RegExp(`^${regexPattern}$`);
segmentRegexCache.set(patternPart, regex);
}
return regex;
}
export function isPathAllowed(pattern: string, path: string): boolean {
const normalize = (p: string) => p.split("/").filter(Boolean);
const patternParts = normalize(pattern);
const pathParts = normalize(path);
function matchSegments(
patternIndex: number,
pathIndex: number,
depth: number = 0
): boolean {
if (depth > MAX_RECURSION_DEPTH) {
return false;
}
const currentPatternPart = patternParts[patternIndex];
const currentPathPart = pathParts[pathIndex];
if (patternIndex >= patternParts.length) {
return pathIndex >= pathParts.length;
}
if (pathIndex >= pathParts.length) {
return patternParts.slice(patternIndex).every((p) => p === "*");
}
if (currentPatternPart === "*") {
if (matchSegments(patternIndex + 1, pathIndex, depth + 1)) {
return true;
}
if (matchSegments(patternIndex, pathIndex + 1, depth + 1)) {
return true;
}
return false;
}
if (currentPatternPart.includes("*")) {
const regex = getSegmentRegex(currentPatternPart);
if (regex.test(currentPathPart)) {
return matchSegments(
patternIndex + 1,
pathIndex + 1,
depth + 1
);
}
return false;
}
if (currentPatternPart !== currentPathPart) {
return false;
}
return matchSegments(patternIndex + 1, pathIndex + 1, depth + 1);
}
return matchSegments(0, 0, 0);
}

View File

@@ -8,7 +8,6 @@ import {
exitNodes, exitNodes,
newts, newts,
olms, olms,
primaryDb,
roleSiteResources, roleSiteResources,
Site, Site,
SiteResource, SiteResource,
@@ -21,10 +20,10 @@ import {
} from "@server/db"; } from "@server/db";
import { and, count, eq, inArray, ne } from "drizzle-orm"; import { and, count, eq, inArray, ne } from "drizzle-orm";
import { deletePeersBatch as newtDeletePeersBatch } from "@server/routers/newt/peers"; import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers";
import { import {
initPeerAddHandshakeBatch, initPeerAddHandshake,
deletePeersBatch as olmDeletePeersBatch deletePeer as olmDeletePeer
} from "@server/routers/olm/peers"; } from "@server/routers/olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes"; import { sendToExitNode } from "#dynamic/lib/exitNodes";
import logger from "@server/logger"; import logger from "@server/logger";
@@ -35,13 +34,12 @@ import {
parseEndpoint parseEndpoint
} from "@server/lib/ip"; } from "@server/lib/ip";
import { import {
addPeerDataBatch, addPeerData,
addTargetsBatch as addSubnetProxyTargetsBatch, addTargets as addSubnetProxyTargets,
removePeerDataBatch, removePeerData,
removeTargetsBatch as removeSubnetProxyTargetsBatch removeTargets as removeSubnetProxyTargets
} from "@server/routers/client/targets"; } from "@server/routers/client/targets";
import { lockManager } from "#dynamic/lib/lock"; import { lockManager } from "#dynamic/lib/lock";
import { rebuildQueue } from "#dynamic/lib/rebuildQueue";
// TTL for rebuild-association locks. These functions can fan out into many // TTL for rebuild-association locks. These functions can fan out into many
// peer/proxy updates, so give them a generous window. // peer/proxy updates, so give them a generous window.
@@ -162,33 +160,18 @@ export async function getClientSiteResourceAccess(
export async function rebuildClientAssociationsFromSiteResource( export async function rebuildClientAssociationsFromSiteResource(
siteResource: SiteResource, siteResource: SiteResource,
trx: Transaction | typeof db = db trx: Transaction | typeof db = db
) { ): Promise<{
try { mergedAllClients: {
clientId: number;
pubKey: string | null;
subnet: string | null;
}[];
}> {
return await lockManager.withLock( return await lockManager.withLock(
`rebuild-client-associations:site-resource:${siteResource.siteResourceId}`, `rebuild-client-associations:site-resource:${siteResource.siteResourceId}`,
() => () => rebuildClientAssociationsFromSiteResourceImpl(siteResource, trx),
rebuildClientAssociationsFromSiteResourceImpl(
siteResource,
trx
),
REBUILD_ASSOCIATIONS_LOCK_TTL_MS REBUILD_ASSOCIATIONS_LOCK_TTL_MS
); );
} catch (err: any) {
if (
typeof err?.message === "string" &&
err.message.startsWith("Failed to acquire lock")
) {
logger.warn(
`rebuildClientAssociations: could not acquire lock for site resource ${siteResource.siteResourceId}, queuing for deferred processing`
);
await rebuildQueue.enqueue({
type: "site-resource",
id: siteResource.siteResourceId
});
return { mergedAllClients: [] };
}
throw err;
}
} }
async function rebuildClientAssociationsFromSiteResourceImpl( async function rebuildClientAssociationsFromSiteResourceImpl(
@@ -553,28 +536,6 @@ async function handleMessagesForSiteClients(
const newtJobs: Promise<any>[] = []; const newtJobs: Promise<any>[] = [];
const olmJobs: Promise<any>[] = []; const olmJobs: Promise<any>[] = [];
const exitNodeJobs: Promise<any>[] = []; const exitNodeJobs: Promise<any>[] = [];
const newtPeerDeletes: {
siteId: number;
publicKey: string;
newtId: string;
}[] = [];
const olmPeerDeletes: {
clientId: number;
siteId: number;
publicKey: string;
olmId: string;
}[] = [];
const olmPeerAddHandshakes: {
clientId: number;
peer: {
siteId: number;
exitNode: {
publicKey: string;
endpoint: string;
};
};
olmId: string;
}[] = [];
// Combine all clients that need processing (those being added or removed) // Combine all clients that need processing (those being added or removed)
const clientsToProcess = new Map< const clientsToProcess = new Map<
@@ -623,21 +584,6 @@ async function handleMessagesForSiteClients(
} }
} }
// Batch-fetch all olm IDs for the clients we need to process
const clientIdsToProcess = Array.from(clientsToProcess.keys());
const olmRows =
clientIdsToProcess.length > 0
? await trx
.select({ olmId: olms.olmId, clientId: olms.clientId })
.from(olms)
.where(inArray(olms.clientId, clientIdsToProcess))
: [];
const olmByClientId = new Map<number, string>(
olmRows
.filter((r) => r.clientId !== null)
.map((r) => [r.clientId as number, r.olmId])
);
for (const client of clientsToProcess.values()) { for (const client of clientsToProcess.values()) {
// UPDATE THE NEWT // UPDATE THE NEWT
if (!client.subnet || !client.pubKey) { if (!client.subnet || !client.pubKey) {
@@ -654,8 +600,14 @@ async function handleMessagesForSiteClients(
continue; continue;
} }
const olmId = olmByClientId.get(client.clientId); const [olm] = await trx
if (!olmId) { .select({
olmId: olms.olmId
})
.from(olms)
.where(eq(olms.clientId, client.clientId))
.limit(1);
if (!olm) {
logger.warn( logger.warn(
`Olm not found for client ${client.clientId} so cannot add/delete peers` `Olm not found for client ${client.clientId} so cannot add/delete peers`
); );
@@ -663,17 +615,15 @@ async function handleMessagesForSiteClients(
} }
if (isDelete) { if (isDelete) {
newtPeerDeletes.push({ newtJobs.push(newtDeletePeer(siteId, client.pubKey, newt.newtId));
olmJobs.push(
olmDeletePeer(
client.clientId,
siteId, siteId,
publicKey: client.pubKey, site.publicKey,
newtId: newt.newtId olm.olmId
}); )
olmPeerDeletes.push({ );
clientId: client.clientId,
siteId,
publicKey: site.publicKey,
olmId
});
} }
if (isAdd) { if (isAdd) {
@@ -685,34 +635,23 @@ async function handleMessagesForSiteClients(
continue; continue;
} }
olmPeerAddHandshakes.push({ await initPeerAddHandshake(
clientId: client.clientId, // this will kick off the add peer process for the client
peer: { client.clientId,
{
siteId, siteId,
exitNode: { exitNode: {
publicKey: exitNode.publicKey, publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint endpoint: exitNode.endpoint
} }
}, },
olmId olm.olmId
}); );
} }
exitNodeJobs.push(updateClientSiteDestinations(client, trx)); exitNodeJobs.push(updateClientSiteDestinations(client, trx));
} }
if (newtPeerDeletes.length > 0) {
newtJobs.push(newtDeletePeersBatch(newtPeerDeletes));
}
if (olmPeerDeletes.length > 0) {
olmJobs.push(olmDeletePeersBatch(olmPeerDeletes));
}
if (olmPeerAddHandshakes.length > 0) {
olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes));
}
Promise.all(exitNodeJobs).catch((error) => { Promise.all(exitNodeJobs).catch((error) => {
logger.error( logger.error(
`rebuildClientAssociations: Error updating client site destinations for site ${site.siteId}:`, `rebuildClientAssociations: Error updating client site destinations for site ${site.siteId}:`,
@@ -873,20 +812,6 @@ async function handleSubnetProxyTargetUpdates(
): Promise<void> { ): Promise<void> {
const proxyJobs: Promise<any>[] = []; const proxyJobs: Promise<any>[] = [];
const olmJobs: Promise<any>[] = []; const olmJobs: Promise<any>[] = [];
const targetsToAddBatch: {
newtId: string;
targets: NonNullable<
Awaited<ReturnType<typeof generateSubnetProxyTargetV2>>
>;
version: string | null;
}[] = [];
const targetsToRemoveBatch: {
newtId: string;
targets: NonNullable<
Awaited<ReturnType<typeof generateSubnetProxyTargetV2>>
>;
version: string | null;
}[] = [];
for (const siteData of sitesList) { for (const siteData of sitesList) {
const siteId = siteData.siteId; const siteId = siteData.siteId;
@@ -918,26 +843,26 @@ async function handleSubnetProxyTargetUpdates(
); );
if (targetsToAdd) { if (targetsToAdd) {
targetsToAddBatch.push({ proxyJobs.push(
newtId: newt.newtId, addSubnetProxyTargets(
targets: targetsToAdd, newt.newtId,
version: newt.version targetsToAdd,
}); newt.version
}
olmJobs.push(
addPeerDataBatch(
addedClients.map((client) => ({
clientId: client.clientId,
siteId,
remoteSubnets: generateRemoteSubnets([
siteResource
]),
aliases: generateAliasConfig([siteResource])
}))
) )
); );
} }
for (const client of addedClients) {
olmJobs.push(
addPeerData(
client.clientId,
siteId,
generateRemoteSubnets([siteResource]),
generateAliasConfig([siteResource])
)
);
}
}
} }
// here we use the existingSiteResource from BEFORE we updated the destination so we dont need to worry about updating destinations here // here we use the existingSiteResource from BEFORE we updated the destination so we dont need to worry about updating destinations here
@@ -955,20 +880,15 @@ async function handleSubnetProxyTargetUpdates(
); );
if (targetsToRemove) { if (targetsToRemove) {
targetsToRemoveBatch.push({ proxyJobs.push(
newtId: newt.newtId, removeSubnetProxyTargets(
targets: targetsToRemove, newt.newtId,
version: newt.version targetsToRemove,
}); newt.version
)
);
} }
const peerDataRemovals: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: ReturnType<typeof generateAliasConfig>;
}[] = [];
for (const client of removedClients) { for (const client of removedClients) {
if (!siteResource.destination) { if (!siteResource.destination) {
continue; continue;
@@ -1016,58 +936,31 @@ async function handleSubnetProxyTargetUpdates(
? [] ? []
: generateRemoteSubnets([siteResource]); : generateRemoteSubnets([siteResource]);
peerDataRemovals.push({ olmJobs.push(
clientId: client.clientId, removePeerData(
client.clientId,
siteId, siteId,
remoteSubnets: remoteSubnetsToRemove, remoteSubnetsToRemove,
aliases: generateAliasConfig([siteResource]) generateAliasConfig([siteResource])
}); )
} );
if (peerDataRemovals.length > 0) {
olmJobs.push(removePeerDataBatch(peerDataRemovals));
} }
} }
} }
} }
if (targetsToAddBatch.length > 0) { await Promise.all(proxyJobs);
proxyJobs.push(addSubnetProxyTargetsBatch(targetsToAddBatch));
}
if (targetsToRemoveBatch.length > 0) {
proxyJobs.push(removeSubnetProxyTargetsBatch(targetsToRemoveBatch));
}
await Promise.all([...proxyJobs, ...olmJobs]);
} }
export async function rebuildClientAssociationsFromClient( export async function rebuildClientAssociationsFromClient(
client: Client, client: Client,
trx: Transaction | typeof db = db trx: Transaction | typeof db = db
): Promise<void> { ): Promise<void> {
try {
return await lockManager.withLock( return await lockManager.withLock(
`rebuild-client-associations:client:${client.clientId}`, `rebuild-client-associations:client:${client.clientId}`,
() => rebuildClientAssociationsFromClientImpl(client, trx), () => rebuildClientAssociationsFromClientImpl(client, trx),
REBUILD_ASSOCIATIONS_LOCK_TTL_MS REBUILD_ASSOCIATIONS_LOCK_TTL_MS
); );
} catch (err: any) {
if (
typeof err?.message === "string" &&
err.message.startsWith("Failed to acquire lock")
) {
logger.warn(
`rebuildClientAssociations: could not acquire lock for client ${client.clientId}, queuing for deferred processing`
);
await rebuildQueue.enqueue({
type: "client",
id: client.clientId
});
return;
}
throw err;
}
} }
async function rebuildClientAssociationsFromClientImpl( async function rebuildClientAssociationsFromClientImpl(
@@ -1344,28 +1237,6 @@ async function handleMessagesForClientSites(
const newtJobs: Promise<any>[] = []; const newtJobs: Promise<any>[] = [];
const olmJobs: Promise<any>[] = []; const olmJobs: Promise<any>[] = [];
const exitNodeJobs: Promise<any>[] = []; const exitNodeJobs: Promise<any>[] = [];
const newtPeerDeletes: {
siteId: number;
publicKey: string;
newtId: string;
}[] = [];
const olmPeerDeletes: {
clientId: number;
siteId: number;
publicKey: string;
olmId: string;
}[] = [];
const olmPeerAddHandshakes: {
clientId: number;
peer: {
siteId: number;
exitNode: {
publicKey: string;
endpoint: string;
};
};
olmId: string;
}[] = [];
const totalSitesOnClient = await trx const totalSitesOnClient = await trx
.select({ count: count(clientSitesAssociationsCache.siteId) }) .select({ count: count(clientSitesAssociationsCache.siteId) })
@@ -1397,19 +1268,19 @@ async function handleMessagesForClientSites(
if (isRemove) { if (isRemove) {
// Remove peer from newt // Remove peer from newt
newtPeerDeletes.push({ newtJobs.push(
siteId: site.siteId, newtDeletePeer(site.siteId, client.pubKey, newt.newtId)
publicKey: client.pubKey, );
newtId: newt.newtId
});
try { try {
// Remove peer from olm // Remove peer from olm
olmPeerDeletes.push({ olmJobs.push(
clientId: client.clientId, olmDeletePeer(
siteId: site.siteId, client.clientId,
publicKey: site.publicKey, site.siteId,
site.publicKey,
olmId olmId
}); )
);
} catch (error) { } catch (error) {
// if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send
if ( if (
@@ -1441,9 +1312,10 @@ async function handleMessagesForClientSites(
continue; continue;
} }
olmPeerAddHandshakes.push({ await initPeerAddHandshake(
clientId: client.clientId, // this will kick off the add peer process for the client
peer: { client.clientId,
{
siteId: site.siteId, siteId: site.siteId,
exitNode: { exitNode: {
publicKey: exitNode.publicKey, publicKey: exitNode.publicKey,
@@ -1451,7 +1323,7 @@ async function handleMessagesForClientSites(
} }
}, },
olmId olmId
}); );
} }
// Update exit node destinations // Update exit node destinations
@@ -1467,18 +1339,6 @@ async function handleMessagesForClientSites(
); );
} }
if (newtPeerDeletes.length > 0) {
newtJobs.push(newtDeletePeersBatch(newtPeerDeletes));
}
if (olmPeerDeletes.length > 0) {
olmJobs.push(olmDeletePeersBatch(olmPeerDeletes));
}
if (olmPeerAddHandshakes.length > 0) {
olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes));
}
Promise.all(exitNodeJobs).catch((error) => { Promise.all(exitNodeJobs).catch((error) => {
logger.error( logger.error(
`rebuildClientAssociations: Error updating client site destinations for client ${client.clientId}:`, `rebuildClientAssociations: Error updating client site destinations for client ${client.clientId}:`,
@@ -1577,20 +1437,6 @@ async function handleMessagesForClientResources(
continue; continue;
} }
const targetsToAddBatch: {
newtId: string;
targets: NonNullable<
Awaited<ReturnType<typeof generateSubnetProxyTargetV2>>
>;
version: string | null;
}[] = [];
const peerDataAdds: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: ReturnType<typeof generateAliasConfig>;
}[] = [];
for (const resource of resources) { for (const resource of resources) {
const targets = await generateSubnetProxyTargetV2(resource, [ const targets = await generateSubnetProxyTargetV2(resource, [
{ {
@@ -1601,21 +1447,25 @@ async function handleMessagesForClientResources(
]); ]);
if (targets) { if (targets) {
targetsToAddBatch.push({ proxyJobs.push(
newtId: newt.newtId, addSubnetProxyTargets(
newt.newtId,
targets, targets,
version: newt.version newt.version
}); )
);
} }
try { try {
// Add peer data to olm // Add peer data to olm
peerDataAdds.push({ olmJobs.push(
clientId: client.clientId, addPeerData(
client.clientId,
siteId, siteId,
remoteSubnets: generateRemoteSubnets([resource]), generateRemoteSubnets([resource]),
aliases: generateAliasConfig([resource]) generateAliasConfig([resource])
}); )
);
} catch (error) { } catch (error) {
// if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send
if ( if (
@@ -1630,14 +1480,6 @@ async function handleMessagesForClientResources(
} }
} }
} }
if (targetsToAddBatch.length > 0) {
proxyJobs.push(addSubnetProxyTargetsBatch(targetsToAddBatch));
}
if (peerDataAdds.length > 0) {
olmJobs.push(addPeerDataBatch(peerDataAdds));
}
} }
} }
@@ -1704,20 +1546,6 @@ async function handleMessagesForClientResources(
continue; continue;
} }
const targetsToRemoveBatch: {
newtId: string;
targets: NonNullable<
Awaited<ReturnType<typeof generateSubnetProxyTargetV2>>
>;
version: string | null;
}[] = [];
const peerDataRemovals: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: ReturnType<typeof generateAliasConfig>;
}[] = [];
for (const resource of resources) { for (const resource of resources) {
const targets = await generateSubnetProxyTargetV2(resource, [ const targets = await generateSubnetProxyTargetV2(resource, [
{ {
@@ -1728,11 +1556,13 @@ async function handleMessagesForClientResources(
]); ]);
if (targets) { if (targets) {
targetsToRemoveBatch.push({ proxyJobs.push(
newtId: newt.newtId, removeSubnetProxyTargets(
newt.newtId,
targets, targets,
version: newt.version newt.version
}); )
);
} }
try { try {
@@ -1783,12 +1613,14 @@ async function handleMessagesForClientResources(
: generateRemoteSubnets([resource]); : generateRemoteSubnets([resource]);
// Remove peer data from olm // Remove peer data from olm
peerDataRemovals.push({ olmJobs.push(
clientId: client.clientId, removePeerData(
client.clientId,
siteId, siteId,
remoteSubnets: remoteSubnetsToRemove, remoteSubnetsToRemove,
aliases: generateAliasConfig([resource]) generateAliasConfig([resource])
}); )
);
} catch (error) { } catch (error) {
// if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send
if ( if (
@@ -1803,16 +1635,6 @@ async function handleMessagesForClientResources(
} }
} }
} }
if (targetsToRemoveBatch.length > 0) {
proxyJobs.push(
removeSubnetProxyTargetsBatch(targetsToRemoveBatch)
);
}
if (peerDataRemovals.length > 0) {
olmJobs.push(removePeerDataBatch(peerDataRemovals));
}
} }
} }
@@ -2062,20 +1884,11 @@ export async function cleanupSiteAssociations(
// 7. Fire all removal messages in parallel. // 7. Fire all removal messages in parallel.
const jobs: Promise<any>[] = []; const jobs: Promise<any>[] = [];
const olmPeerDeletes: {
clientId: number;
siteId: number;
publicKey: string;
}[] = [];
for (const client of allClients) { for (const client of allClients) {
// Tell each olm to drop the site's WireGuard peer. // Tell each olm to drop the site's WireGuard peer.
if (site.publicKey) { if (site.publicKey) {
olmPeerDeletes.push({ jobs.push(olmDeletePeer(client.clientId, siteId, site.publicKey));
clientId: client.clientId,
siteId,
publicKey: site.publicKey
});
} }
// Recompute and push updated relay destinations (now excluding this site). // Recompute and push updated relay destinations (now excluding this site).
@@ -2084,10 +1897,6 @@ export async function cleanupSiteAssociations(
} }
} }
if (olmPeerDeletes.length > 0) {
jobs.push(olmDeletePeersBatch(olmPeerDeletes));
}
await Promise.all(jobs).catch((error) => { await Promise.all(jobs).catch((error) => {
logger.error( logger.error(
`cleanupSiteAssociations: error sending cleanup messages for siteId=${siteId}:`, `cleanupSiteAssociations: error sending cleanup messages for siteId=${siteId}:`,
@@ -2097,47 +1906,3 @@ export async function cleanupSiteAssociations(
logger.debug(`cleanupSiteAssociations: DONE siteId=${siteId}`); logger.debug(`cleanupSiteAssociations: DONE siteId=${siteId}`);
} }
/**
* Start the background rebuild queue processor. This should be called once
* during server startup. Only one server instance at a time will actively
* consume the queue (enforced via a distributed Redis lock); all other
* instances will poll and wait until the lock becomes available.
*/
export function startRebuildQueueProcessor(): void {
rebuildQueue.startProcessing({
onSiteResource: async (siteResourceId: number) => {
const [siteResource] = await primaryDb
.select()
.from(siteResources)
.where(eq(siteResources.siteResourceId, siteResourceId));
if (!siteResource) {
logger.warn(
`Rebuild queue: site resource ${siteResourceId} not found, skipping`
);
return;
}
await rebuildClientAssociationsFromSiteResource(
siteResource,
primaryDb
);
},
onClient: async (clientId: number) => {
const [client] = await primaryDb
.select()
.from(clients)
.where(eq(clients.clientId, clientId));
if (!client) {
logger.warn(
`Rebuild queue: client ${clientId} not found, skipping`
);
return;
}
await rebuildClientAssociationsFromClient(client, primaryDb);
}
});
}

View File

@@ -1,23 +0,0 @@
export type RebuildJobType = "site-resource" | "client";
export interface RebuildJob {
type: RebuildJobType;
id: number;
}
export interface RebuildJobHandlers {
onSiteResource(siteResourceId: number): Promise<void>;
onClient(clientId: number): Promise<void>;
}
export interface RebuildQueueManager {
enqueue(job: RebuildJob): Promise<void>;
startProcessing(handlers: RebuildJobHandlers): void;
}
class NoopRebuildQueue implements RebuildQueueManager {
async enqueue(_job: RebuildJob): Promise<void> {}
startProcessing(_handlers: RebuildJobHandlers): void {}
}
export const rebuildQueue: RebuildQueueManager = new NoopRebuildQueue();

View File

@@ -17,7 +17,7 @@ import { certificates, db } from "@server/db";
import { and, eq, isNotNull, or, inArray, sql } from "drizzle-orm"; import { and, eq, isNotNull, or, inArray, sql } from "drizzle-orm";
import { decrypt } from "@server/lib/crypto"; import { decrypt } from "@server/lib/crypto";
import logger from "@server/logger"; import logger from "@server/logger";
import { regionalCache as cache } from "#private/lib/cache"; import cache from "#private/lib/cache";
import { build } from "@server/build"; import { build } from "@server/build";
// Define the return type for clarity and type safety // Define the return type for clarity and type safety

View File

@@ -1,198 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025-2026 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { redis } from "#private/lib/redis";
import { lockManager } from "#dynamic/lib/lock";
import logger from "@server/logger";
export type RebuildJobType = "site-resource" | "client";
export interface RebuildJob {
type: RebuildJobType;
id: number;
}
export interface RebuildJobHandlers {
onSiteResource(siteResourceId: number): Promise<void>;
onClient(clientId: number): Promise<void>;
}
// Redis list holding pending rebuild jobs (RPUSH to enqueue, LPOP to dequeue — FIFO order).
const QUEUE_KEY = "rebuild-client-associations:queue";
const QUEUED_SET_KEY = "rebuild-client-associations:queued";
// Distributed lock that serialises queue consumption to a single server instance
// at a time. TTL is generous enough to cover a full batch of expensive rebuilds.
const PROCESSOR_LOCK_KEY = "rebuild-client-associations:processor";
// Each rebuild can take up to REBUILD_ASSOCIATIONS_LOCK_TTL_MS (120 s) per
// resource. Allow BATCH_SIZE resources per processor-lock acquisition, plus a
// small buffer.
const BATCH_SIZE = 5;
const PROCESSOR_LOCK_TTL_MS = 120000 * BATCH_SIZE + 30000; // ~630 s
const POLL_INTERVAL_MS = 500;
class RedisRebuildQueue {
private processingStarted = false;
async enqueue(job: RebuildJob): Promise<void> {
if (!redis || redis.status !== "ready") {
logger.warn(
`Rebuild queue: Redis not available — rebuild for ${job.type}:${job.id} will not be retried`
);
return;
}
try {
const dedupeKey = `${job.type}:${job.id}`;
const added = await redis.sadd(QUEUED_SET_KEY, dedupeKey);
if (added === 0) {
logger.debug(
`Rebuild queue: skipped duplicate queued job ${job.type}:${job.id}`
);
return;
}
await redis.rpush(QUEUE_KEY, JSON.stringify(job));
logger.debug(
`Rebuild queue: enqueued ${job.type}:${job.id} (queue position: tail)`
);
} catch (err) {
await redis
.srem(QUEUED_SET_KEY, `${job.type}:${job.id}`)
.catch((cleanupErr) =>
logger.warn(
`Rebuild queue: failed to cleanup dedupe key for ${job.type}:${job.id} after enqueue failure:`,
cleanupErr
)
);
logger.error(
`Rebuild queue: failed to enqueue ${job.type}:${job.id}:`,
err
);
}
}
startProcessing(handlers: RebuildJobHandlers): void {
if (this.processingStarted) return;
this.processingStarted = true;
this.processLoop(handlers).catch((err) => {
logger.error("Rebuild queue processor loop crashed:", err);
});
logger.info("Rebuild queue processor started");
}
private async processLoop(handlers: RebuildJobHandlers): Promise<void> {
while (true) {
try {
await this.tryProcessBatch(handlers);
} catch (err) {
logger.error(
"Rebuild queue: unhandled error in process loop:",
err
);
}
await new Promise((resolve) =>
setTimeout(resolve, POLL_INTERVAL_MS)
);
}
}
private async tryProcessBatch(handlers: RebuildJobHandlers): Promise<void> {
if (!redis || redis.status !== "ready") return;
// Peek before acquiring the processor lock to avoid unnecessary Redis
// round-trips and lock contention when the queue is idle.
const queueLength = await redis.llen(QUEUE_KEY).catch(() => 0);
if (queueLength === 0) return;
try {
await lockManager.withLock(
PROCESSOR_LOCK_KEY,
async () => {
for (let i = 0; i < BATCH_SIZE; i++) {
if (!redis || redis.status !== "ready") break;
const payload = await redis.lpop(QUEUE_KEY);
if (payload === null) break; // queue drained
let job: RebuildJob;
try {
job = JSON.parse(payload) as RebuildJob;
} catch {
logger.error(
`Rebuild queue: could not parse job payload, discarding: ${payload}`
);
continue;
}
// Remove from dedupe set once dequeued so the same job
// can be re-queued while this one is in progress.
await redis
.srem(QUEUED_SET_KEY, `${job.type}:${job.id}`)
.catch((cleanupErr) =>
logger.warn(
`Rebuild queue: failed to remove dedupe key for ${job.type}:${job.id} on dequeue:`,
cleanupErr
)
);
logger.debug(
`Rebuild queue: processing ${job.type}:${job.id}`
);
try {
if (job.type === "site-resource") {
await handlers.onSiteResource(job.id);
} else if (job.type === "client") {
await handlers.onClient(job.id);
} else {
logger.warn(
`Rebuild queue: unknown job type "${(job as any).type}", discarding`
);
}
logger.debug(
`Rebuild queue: completed ${job.type}:${job.id}`
);
} catch (err) {
logger.error(
`Rebuild queue: job ${job.type}:${job.id} threw an error:`,
err
);
}
}
},
PROCESSOR_LOCK_TTL_MS
);
} catch (err: any) {
if (
typeof err?.message === "string" &&
err.message.startsWith("Failed to acquire lock")
) {
// Another server instance currently holds the processor lock and
// is consuming the queue — nothing to do this cycle.
logger.debug(
"Rebuild queue: processor lock held by another instance, skipping this cycle"
);
} else {
throw err;
}
}
}
}
export const rebuildQueue: RedisRebuildQueue = new RedisRebuildQueue();

View File

@@ -22,7 +22,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { ListRemoteExitNodesResponse } from "@server/routers/remoteExitNode/types"; import { ListRemoteExitNodesResponse } from "@server/routers/remoteExitNode/types";
import { regionalCache as cache } from "#private/lib/cache"; import cache from "#private/lib/cache";
import semver from "semver"; import semver from "semver";
let stalePangolinNodeVersion: string | null = null; let stalePangolinNodeVersion: string | null = null;

View File

@@ -38,7 +38,6 @@ import { messageHandlers } from "@server/routers/ws/messageHandlers";
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers"; import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
import { import {
AuthenticatedWebSocket, AuthenticatedWebSocket,
BatchSendMessage,
ClientType, ClientType,
WSMessage, WSMessage,
TokenPayload, TokenPayload,
@@ -188,8 +187,6 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true });
// Generate unique node ID for this instance // Generate unique node ID for this instance
const NODE_ID = uuidv4(); const NODE_ID = uuidv4();
const REDIS_CHANNEL = "websocket_messages"; const REDIS_CHANNEL = "websocket_messages";
const REDIS_DIRECT_BATCH_SIZE = 250;
const REDIS_DIRECT_FLUSH_INTERVAL_MS = 10;
// Client tracking map (local to this node) // Client tracking map (local to this node)
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map(); const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
@@ -200,15 +197,6 @@ const clientConfigVersions: Map<string, number> = new Map();
// Recovery tracking // Recovery tracking
let isRedisRecoveryInProgress = false; let isRedisRecoveryInProgress = false;
interface RedisDirectBatchEntry {
targetClientId: string;
message: WSMessage;
resolve: () => void;
}
let pendingRedisDirectMessages: RedisDirectBatchEntry[] = [];
let redisDirectFlushTimer: NodeJS.Timeout | null = null;
// Helper to get map key // Helper to get map key
const getClientMapKey = (clientId: string) => clientId; const getClientMapKey = (clientId: string) => clientId;
@@ -219,78 +207,6 @@ const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
const getConfigVersionKey = (clientId: string) => const getConfigVersionKey = (clientId: string) =>
`ws:configVersion:${clientId}`; `ws:configVersion:${clientId}`;
const clearRedisDirectFlushTimer = (): void => {
if (redisDirectFlushTimer) {
clearTimeout(redisDirectFlushTimer);
redisDirectFlushTimer = null;
}
};
const publishDirectBatch = async (
entries: RedisDirectBatchEntry[]
): Promise<void> => {
const redisMessage: RedisMessage = {
type: "direct-batch",
messages: entries.map((entry) => ({
targetClientId: entry.targetClientId,
message: entry.message
})),
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
};
const flushPendingRedisDirectMessages = async (): Promise<void> => {
clearRedisDirectFlushTimer();
if (pendingRedisDirectMessages.length === 0) {
return;
}
const entries = pendingRedisDirectMessages;
pendingRedisDirectMessages = [];
if (!redisManager.isRedisEnabled()) {
entries.forEach((entry) => entry.resolve());
return;
}
for (let i = 0; i < entries.length; i += REDIS_DIRECT_BATCH_SIZE) {
const batch = entries.slice(i, i + REDIS_DIRECT_BATCH_SIZE);
try {
await publishDirectBatch(batch);
} catch (error) {
logger.error(
"Failed to send batched direct messages via Redis, messages may be lost:",
error
);
} finally {
batch.forEach((entry) => entry.resolve());
}
}
};
const enqueueRedisDirectMessage = async (
targetClientId: string,
message: WSMessage
): Promise<void> => {
await new Promise<void>((resolve) => {
pendingRedisDirectMessages.push({ targetClientId, message, resolve });
if (pendingRedisDirectMessages.length >= REDIS_DIRECT_BATCH_SIZE) {
void flushPendingRedisDirectMessages();
return;
}
if (!redisDirectFlushTimer) {
redisDirectFlushTimer = setTimeout(() => {
void flushPendingRedisDirectMessages();
}, REDIS_DIRECT_FLUSH_INTERVAL_MS);
}
});
};
// Initialize Redis subscription for cross-node messaging // Initialize Redis subscription for cross-node messaging
const initializeRedisSubscription = async (): Promise<void> => { const initializeRedisSubscription = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return; if (!redisManager.isRedisEnabled()) return;
@@ -311,16 +227,7 @@ const initializeRedisSubscription = async (): Promise<void> => {
// Send to specific client on this node // Send to specific client on this node
await sendToClientLocal( await sendToClientLocal(
redisMessage.targetClientId, redisMessage.targetClientId,
redisMessage.message, redisMessage.message
{},
redisMessage.message.configVersion
);
} else if (
redisMessage.type === "direct-batch" &&
redisMessage.messages
) {
await sendRedisDirectBatchToLocalClients(
redisMessage.messages
); );
} else if (redisMessage.type === "broadcast") { } else if (redisMessage.type === "broadcast") {
// Broadcast to all clients on this node except excluded // Broadcast to all clients on this node except excluded
@@ -596,8 +503,7 @@ const incrementClientConfigVersion = async (
const sendToClientLocal = async ( const sendToClientLocal = async (
clientId: string, clientId: string,
message: WSMessage, message: WSMessage,
options: SendMessageOptions = {}, options: SendMessageOptions = {}
preResolvedConfigVersion?: number
): Promise<boolean> => { ): Promise<boolean> => {
const mapKey = getClientMapKey(clientId); const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey); const clients = connectedClients.get(mapKey);
@@ -606,8 +512,7 @@ const sendToClientLocal = async (
} }
// Handle config version // Handle config version
const configVersion = const configVersion = await getClientConfigVersion(clientId);
preResolvedConfigVersion ?? (await getClientConfigVersion(clientId));
// Add config version to message // Add config version to message
const messageWithVersion = { const messageWithVersion = {
@@ -640,73 +545,45 @@ const sendToClientLocal = async (
return true; return true;
}; };
const sendRedisDirectBatchToLocalClients = async (
entries: { targetClientId: string; message: WSMessage }[]
): Promise<void> => {
const jobs = entries.map((entry) =>
sendToClientLocal(
entry.targetClientId,
entry.message,
{},
entry.message.configVersion
)
);
await Promise.all(jobs);
};
const broadcastToAllExceptLocal = async ( const broadcastToAllExceptLocal = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string, excludeClientId?: string,
options: SendMessageOptions = {} options: SendMessageOptions = {}
): Promise<void> => { ): Promise<void> => {
const sendPlans = await Promise.all( for (const [mapKey, clients] of connectedClients.entries()) {
Array.from(connectedClients.entries()).map( const [type, id] = mapKey.split(":");
async ([mapKey, clients]) => {
const clientId = mapKey; // mapKey is the clientId const clientId = mapKey; // mapKey is the clientId
if (excludeClientId && clientId === excludeClientId) { if (!(excludeClientId && clientId === excludeClientId)) {
return null; // Handle config version per client
}
let configVersion = await getClientConfigVersion(clientId); let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) { if (options.incrementConfigVersion) {
configVersion = configVersion = await incrementClientConfigVersion(clientId);
await incrementClientConfigVersion(clientId);
} }
return { // Add config version to message
clients, const messageWithVersion = {
messageWithVersion: {
...message, ...message,
configVersion configVersion
}
}; };
}
)
);
for (const plan of sendPlans) {
if (!plan) {
continue;
}
if (options.compress) { if (options.compress) {
const compressed = zlib.gzipSync( const compressed = zlib.gzipSync(
Buffer.from(JSON.stringify(plan.messageWithVersion), "utf8") Buffer.from(JSON.stringify(messageWithVersion), "utf8")
); );
plan.clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(compressed); client.send(compressed);
} }
}); });
} else { } else {
const messageString = JSON.stringify(plan.messageWithVersion); clients.forEach((client) => {
plan.clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(messageString); client.send(JSON.stringify(messageWithVersion));
} }
}); });
} }
} }
}
}; };
// Cross-node message sending (via Redis) // Cross-node message sending (via Redis)
@@ -725,23 +602,28 @@ const sendToClient = async (
); );
// Try to send locally first // Try to send locally first
const localSent = await sendToClientLocal( const localSent = await sendToClientLocal(clientId, message, options);
clientId,
message,
options,
configVersion
);
// Only send via Redis if the client is not connected locally and Redis is enabled // Only send via Redis if the client is not connected locally and Redis is enabled
if (!localSent && redisManager.isRedisEnabled()) { if (!localSent && redisManager.isRedisEnabled()) {
try { try {
await enqueueRedisDirectMessage(clientId, { const redisMessage: RedisMessage = {
type: "direct",
targetClientId: clientId,
message: {
...message, ...message,
configVersion configVersion
}); },
fromNodeId: NODE_ID
};
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) { } catch (error) {
logger.error( logger.error(
"Failed to queue batched direct message for Redis delivery, message may be lost:", "Failed to send message via Redis, message may be lost:",
error error
); );
// Continue execution - local delivery already attempted // Continue execution - local delivery already attempted
@@ -756,95 +638,6 @@ const sendToClient = async (
return localSent; return localSent;
}; };
const sendToClientsBatch = async (
entries: BatchSendMessage[]
): Promise<void> => {
if (entries.length === 0) {
return;
}
const remoteEntries: { targetClientId: string; message: WSMessage }[] = [];
const clientsWithIncrement = new Set(
entries
.filter((entry) => !!entry.options?.incrementConfigVersion)
.map((entry) => entry.clientId)
);
const nonIncrementOnlyClientIds = Array.from(
new Set(
entries
.map((entry) => entry.clientId)
.filter((clientId) => !clientsWithIncrement.has(clientId))
)
);
const stableConfigVersionByClient = new Map<string, number | undefined>(
await Promise.all(
nonIncrementOnlyClientIds.map(
async (clientId) =>
[clientId, await getClientConfigVersion(clientId)] as const
)
)
);
for (const entry of entries) {
const options = entry.options || {};
const { clientId, message } = entry;
const configVersion = options.incrementConfigVersion
? await incrementClientConfigVersion(clientId)
: stableConfigVersionByClient.get(clientId);
logger.debug(
`sendToClientsBatch: Message type ${message.type} queued for clientId ${clientId} (new configVersion: ${configVersion})`
);
const localSent = await sendToClientLocal(
clientId,
message,
options,
configVersion
);
if (!localSent && redisManager.isRedisEnabled()) {
remoteEntries.push({
targetClientId: clientId,
message: {
...message,
configVersion
}
});
} else if (!localSent && !redisManager.isRedisEnabled()) {
logger.debug(
`Could not deliver batch message to ${clientId} - not connected locally and Redis unavailable`
);
}
}
if (!redisManager.isRedisEnabled() || remoteEntries.length === 0) {
return;
}
for (let i = 0; i < remoteEntries.length; i += REDIS_DIRECT_BATCH_SIZE) {
const messages = remoteEntries.slice(i, i + REDIS_DIRECT_BATCH_SIZE);
try {
const redisMessage: RedisMessage = {
type: "direct-batch",
messages,
fromNodeId: NODE_ID
};
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error(
"Failed to send explicit direct batch via Redis, messages may be lost:",
error
);
}
}
};
const broadcastToAllExcept = async ( const broadcastToAllExcept = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string, excludeClientId?: string,
@@ -1316,8 +1109,6 @@ const disconnectClient = async (clientId: string): Promise<boolean> => {
// Cleanup function for graceful shutdown // Cleanup function for graceful shutdown
const cleanup = async (): Promise<void> => { const cleanup = async (): Promise<void> => {
try { try {
await flushPendingRedisDirectMessages();
// Close all WebSocket connections // Close all WebSocket connections
connectedClients.forEach((clients) => { connectedClients.forEach((clients) => {
clients.forEach((client) => { clients.forEach((client) => {
@@ -1348,7 +1139,6 @@ export {
router, router,
handleWSUpgrade, handleWSUpgrade,
sendToClient, sendToClient,
sendToClientsBatch,
broadcastToAllExcept, broadcastToAllExcept,
connectedClients, connectedClients,
hasActiveConnections, hasActiveConnections,

View File

@@ -1,6 +1,5 @@
import { assertEquals } from "@test/assert"; import { assertEquals } from "@test/assert";
import { REGIONS } from "@server/db/regions"; import { REGIONS } from "@server/db/regions";
import { isPathAllowed } from "@server/lib/pathMatch";
function isIpInRegion( function isIpInRegion(
ipCountryCode: string | undefined, ipCountryCode: string | undefined,
@@ -34,6 +33,76 @@ function isIpInRegion(
return false; return false;
} }
function isPathAllowed(pattern: string, path: string): boolean {
// Normalize and split paths into segments
const normalize = (p: string) => p.split("/").filter(Boolean);
const patternParts = normalize(pattern);
const pathParts = normalize(path);
// Recursive function to try different wildcard matches
function matchSegments(patternIndex: number, pathIndex: number): boolean {
const indent = " ".repeat(pathIndex); // Indent based on recursion depth
const currentPatternPart = patternParts[patternIndex];
const currentPathPart = pathParts[pathIndex];
// If we've consumed all pattern parts, we should have consumed all path parts
if (patternIndex >= patternParts.length) {
const result = pathIndex >= pathParts.length;
return result;
}
// If we've consumed all path parts but still have pattern parts
if (pathIndex >= pathParts.length) {
// The only way this can match is if all remaining pattern parts are wildcards
const remainingPattern = patternParts.slice(patternIndex);
const result = remainingPattern.every((p) => p === "*");
return result;
}
// For full segment wildcards, try consuming different numbers of path segments
if (currentPatternPart === "*") {
// Try consuming 0 segments (skip the wildcard)
if (matchSegments(patternIndex + 1, pathIndex)) {
return true;
}
// Try consuming current segment and recursively try rest
if (matchSegments(patternIndex, pathIndex + 1)) {
return true;
}
return false;
}
// Check for in-segment wildcard (e.g., "prefix*" or "prefix*suffix")
if (currentPatternPart.includes("*")) {
// Convert the pattern segment to a regex pattern
const regexPattern = currentPatternPart
.replace(/\*/g, ".*") // Replace * with .* for regex wildcard
.replace(/\?/g, "."); // Replace ? with . for single character wildcard if needed
const regex = new RegExp(`^${regexPattern}$`);
if (regex.test(currentPathPart)) {
return matchSegments(patternIndex + 1, pathIndex + 1);
}
return false;
}
// For regular segments, they must match exactly
if (currentPatternPart !== currentPathPart) {
return false;
}
// Move to next segments in both pattern and path
return matchSegments(patternIndex + 1, pathIndex + 1);
}
const result = matchSegments(0, 0);
return result;
}
function runTests() { function runTests() {
console.log("Running path matching tests..."); console.log("Running path matching tests...");
@@ -239,121 +308,6 @@ function runTests() {
console.log("All path matching tests passed!"); console.log("All path matching tests passed!");
} }
function runSpecialCharacterTests() {
console.log("\nRunning special character tests...");
let threw = false;
try {
isPathAllowed("(api*", "anything");
isPathAllowed("a(b*", "a(bc");
isPathAllowed("c[d*", "c[de");
isPathAllowed("x{2}*", "x{2}y");
isPathAllowed("a|b*", "a|bc");
isPathAllowed("back\\slash*", "back\\slashed");
} catch (e) {
threw = true;
console.error(
"Patterns accepted by isValidUrlGlobPattern crashed the matcher:",
e instanceof Error ? e.message : e
);
}
assertEquals(
threw,
false,
"Patterns with regex metacharacters must not throw"
);
assertEquals(
isPathAllowed("(api*", "(api-v1"),
true,
"Parenthesis should be treated as a literal character"
);
assertEquals(
isPathAllowed("(api*", "xapi-v1"),
false,
"Parenthesis should not match other characters"
);
assertEquals(
isPathAllowed("a(b)*", "a(b)c"),
true,
"Parentheses pair should be treated as literal characters"
);
assertEquals(
isPathAllowed("*.png", "image.png"),
true,
"Dot should match a literal dot"
);
assertEquals(
isPathAllowed("*.png", "imageXpng"),
false,
"Dot should not act as a regex wildcard"
);
assertEquals(
isPathAllowed("v1.0*", "v1.0.1"),
true,
"Version-like literal should match itself"
);
assertEquals(
isPathAllowed("v1.0*", "v1x0-beta"),
false,
"Version-like literal should not match arbitrary characters"
);
assertEquals(
isPathAllowed("a+b*", "a+bc"),
true,
"Plus should be treated as a literal character"
);
assertEquals(
isPathAllowed("a+b*", "aaabc"),
false,
"Plus should not act as a regex quantifier"
);
assertEquals(
isPathAllowed("$ref*", "$refs"),
true,
"Dollar sign should be treated as a literal character"
);
assertEquals(
isPathAllowed("price$*", "price$100"),
true,
"Dollar sign mid-pattern should be treated as a literal character"
);
assertEquals(
isPathAllowed("^start*", "^started"),
true,
"Caret should be treated as a literal character"
);
assertEquals(
isPathAllowed("a|b*", "a|bc"),
true,
"Pipe should be treated as a literal character"
);
assertEquals(
isPathAllowed("a|b*", "a"),
false,
"Pipe should not act as regex alternation"
);
assertEquals(
isPathAllowed("file?*", "fileX"),
true,
"Question mark should still act as a single-character wildcard"
);
assertEquals(
isPathAllowed("api/*", "api/" + "x/".repeat(50)),
true,
"Deeply nested paths should still match"
);
console.log("All special character tests passed!");
}
function runRegionTests() { function runRegionTests() {
console.log("\nRunning isIpInRegion tests..."); console.log("\nRunning isIpInRegion tests...");
@@ -413,7 +367,6 @@ function runRegionTests() {
// Run all tests // Run all tests
try { try {
runTests(); runTests();
runSpecialCharacterTests();
runRegionTests(); runRegionTests();
console.log("\n✅ All tests passed!"); console.log("\n✅ All tests passed!");
} catch (error) { } catch (error) {

View File

@@ -25,7 +25,6 @@ import {
} from "@server/db"; } from "@server/db";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { isIpInCidr, stripPortFromHost } from "@server/lib/ip"; import { isIpInCidr, stripPortFromHost } from "@server/lib/ip";
import { isPathAllowed } from "@server/lib/pathMatch";
import { response } from "@server/lib/response"; import { response } from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -1091,7 +1090,143 @@ async function checkRules(
return; return;
} }
export { isPathAllowed }; export function isPathAllowed(pattern: string, path: string): boolean {
logger.debug(`\nMatching path "${path}" against pattern "${pattern}"`);
// Normalize and split paths into segments
const normalize = (p: string) => p.split("/").filter(Boolean);
const patternParts = normalize(pattern);
const pathParts = normalize(path);
logger.debug(`Normalized pattern parts: [${patternParts.join(", ")}]`);
logger.debug(`Normalized path parts: [${pathParts.join(", ")}]`);
// Maximum recursion depth to prevent stack overflow and memory issues
const MAX_RECURSION_DEPTH = 100;
// Recursive function to try different wildcard matches
function matchSegments(
patternIndex: number,
pathIndex: number,
depth: number = 0
): boolean {
// Check recursion depth limit
if (depth > MAX_RECURSION_DEPTH) {
logger.warn(
`Path matching exceeded maximum recursion depth (${MAX_RECURSION_DEPTH}) for pattern "${pattern}" and path "${path}"`
);
return false;
}
const indent = " ".repeat(depth); // Indent based on recursion depth
const currentPatternPart = patternParts[patternIndex];
const currentPathPart = pathParts[pathIndex];
logger.debug(
`${indent}Checking patternIndex=${patternIndex} (${currentPatternPart || "END"}) vs pathIndex=${pathIndex} (${currentPathPart || "END"}) [depth=${depth}]`
);
// If we've consumed all pattern parts, we should have consumed all path parts
if (patternIndex >= patternParts.length) {
const result = pathIndex >= pathParts.length;
logger.debug(
`${indent}Reached end of pattern, remaining path: ${pathParts.slice(pathIndex).join("/")} -> ${result}`
);
return result;
}
// If we've consumed all path parts but still have pattern parts
if (pathIndex >= pathParts.length) {
// The only way this can match is if all remaining pattern parts are wildcards
const remainingPattern = patternParts.slice(patternIndex);
const result = remainingPattern.every((p) => p === "*");
logger.debug(
`${indent}Reached end of path, remaining pattern: ${remainingPattern.join("/")} -> ${result}`
);
return result;
}
// For full segment wildcards, try consuming different numbers of path segments
if (currentPatternPart === "*") {
logger.debug(
`${indent}Found wildcard at pattern index ${patternIndex}`
);
// Try consuming 0 segments (skip the wildcard)
logger.debug(
`${indent}Trying to skip wildcard (consume 0 segments)`
);
if (matchSegments(patternIndex + 1, pathIndex, depth + 1)) {
logger.debug(
`${indent}Successfully matched by skipping wildcard`
);
return true;
}
// Try consuming current segment and recursively try rest
logger.debug(
`${indent}Trying to consume segment "${currentPathPart}" for wildcard`
);
if (matchSegments(patternIndex, pathIndex + 1, depth + 1)) {
logger.debug(
`${indent}Successfully matched by consuming segment for wildcard`
);
return true;
}
logger.debug(`${indent}Failed to match wildcard`);
return false;
}
// Check for in-segment wildcard (e.g., "prefix*" or "prefix*suffix")
if (currentPatternPart.includes("*")) {
logger.debug(
`${indent}Found in-segment wildcard in "${currentPatternPart}"`
);
// Convert the pattern segment to a regex pattern
const regexPattern = currentPatternPart
.replace(/\*/g, ".*") // Replace * with .* for regex wildcard
.replace(/\?/g, "."); // Replace ? with . for single character wildcard if needed
const regex = new RegExp(`^${regexPattern}$`);
if (regex.test(currentPathPart)) {
logger.debug(
`${indent}Segment with wildcard matches: "${currentPatternPart}" matches "${currentPathPart}"`
);
return matchSegments(
patternIndex + 1,
pathIndex + 1,
depth + 1
);
}
logger.debug(
`${indent}Segment with wildcard mismatch: "${currentPatternPart}" doesn't match "${currentPathPart}"`
);
return false;
}
// For regular segments, they must match exactly
if (currentPatternPart !== currentPathPart) {
logger.debug(
`${indent}Segment mismatch: "${currentPatternPart}" != "${currentPathPart}"`
);
return false;
}
logger.debug(
`${indent}Segments match: "${currentPatternPart}" = "${currentPathPart}"`
);
// Move to next segments in both pattern and path
return matchSegments(patternIndex + 1, pathIndex + 1, depth + 1);
}
const result = matchSegments(0, 0, 0);
logger.debug(`Final result: ${result}`);
return result;
}
async function isIpInGeoIP( async function isIpInGeoIP(
ipCountryCode: string | undefined, ipCountryCode: string | undefined,

View File

@@ -1,4 +1,4 @@
import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import { db, newts, olms } from "@server/db"; import { db, newts, olms } from "@server/db";
import { import {
Alias, Alias,
@@ -8,7 +8,7 @@ import {
} from "@server/lib/ip"; } from "@server/lib/ip";
import { canCompress } from "@server/lib/clientVersionChecks"; import { canCompress } from "@server/lib/clientVersionChecks";
import logger from "@server/logger"; import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm"; import { eq } from "drizzle-orm";
import semver from "semver"; import semver from "semver";
const NEWT_V2_TARGETS_VERSION = ">=1.10.3"; const NEWT_V2_TARGETS_VERSION = ">=1.10.3";
@@ -59,42 +59,6 @@ export async function addTargets(
); );
} }
export async function addTargetsBatch(
entries: {
newtId: string;
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
version?: string | null;
}[]
) {
if (entries.length === 0) {
return;
}
const resolved = await Promise.all(
entries.map(async (entry) => ({
...entry,
targets: await convertTargetsIfNecessary(
entry.newtId,
entry.targets
)
}))
);
await sendToClientsBatch(
resolved.map((entry) => ({
clientId: entry.newtId,
message: {
type: `newt/wg/targets/add`,
data: entry.targets
},
options: {
incrementConfigVersion: true,
compress: canCompress(entry.version, "newt")
}
}))
);
}
export async function removeTargets( export async function removeTargets(
newtId: string, newtId: string,
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[], targets: SubnetProxyTarget[] | SubnetProxyTargetV2[],
@@ -112,42 +76,6 @@ export async function removeTargets(
); );
} }
export async function removeTargetsBatch(
entries: {
newtId: string;
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
version?: string | null;
}[]
) {
if (entries.length === 0) {
return;
}
const resolved = await Promise.all(
entries.map(async (entry) => ({
...entry,
targets: await convertTargetsIfNecessary(
entry.newtId,
entry.targets
)
}))
);
await sendToClientsBatch(
resolved.map((entry) => ({
clientId: entry.newtId,
message: {
type: `newt/wg/targets/remove`,
data: entry.targets
},
options: {
incrementConfigVersion: true,
compress: canCompress(entry.version, "newt")
}
}))
);
}
export async function updateTargets( export async function updateTargets(
newtId: string, newtId: string,
targets: { targets: {
@@ -273,171 +201,6 @@ export async function removePeerData(
}); });
} }
const resolveOlmTargets = async (
entries: {
clientId: number;
olmId?: string;
version?: string | null;
}[]
) => {
const unresolvedClientIds = entries
.filter((entry) => !entry.olmId)
.map((entry) => entry.clientId);
const olmMap = new Map<number, { olmId: string; version: string | null }>();
if (unresolvedClientIds.length > 0) {
const olmRows = await db
.select({
clientId: olms.clientId,
olmId: olms.olmId,
version: olms.version
})
.from(olms)
.where(inArray(olms.clientId, unresolvedClientIds));
for (const row of olmRows) {
if (row.clientId !== null) {
olmMap.set(row.clientId, {
olmId: row.olmId,
version: row.version
});
}
}
}
return entries
.map((entry) => {
if (entry.olmId) {
return {
clientId: entry.clientId,
olmId: entry.olmId,
version: entry.version
};
}
const resolved = olmMap.get(entry.clientId);
if (!resolved) {
return null;
}
return {
clientId: entry.clientId,
olmId: resolved.olmId,
version: entry.version ?? resolved.version
};
})
.filter((entry) => entry !== null);
};
export async function addPeerDataBatch(
entries: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: Alias[];
olmId?: string;
version?: string | null;
}[]
) {
if (entries.length === 0) {
return;
}
const resolvedTargets = await resolveOlmTargets(entries);
if (resolvedTargets.length === 0) {
return;
}
const payloads = entries
.map((entry) => {
const resolved = resolvedTargets.find(
(target) => target.clientId === entry.clientId
);
if (!resolved) {
return null;
}
return {
clientId: resolved.olmId,
message: {
type: `olm/wg/peer/data/add`,
data: {
siteId: entry.siteId,
remoteSubnets: entry.remoteSubnets,
aliases: entry.aliases
}
},
options: {
incrementConfigVersion: true,
compress: canCompress(resolved.version, "olm")
}
};
})
.filter((entry) => entry !== null);
if (payloads.length === 0) {
return;
}
await sendToClientsBatch(payloads);
}
export async function removePeerDataBatch(
entries: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: Alias[];
olmId?: string;
version?: string | null;
}[]
) {
if (entries.length === 0) {
return;
}
const resolvedTargets = await resolveOlmTargets(entries);
if (resolvedTargets.length === 0) {
return;
}
const payloads = entries
.map((entry) => {
const resolved = resolvedTargets.find(
(target) => target.clientId === entry.clientId
);
if (!resolved) {
return null;
}
return {
clientId: resolved.olmId,
message: {
type: `olm/wg/peer/data/remove`,
data: {
siteId: entry.siteId,
remoteSubnets: entry.remoteSubnets,
aliases: entry.aliases
}
},
options: {
incrementConfigVersion: true,
compress: canCompress(resolved.version, "olm")
}
};
})
.filter((entry) => entry !== null);
if (payloads.length === 0) {
return;
}
await sendToClientsBatch(payloads);
}
export async function updatePeerData( export async function updatePeerData(
clientId: number, clientId: number,
siteId: number, siteId: number,

View File

@@ -10,7 +10,7 @@ import { verifyPassword } from "@server/auth/password";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger"; import logger from "@server/logger";
import { regionalCache as cache } from "#dynamic/lib/cache"; import cache from "#dynamic/lib/cache";
import config from "@server/lib/config"; import config from "@server/lib/config";
// Stale-while-revalidate in-memory fallback for the releases API. // Stale-while-revalidate in-memory fallback for the releases API.

View File

@@ -2,7 +2,7 @@ import { MessageHandler } from "@server/routers/ws";
import logger from "@server/logger"; import logger from "@server/logger";
import { Newt } from "@server/db"; import { Newt } from "@server/db";
import { applyNewtDockerBlueprint } from "@server/lib/blueprints/applyNewtDockerBlueprint"; import { applyNewtDockerBlueprint } from "@server/lib/blueprints/applyNewtDockerBlueprint";
import cache from "#dynamic/lib/cache"; // not using regional here because we dont know where the site is import cache from "#dynamic/lib/cache";
export const handleDockerStatusMessage: MessageHandler = async (context) => { export const handleDockerStatusMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context; const { message, client, sendToClient } = context;

View File

@@ -1,7 +1,7 @@
import { db, Site } from "@server/db"; import { db, Site } from "@server/db";
import { newts, sites } from "@server/db"; import { newts, sites } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger"; import logger from "@server/logger";
export async function addPeer( export async function addPeer(
@@ -36,14 +36,10 @@ export async function addPeer(
newtId = newt.newtId; newtId = newt.newtId;
} }
await sendToClient( await sendToClient(newtId, {
newtId,
{
type: "newt/wg/peer/add", type: "newt/wg/peer/add",
data: peer data: peer
}, }, { incrementConfigVersion: true }).catch((error) => {
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -80,16 +76,12 @@ export async function deletePeer(
newtId = newt.newtId; newtId = newt.newtId;
} }
await sendToClient( await sendToClient(newtId, {
newtId,
{
type: "newt/wg/peer/remove", type: "newt/wg/peer/remove",
data: { data: {
publicKey publicKey
} }
}, }, { incrementConfigVersion: true }).catch((error) => {
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -98,35 +90,6 @@ export async function deletePeer(
return site; return site;
} }
export async function deletePeersBatch(
peers: {
siteId: number;
publicKey: string;
newtId: string;
}[]
) {
if (peers.length === 0) {
return;
}
await sendToClientsBatch(
peers.map((peer) => ({
clientId: peer.newtId,
message: {
type: "newt/wg/peer/remove",
data: {
publicKey: peer.publicKey
}
},
options: { incrementConfigVersion: true }
}))
).catch((error) => {
logger.warn(`Error sending batched newt peer removals:`, error);
});
logger.info(`Deleted ${peers.length} peer(s) from newts (batch)`);
}
export async function updatePeer( export async function updatePeer(
siteId: number, siteId: number,
publicKey: string, publicKey: string,
@@ -159,17 +122,13 @@ export async function updatePeer(
newtId = newt.newtId; newtId = newt.newtId;
} }
await sendToClient( await sendToClient(newtId, {
newtId,
{
type: "newt/wg/peer/update", type: "newt/wg/peer/update",
data: { data: {
publicKey, publicKey,
...peer ...peer
} }
}, }, { incrementConfigVersion: true }).catch((error) => {
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });

View File

@@ -20,7 +20,7 @@ import { handleFingerprintInsertion } from "./fingerprintingUtils";
import { build } from "@server/build"; import { build } from "@server/build";
import { canCompress } from "@server/lib/clientVersionChecks"; import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config"; import config from "@server/lib/config";
import cache from "#dynamic/lib/cache"; // not using regional here because we need this in the register message handler before we know where the client is import cache from "#dynamic/lib/cache";
const HOLEPUNCH_STALE_CHAIN_THRESHOLD = 18; const HOLEPUNCH_STALE_CHAIN_THRESHOLD = 18;
const HOLEPUNCH_STALE_CHAIN_TTL_SECONDS = 1800; const HOLEPUNCH_STALE_CHAIN_TTL_SECONDS = 1800;

View File

@@ -1,9 +1,9 @@
import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import { clientSitesAssociationsCache, db, olms } from "@server/db"; import { clientSitesAssociationsCache, db, olms } from "@server/db";
import { canCompress } from "@server/lib/clientVersionChecks"; import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config"; import config from "@server/lib/config";
import logger from "@server/logger"; import logger from "@server/logger";
import { and, eq, inArray } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import { Alias } from "yaml"; import { Alias } from "yaml";
export async function addPeer( export async function addPeer(
@@ -205,150 +205,3 @@ export async function initPeerAddHandshake(
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}` `Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
); );
} }
export async function deletePeersBatch(
peers: {
clientId: number;
siteId: number;
publicKey: string;
olmId?: string;
version?: string | null;
}[]
) {
if (peers.length === 0) {
return;
}
const unresolvedClientIds = peers
.filter((peer) => !peer.olmId)
.map((peer) => peer.clientId);
const olmByClientId = new Map<
number,
{ olmId: string; version: string | null }
>();
if (unresolvedClientIds.length > 0) {
const olmRows = await db
.select({
clientId: olms.clientId,
olmId: olms.olmId,
version: olms.version
})
.from(olms)
.where(inArray(olms.clientId, unresolvedClientIds));
for (const row of olmRows) {
if (row.clientId !== null) {
olmByClientId.set(row.clientId, {
olmId: row.olmId,
version: row.version
});
}
}
}
const batchPayloads = peers
.map((peer) => {
const resolved = peer.olmId
? { olmId: peer.olmId, version: peer.version ?? null }
: olmByClientId.get(peer.clientId);
if (!resolved) {
return null;
}
return {
clientId: resolved.olmId,
message: {
type: "olm/wg/peer/remove",
data: {
publicKey: peer.publicKey,
siteId: peer.siteId
}
},
options: {
incrementConfigVersion: true,
compress: canCompress(
peer.version ?? resolved.version,
"olm"
)
}
};
})
.filter((payload) => payload !== null);
if (batchPayloads.length === 0) {
return;
}
await sendToClientsBatch(batchPayloads).catch((error) => {
logger.warn(`Error sending batched olm peer removals:`, error);
});
logger.info(`Deleted ${batchPayloads.length} peer(s) from olms (batch)`);
}
export async function initPeerAddHandshakeBatch(
handshakes: {
clientId: number;
peer: {
siteId: number;
exitNode: {
publicKey: string;
endpoint: string;
};
};
olmId: string;
chainId?: string;
}[]
) {
if (handshakes.length === 0) {
return;
}
await sendToClientsBatch(
handshakes.map((item) => ({
clientId: item.olmId,
message: {
type: "olm/wg/peer/holepunch/site/add",
data: {
siteId: item.peer.siteId,
exitNode: {
publicKey: item.peer.exitNode.publicKey,
relayPort:
config.getRawConfig().gerbil.clients_start_port,
endpoint: item.peer.exitNode.endpoint
},
chainId: item.chainId
}
},
options: { incrementConfigVersion: true }
}))
).catch((error) => {
logger.warn(`Error sending batched olm handshakes:`, error);
});
await Promise.all(
handshakes.map((item) =>
db
.update(clientSitesAssociationsCache)
.set({ isJitMode: false })
.where(
and(
eq(
clientSitesAssociationsCache.clientId,
item.clientId
),
eq(
clientSitesAssociationsCache.siteId,
item.peer.siteId
)
)
)
)
);
logger.info(
`Initiated ${handshakes.length} peer add handshake(s) to olms (batch)`
);
}

View File

@@ -15,7 +15,8 @@ import logger from "@server/logger";
import { z } from "zod"; import { z } from "zod";
import { fromZodError } from "zod-validation-error"; import { fromZodError } from "zod-validation-error";
import type { PaginatedResponse } from "@server/types/Pagination"; import type { PaginatedResponse } from "@server/types/Pagination";
import { regionalCache as cache } from "#dynamic/lib/cache"; import { OpenAPITags, registry } from "@server/openApi";
import { localCache } from "#dynamic/lib/cache";
const USER_RESOURCE_ALIASES_CACHE_TTL_SEC = 60; const USER_RESOURCE_ALIASES_CACHE_TTL_SEC = 60;
@@ -152,7 +153,7 @@ export async function listUserResourceAliases(
pageSize pageSize
); );
const cachedData: ListUserResourceAliasesResponse | undefined = const cachedData: ListUserResourceAliasesResponse | undefined =
await cache.get(cacheKey); localCache.get(cacheKey);
if (cachedData) { if (cachedData) {
return response<ListUserResourceAliasesResponse>(res, { return response<ListUserResourceAliasesResponse>(res, {
@@ -210,11 +211,7 @@ export async function listUserResourceAliases(
page page
} }
}; };
await cache.set( localCache.set(cacheKey, data, USER_RESOURCE_ALIASES_CACHE_TTL_SEC);
cacheKey,
data,
USER_RESOURCE_ALIASES_CACHE_TTL_SEC
);
return response<ListUserResourceAliasesResponse>(res, { return response<ListUserResourceAliasesResponse>(res, {
data, data,
success: true, success: true,
@@ -259,7 +256,7 @@ export async function listUserResourceAliases(
page page
} }
}; };
await cache.set(cacheKey, data, USER_RESOURCE_ALIASES_CACHE_TTL_SEC); localCache.set(cacheKey, data, USER_RESOURCE_ALIASES_CACHE_TTL_SEC);
return response<ListUserResourceAliasesResponse>(res, { return response<ListUserResourceAliasesResponse>(res, {
data, data,

View File

@@ -14,7 +14,7 @@ import {
siteLabels, siteLabels,
type Label type Label
} from "@server/db"; } from "@server/db";
import { regionalCache as cache } from "#dynamic/lib/cache"; import cache from "#dynamic/lib/cache";
import response from "@server/lib/response"; import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";

View File

@@ -28,10 +28,7 @@ import {
isIpInCidr, isIpInCidr,
portRangeStringSchema portRangeStringSchema
} from "@server/lib/ip"; } from "@server/lib/ip";
import { import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
getClientSiteResourceAccess,
rebuildClientAssociationsFromSiteResource
} from "@server/lib/rebuildClientAssociations";
import logger from "@server/logger"; import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
@@ -849,17 +846,12 @@ export async function handleMessagingForUpdatedSiteResource(
updatedSiteResource updatedSiteResource
); );
const { mergedAllClients } =
await rebuildClientAssociationsFromSiteResource( await rebuildClientAssociationsFromSiteResource(
existingSiteResource || updatedSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below existingSiteResource || updatedSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below
trx trx
); );
const { sitesList, mergedAllClients, mergedAllClientIds } =
await getClientSiteResourceAccess(
existingSiteResource || updatedSiteResource,
trx
);
// after everything is rebuilt above we still need to update the targets and remote subnets if the destination changed // after everything is rebuilt above we still need to update the targets and remote subnets if the destination changed
const destinationChanged = const destinationChanged =
existingSiteResource && existingSiteResource &&

View File

@@ -76,32 +76,12 @@ export interface SendMessageOptions {
compress?: boolean; compress?: boolean;
} }
export interface BatchSendMessage { // Redis message type for cross-node communication
clientId: string; export interface RedisMessage {
message: WSMessage; type: "direct" | "broadcast";
options?: SendMessageOptions; targetClientId?: string;
}
// Redis message types for cross-node communication
export type RedisMessage =
| {
type: "direct";
targetClientId: string;
message: WSMessage;
fromNodeId: string;
}
| {
type: "direct-batch";
messages: {
targetClientId: string;
message: WSMessage;
}[];
fromNodeId: string;
}
| {
type: "broadcast";
excludeClientId?: string; excludeClientId?: string;
message: WSMessage; message: WSMessage;
fromNodeId: string; fromNodeId: string;
options?: SendMessageOptions; options?: SendMessageOptions;
}; }

View File

@@ -26,8 +26,7 @@ import {
WebSocketRequest, WebSocketRequest,
WSMessage, WSMessage,
AuthenticatedWebSocket, AuthenticatedWebSocket,
SendMessageOptions, SendMessageOptions
BatchSendMessage
} from "./types"; } from "./types";
import { validateSessionToken } from "@server/auth/sessions/app"; import { validateSessionToken } from "@server/auth/sessions/app";
@@ -213,20 +212,6 @@ const sendToClient = async (
return localSent; return localSent;
}; };
const sendToClientsBatch = async (
entries: BatchSendMessage[]
): Promise<void> => {
if (entries.length === 0) {
return;
}
await Promise.all(
entries.map((entry) =>
sendToClient(entry.clientId, entry.message, entry.options)
)
);
};
const broadcastToAllExcept = async ( const broadcastToAllExcept = async (
message: WSMessage, message: WSMessage,
excludeClientId?: string, excludeClientId?: string,
@@ -567,7 +552,6 @@ export {
router, router,
handleWSUpgrade, handleWSUpgrade,
sendToClient, sendToClient,
sendToClientsBatch,
broadcastToAllExcept, broadcastToAllExcept,
connectedClients, connectedClients,
hasActiveConnections, hasActiveConnections,