diff --git a/server/db/pg/schema/privateSchema.ts b/server/db/pg/schema/privateSchema.ts index 229fc9ff0..ae73b97ac 100644 --- a/server/db/pg/schema/privateSchema.ts +++ b/server/db/pg/schema/privateSchema.ts @@ -11,7 +11,7 @@ import { primaryKey, uniqueIndex } from "drizzle-orm/pg-core"; -import { InferSelectModel } from "drizzle-orm"; +import { InferSelectModel, sql } from "drizzle-orm"; import { domains, orgs, @@ -207,17 +207,28 @@ export const remoteExitNodeSessions = pgTable("remoteExitNodeSession", { expiresAt: bigint("expiresAt", { mode: "number" }).notNull() }); -export const loginPage = pgTable("loginPage", { - loginPageId: serial("loginPageId").primaryKey(), - subdomain: varchar("subdomain"), - fullDomain: varchar("fullDomain"), - exitNodeId: integer("exitNodeId").references(() => exitNodes.exitNodeId, { - onDelete: "set null" - }), - domainId: varchar("domainId").references(() => domains.domainId, { - onDelete: "set null" - }) -}); +export const loginPage = pgTable( + "loginPage", + { + loginPageId: serial("loginPageId").primaryKey(), + subdomain: varchar("subdomain"), + fullDomain: varchar("fullDomain"), + exitNodeId: integer("exitNodeId").references( + () => exitNodes.exitNodeId, + { + onDelete: "set null" + } + ), + domainId: varchar("domainId").references(() => domains.domainId, { + onDelete: "set null" + }) + }, + (t) => [ + index("idx_loginpage_fulldomain") + .on(t.fullDomain) + .where(sql`${t.fullDomain} IS NOT NULL`) + ] +); export const loginPageOrg = pgTable("loginPageOrg", { loginPageId: integer("loginPageId") diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index 025bdf923..1b48aa520 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -1,5 +1,5 @@ import { randomUUID } from "crypto"; -import { InferSelectModel } from "drizzle-orm"; +import { InferSelectModel, sql } from "drizzle-orm"; import { bigint, boolean, @@ -82,107 +82,130 @@ export const orgDomains = pgTable("orgDomains", { .references(() => domains.domainId, { onDelete: "cascade" }) }); -export const sites = pgTable("sites", { - siteId: serial("siteId").primaryKey(), - orgId: varchar("orgId") - .references(() => orgs.orgId, { - onDelete: "cascade" - }) - .notNull(), - niceId: varchar("niceId").notNull(), - exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { - onDelete: "set null" - }), - name: varchar("name").notNull(), - pubKey: varchar("pubKey"), - subnet: varchar("subnet"), - megabytesIn: real("bytesIn").default(0), - megabytesOut: real("bytesOut").default(0), - lastBandwidthUpdate: varchar("lastBandwidthUpdate"), - type: varchar("type").notNull(), // "newt" or "wireguard" - online: boolean("online").notNull().default(false), - lastPing: integer("lastPing"), - address: varchar("address"), - endpoint: varchar("endpoint"), - publicKey: varchar("publicKey"), - lastHolePunch: bigint("lastHolePunch", { mode: "number" }), - listenPort: integer("listenPort"), - dockerSocketEnabled: boolean("dockerSocketEnabled").notNull().default(true), - autoUpdateEnabled: boolean("autoUpdateEnabled").notNull().default(false), - autoUpdateOverrideOrg: boolean("autoUpdateOverrideOrg") - .notNull() - .default(false), - status: varchar("status") - .$type<"pending" | "approved">() - .default("approved") -}); +export const sites = pgTable( + "sites", + { + siteId: serial("siteId").primaryKey(), + orgId: varchar("orgId") + .references(() => orgs.orgId, { + onDelete: "cascade" + }) + .notNull(), + niceId: varchar("niceId").notNull(), + exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { + onDelete: "set null" + }), + name: varchar("name").notNull(), + pubKey: varchar("pubKey"), + subnet: varchar("subnet"), + megabytesIn: real("bytesIn").default(0), + megabytesOut: real("bytesOut").default(0), + lastBandwidthUpdate: varchar("lastBandwidthUpdate"), + type: varchar("type").notNull(), // "newt" or "wireguard" + online: boolean("online").notNull().default(false), + lastPing: integer("lastPing"), + address: varchar("address"), + endpoint: varchar("endpoint"), + publicKey: varchar("publicKey"), + lastHolePunch: bigint("lastHolePunch", { mode: "number" }), + listenPort: integer("listenPort"), + dockerSocketEnabled: boolean("dockerSocketEnabled") + .notNull() + .default(true), + autoUpdateEnabled: boolean("autoUpdateEnabled") + .notNull() + .default(false), + autoUpdateOverrideOrg: boolean("autoUpdateOverrideOrg") + .notNull() + .default(false), + status: varchar("status") + .$type<"pending" | "approved">() + .default("approved") + }, + (t) => [ + index("idx_sites_exitnodeid").on(t.exitNodeId), + index("idx_sites_exitnode_type_siteid").on( + t.exitNodeId, + t.type, + t.siteId + ) + ] +); -export const resources = pgTable("resources", { - resourceId: serial("resourceId").primaryKey(), - resourcePolicyId: integer("resourcePolicyId").references( - () => resourcePolicies.resourcePolicyId, - { onDelete: "set null" } - ), - defaultResourcePolicyId: integer("defaultResourcePolicyId").references( - () => resourcePolicies.resourcePolicyId, - { - onDelete: "restrict" - } - ), - resourceGuid: varchar("resourceGuid", { length: 36 }) - .unique() - .notNull() - .$defaultFn(() => randomUUID()), - orgId: varchar("orgId") - .references(() => orgs.orgId, { - onDelete: "cascade" - }) - .notNull(), - niceId: text("niceId").notNull(), - name: varchar("name").notNull(), - subdomain: varchar("subdomain"), - fullDomain: varchar("fullDomain"), - domainId: varchar("domainId").references(() => domains.domainId, { - onDelete: "set null" - }), - ssl: boolean("ssl").notNull().default(false), - blockAccess: boolean("blockAccess").notNull().default(false), - proxyPort: integer("proxyPort"), - sso: boolean("sso"), - emailWhitelistEnabled: boolean("emailWhitelistEnabled"), - applyRules: boolean("applyRules"), - enabled: boolean("enabled").notNull().default(true), - stickySession: boolean("stickySession").notNull().default(false), - tlsServerName: varchar("tlsServerName"), - setHostHeader: varchar("setHostHeader"), - enableProxy: boolean("enableProxy").default(true), - skipToIdpId: integer("skipToIdpId").references(() => idp.idpId, { - onDelete: "set null" - }), - headers: text("headers"), // comma-separated list of headers to add to the request - proxyProtocol: boolean("proxyProtocol").notNull().default(false), - proxyProtocolVersion: integer("proxyProtocolVersion").default(1), - maintenanceModeEnabled: boolean("maintenanceModeEnabled") - .notNull() - .default(false), - maintenanceModeType: text("maintenanceModeType", { - enum: ["forced", "automatic"] - }).default("forced"), // "forced" = always show, "automatic" = only when down - maintenanceTitle: text("maintenanceTitle"), - maintenanceMessage: text("maintenanceMessage"), - maintenanceEstimatedTime: text("maintenanceEstimatedTime"), - postAuthPath: text("postAuthPath"), - health: varchar("health").default("unknown"), // "healthy", "unhealthy", "unknown" - wildcard: boolean("wildcard").notNull().default(false), - mode: text("mode").default("http").notNull(), // rdp, ssh, http, vnc - pamMode: varchar("pamMode", { length: 32 }) - .$type<"passthrough" | "push">() - .default("passthrough"), - authDaemonMode: varchar("authDaemonMode", { length: 32 }) - .$type<"site" | "remote" | "native">() - .default("site"), - authDaemonPort: integer("authDaemonPort").default(22123) -}); +export const resources = pgTable( + "resources", + { + resourceId: serial("resourceId").primaryKey(), + resourcePolicyId: integer("resourcePolicyId").references( + () => resourcePolicies.resourcePolicyId, + { onDelete: "set null" } + ), + defaultResourcePolicyId: integer("defaultResourcePolicyId").references( + () => resourcePolicies.resourcePolicyId, + { + onDelete: "restrict" + } + ), + resourceGuid: varchar("resourceGuid", { length: 36 }) + .unique() + .notNull() + .$defaultFn(() => randomUUID()), + orgId: varchar("orgId") + .references(() => orgs.orgId, { + onDelete: "cascade" + }) + .notNull(), + niceId: text("niceId").notNull(), + name: varchar("name").notNull(), + subdomain: varchar("subdomain"), + fullDomain: varchar("fullDomain"), + domainId: varchar("domainId").references(() => domains.domainId, { + onDelete: "set null" + }), + ssl: boolean("ssl").notNull().default(false), + blockAccess: boolean("blockAccess").notNull().default(false), + proxyPort: integer("proxyPort"), + sso: boolean("sso"), + emailWhitelistEnabled: boolean("emailWhitelistEnabled"), + applyRules: boolean("applyRules"), + enabled: boolean("enabled").notNull().default(true), + stickySession: boolean("stickySession").notNull().default(false), + tlsServerName: varchar("tlsServerName"), + setHostHeader: varchar("setHostHeader"), + enableProxy: boolean("enableProxy").default(true), + skipToIdpId: integer("skipToIdpId").references(() => idp.idpId, { + onDelete: "set null" + }), + headers: text("headers"), // comma-separated list of headers to add to the request + proxyProtocol: boolean("proxyProtocol").notNull().default(false), + proxyProtocolVersion: integer("proxyProtocolVersion").default(1), + maintenanceModeEnabled: boolean("maintenanceModeEnabled") + .notNull() + .default(false), + maintenanceModeType: text("maintenanceModeType", { + enum: ["forced", "automatic"] + }).default("forced"), // "forced" = always show, "automatic" = only when down + maintenanceTitle: text("maintenanceTitle"), + maintenanceMessage: text("maintenanceMessage"), + maintenanceEstimatedTime: text("maintenanceEstimatedTime"), + postAuthPath: text("postAuthPath"), + health: varchar("health").default("unknown"), // "healthy", "unhealthy", "unknown" + wildcard: boolean("wildcard").notNull().default(false), + mode: text("mode").default("http").notNull(), // rdp, ssh, http, vnc + pamMode: varchar("pamMode", { length: 32 }) + .$type<"passthrough" | "push">() + .default("passthrough"), + authDaemonMode: varchar("authDaemonMode", { length: 32 }) + .$type<"site" | "remote" | "native">() + .default("site"), + authDaemonPort: integer("authDaemonPort").default(22123) + }, + (t) => [ + index("idx_resources_fulldomain") + .on(t.fullDomain) + .where(sql`${t.fullDomain} IS NOT NULL`) + ] +); export const labels = pgTable("labels", { labelId: serial("labelId").primaryKey(), @@ -267,71 +290,84 @@ export const clientLabels = pgTable( (t) => [unique("client_label_uniq").on(t.clientId, t.labelId)] ); -export const targets = pgTable("targets", { - targetId: serial("targetId").primaryKey(), - resourceId: integer("resourceId") - .references(() => resources.resourceId, { - onDelete: "cascade" - }) - .notNull(), - siteId: integer("siteId") - .references(() => sites.siteId, { - onDelete: "cascade" - }) - .notNull(), - ip: varchar("ip").notNull(), - method: varchar("method"), - port: integer("port").notNull(), - internalPort: integer("internalPort"), - enabled: boolean("enabled").notNull().default(true), - path: text("path"), - pathMatchType: text("pathMatchType"), // exact, prefix, regex - rewritePath: text("rewritePath"), // if set, rewrites the path to this value before sending to the target - rewritePathType: text("rewritePathType"), // exact, prefix, regex, stripPrefix - priority: integer("priority").notNull().default(100), - mode: varchar("mode") - .$type<"http" | "tcp" | "udp" | "ssh" | "rdp" | "vnc">() - .notNull() - .default("http"), - authToken: varchar("authToken") -}); +export const targets = pgTable( + "targets", + { + targetId: serial("targetId").primaryKey(), + resourceId: integer("resourceId") + .references(() => resources.resourceId, { + onDelete: "cascade" + }) + .notNull(), + siteId: integer("siteId") + .references(() => sites.siteId, { + onDelete: "cascade" + }) + .notNull(), + ip: varchar("ip").notNull(), + method: varchar("method"), + port: integer("port").notNull(), + internalPort: integer("internalPort"), + enabled: boolean("enabled").notNull().default(true), + path: text("path"), + pathMatchType: text("pathMatchType"), // exact, prefix, regex + rewritePath: text("rewritePath"), // if set, rewrites the path to this value before sending to the target + rewritePathType: text("rewritePathType"), // exact, prefix, regex, stripPrefix + priority: integer("priority").notNull().default(100), + mode: varchar("mode") + .$type<"http" | "tcp" | "udp" | "ssh" | "rdp" | "vnc">() + .notNull() + .default("http"), + authToken: varchar("authToken") + }, + (t) => [ + index("idx_targets_resourceid_siteid").on(t.resourceId, t.siteId), + index("idx_targets_site_enabled_priority_target_resource") + .on(t.siteId, t.priority.desc(), t.targetId, t.resourceId) + .where(sql`${t.enabled} = true`) + ] +); -export const targetHealthCheck = pgTable("targetHealthCheck", { - targetHealthCheckId: serial("targetHealthCheckId").primaryKey(), - targetId: integer("targetId").references(() => targets.targetId, { - onDelete: "cascade" - }), - orgId: varchar("orgId") - .references(() => orgs.orgId, { +export const targetHealthCheck = pgTable( + "targetHealthCheck", + { + targetHealthCheckId: serial("targetHealthCheckId").primaryKey(), + targetId: integer("targetId").references(() => targets.targetId, { onDelete: "cascade" - }) - .notNull(), - siteId: integer("siteId") - .references(() => sites.siteId, { - onDelete: "cascade" - }) - .notNull(), - name: varchar("name"), - hcEnabled: boolean("hcEnabled").notNull().default(false), - hcPath: varchar("hcPath"), - hcScheme: varchar("hcScheme"), - hcMode: varchar("hcMode").default("http"), - hcHostname: varchar("hcHostname"), - hcPort: integer("hcPort"), - hcInterval: integer("hcInterval").default(30), // in seconds - hcUnhealthyInterval: integer("hcUnhealthyInterval").default(30), // in seconds - hcTimeout: integer("hcTimeout").default(5), // in seconds - hcHeaders: varchar("hcHeaders"), - hcFollowRedirects: boolean("hcFollowRedirects").default(true), - hcMethod: varchar("hcMethod").default("GET"), - hcStatus: integer("hcStatus"), // http code - hcHealth: text("hcHealth") - .$type<"unknown" | "healthy" | "unhealthy">() - .default("unknown"), // "unknown", "healthy", "unhealthy" - hcTlsServerName: text("hcTlsServerName"), - hcHealthyThreshold: integer("hcHealthyThreshold").default(1), - hcUnhealthyThreshold: integer("hcUnhealthyThreshold").default(1) -}); + }), + orgId: varchar("orgId") + .references(() => orgs.orgId, { + onDelete: "cascade" + }) + .notNull(), + siteId: integer("siteId") + .references(() => sites.siteId, { + onDelete: "cascade" + }) + .notNull(), + name: varchar("name"), + hcEnabled: boolean("hcEnabled").notNull().default(false), + hcPath: varchar("hcPath"), + hcScheme: varchar("hcScheme"), + hcMode: varchar("hcMode").default("http"), + hcHostname: varchar("hcHostname"), + hcPort: integer("hcPort"), + hcInterval: integer("hcInterval").default(30), // in seconds + hcUnhealthyInterval: integer("hcUnhealthyInterval").default(30), // in seconds + hcTimeout: integer("hcTimeout").default(5), // in seconds + hcHeaders: varchar("hcHeaders"), + hcFollowRedirects: boolean("hcFollowRedirects").default(true), + hcMethod: varchar("hcMethod").default("GET"), + hcStatus: integer("hcStatus"), // http code + hcHealth: text("hcHealth") + .$type<"unknown" | "healthy" | "unhealthy">() + .default("unknown"), // "unknown", "healthy", "unhealthy" + hcTlsServerName: text("hcTlsServerName"), + hcHealthyThreshold: integer("hcHealthyThreshold").default(1), + hcUnhealthyThreshold: integer("hcUnhealthyThreshold").default(1) + }, + (t) => [index("idx_targethealthcheck_targetid").on(t.targetId)] +); export const exitNodes = pgTable("exitNodes", { exitNodeId: serial("exitNodeId").primaryKey(), @@ -406,43 +442,74 @@ export const networks = pgTable("networks", { .notNull() }); -export const siteNetworks = pgTable("siteNetworks", { - siteId: integer("siteId") - .notNull() - .references(() => sites.siteId, { - onDelete: "cascade" - }), - networkId: integer("networkId") - .notNull() - .references(() => networks.networkId, { onDelete: "cascade" }) -}); +export const siteNetworks = pgTable( + "siteNetworks", + { + siteId: integer("siteId") + .notNull() + .references(() => sites.siteId, { + onDelete: "cascade" + }), + networkId: integer("networkId") + .notNull() + .references(() => networks.networkId, { onDelete: "cascade" }) + }, + (t) => [ + index("idx_sitenetworks_siteid").on(t.siteId), + index("idx_sitenetworks_networkid").on(t.networkId) + ] +); -export const clientSiteResources = pgTable("clientSiteResources", { - clientId: integer("clientId") - .notNull() - .references(() => clients.clientId, { onDelete: "cascade" }), - siteResourceId: integer("siteResourceId") - .notNull() - .references(() => siteResources.siteResourceId, { onDelete: "cascade" }) -}); +export const clientSiteResources = pgTable( + "clientSiteResources", + { + clientId: integer("clientId") + .notNull() + .references(() => clients.clientId, { onDelete: "cascade" }), + siteResourceId: integer("siteResourceId") + .notNull() + .references(() => siteResources.siteResourceId, { + onDelete: "cascade" + }) + }, + (t) => [ + index("idx_clientsiteresources_clientid").on(t.clientId), + index("idx_clientsiteresources_siteresourceid").on(t.siteResourceId) + ] +); -export const roleSiteResources = pgTable("roleSiteResources", { - roleId: integer("roleId") - .notNull() - .references(() => roles.roleId, { onDelete: "cascade" }), - siteResourceId: integer("siteResourceId") - .notNull() - .references(() => siteResources.siteResourceId, { onDelete: "cascade" }) -}); +export const roleSiteResources = pgTable( + "roleSiteResources", + { + roleId: integer("roleId") + .notNull() + .references(() => roles.roleId, { onDelete: "cascade" }), + siteResourceId: integer("siteResourceId") + .notNull() + .references(() => siteResources.siteResourceId, { + onDelete: "cascade" + }) + }, + (t) => [index("idx_rolesiteresources_siteresourceid").on(t.siteResourceId)] +); -export const userSiteResources = pgTable("userSiteResources", { - userId: varchar("userId") - .notNull() - .references(() => users.userId, { onDelete: "cascade" }), - siteResourceId: integer("siteResourceId") - .notNull() - .references(() => siteResources.siteResourceId, { onDelete: "cascade" }) -}); +export const userSiteResources = pgTable( + "userSiteResources", + { + userId: varchar("userId") + .notNull() + .references(() => users.userId, { onDelete: "cascade" }), + siteResourceId: integer("siteResourceId") + .notNull() + .references(() => siteResources.siteResourceId, { + onDelete: "cascade" + }) + }, + (t) => [ + index("idx_usersiteresources_userid").on(t.userId), + index("idx_usersiteresources_siteresourceid").on(t.siteResourceId) + ] +); export const users = pgTable("user", { userId: varchar("id").primaryKey(), @@ -467,15 +534,19 @@ export const users = pgTable("user", { locale: varchar("locale") }); -export const newts = pgTable("newt", { - newtId: varchar("id").primaryKey(), - secretHash: varchar("secretHash").notNull(), - dateCreated: varchar("dateCreated").notNull(), - version: varchar("version"), - siteId: integer("siteId").references(() => sites.siteId, { - onDelete: "cascade" - }) -}); +export const newts = pgTable( + "newt", + { + newtId: varchar("id").primaryKey(), + secretHash: varchar("secretHash").notNull(), + dateCreated: varchar("dateCreated").notNull(), + version: varchar("version"), + siteId: integer("siteId").references(() => sites.siteId, { + onDelete: "cascade" + }) + }, + (t) => [index("idx_newt_siteid").on(t.siteId)] +); export const twoFactorBackupCodes = pgTable("twoFactorBackupCodes", { codeId: serial("id").primaryKey(), @@ -576,29 +647,49 @@ export const userOrgRoles = pgTable( (t) => [unique().on(t.userId, t.orgId, t.roleId)] ); -export const roleActions = pgTable("roleActions", { - roleId: integer("roleId") - .notNull() - .references(() => roles.roleId, { onDelete: "cascade" }), - actionId: varchar("actionId") - .notNull() - .references(() => actions.actionId, { onDelete: "cascade" }), - orgId: varchar("orgId") - .notNull() - .references(() => orgs.orgId, { onDelete: "cascade" }) -}); +export const roleActions = pgTable( + "roleActions", + { + roleId: integer("roleId") + .notNull() + .references(() => roles.roleId, { onDelete: "cascade" }), + actionId: varchar("actionId") + .notNull() + .references(() => actions.actionId, { onDelete: "cascade" }), + orgId: varchar("orgId") + .notNull() + .references(() => orgs.orgId, { onDelete: "cascade" }) + }, + (t) => [ + index("idx_roleActions_roleId_orgId_actionId").on( + t.roleId, + t.orgId, + t.actionId + ) + ] +); -export const userActions = pgTable("userActions", { - userId: varchar("userId") - .notNull() - .references(() => users.userId, { onDelete: "cascade" }), - actionId: varchar("actionId") - .notNull() - .references(() => actions.actionId, { onDelete: "cascade" }), - orgId: varchar("orgId") - .notNull() - .references(() => orgs.orgId, { onDelete: "cascade" }) -}); +export const userActions = pgTable( + "userActions", + { + userId: varchar("userId") + .notNull() + .references(() => users.userId, { onDelete: "cascade" }), + actionId: varchar("actionId") + .notNull() + .references(() => actions.actionId, { onDelete: "cascade" }), + orgId: varchar("orgId") + .notNull() + .references(() => orgs.orgId, { onDelete: "cascade" }) + }, + (t) => [ + index("idx_userActions_userId_orgId_actionId").on( + t.userId, + t.orgId, + t.actionId + ) + ] +); export const roleSites = pgTable("roleSites", { roleId: integer("roleId") @@ -1004,40 +1095,44 @@ export const idpOrg = pgTable("idpOrg", { orgMapping: varchar("orgMapping") }); -export const clients = pgTable("clients", { - clientId: serial("clientId").primaryKey(), - orgId: varchar("orgId") - .references(() => orgs.orgId, { +export const clients = pgTable( + "clients", + { + clientId: serial("clientId").primaryKey(), + orgId: varchar("orgId") + .references(() => orgs.orgId, { + onDelete: "cascade" + }) + .notNull(), + exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { + onDelete: "set null" + }), + userId: text("userId").references(() => users.userId, { + // optionally tied to a user and in this case delete when the user deletes onDelete: "cascade" - }) - .notNull(), - exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { - onDelete: "set null" - }), - userId: text("userId").references(() => users.userId, { - // optionally tied to a user and in this case delete when the user deletes - onDelete: "cascade" - }), - niceId: varchar("niceId").notNull(), - olmId: text("olmId"), // to lock it to a specific olm optionally - name: varchar("name").notNull(), - pubKey: varchar("pubKey"), - subnet: varchar("subnet").notNull(), - megabytesIn: real("bytesIn"), - megabytesOut: real("bytesOut"), - lastBandwidthUpdate: varchar("lastBandwidthUpdate"), - lastPing: integer("lastPing"), - type: varchar("type").notNull(), // "olm" - online: boolean("online").notNull().default(false), - // endpoint: varchar("endpoint"), - lastHolePunch: integer("lastHolePunch"), - maxConnections: integer("maxConnections"), - archived: boolean("archived").notNull().default(false), - blocked: boolean("blocked").notNull().default(false), - approvalState: varchar("approvalState").$type< - "pending" | "approved" | "denied" - >() -}); + }), + niceId: varchar("niceId").notNull(), + olmId: text("olmId"), // to lock it to a specific olm optionally + name: varchar("name").notNull(), + pubKey: varchar("pubKey"), + subnet: varchar("subnet").notNull(), + megabytesIn: real("bytesIn"), + megabytesOut: real("bytesOut"), + lastBandwidthUpdate: varchar("lastBandwidthUpdate"), + lastPing: integer("lastPing"), + type: varchar("type").notNull(), // "olm" + online: boolean("online").notNull().default(false), + // endpoint: varchar("endpoint"), + lastHolePunch: integer("lastHolePunch"), + maxConnections: integer("maxConnections"), + archived: boolean("archived").notNull().default(false), + blocked: boolean("blocked").notNull().default(false), + approvalState: varchar("approvalState").$type< + "pending" | "approved" | "denied" + >() + }, + (t) => [index("idx_clients_userid").on(t.userId)] +); export const clientSitesAssociationsCache = pgTable( "clientSitesAssociationsCache", @@ -1049,7 +1144,11 @@ export const clientSitesAssociationsCache = pgTable( isJitMode: boolean("isJitMode").notNull().default(false), endpoint: varchar("endpoint"), publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes - } + }, + (t) => [ + primaryKey({ columns: [t.clientId, t.siteId] }), + index("idx_clientsitesassociationscache_siteid").on(t.siteId) + ] ); export const clientSiteResourcesAssociationsCache = pgTable( @@ -1058,7 +1157,14 @@ export const clientSiteResourcesAssociationsCache = pgTable( clientId: integer("clientId") // not a foreign key here so after its deleted the rebuild function can delete it and send the message .notNull(), siteResourceId: integer("siteResourceId").notNull() - } + }, + (t) => [ + primaryKey({ columns: [t.clientId, t.siteResourceId] }), + index("idx_clientSiteResourcesAssociationsCache_siteResourceId").on( + t.siteResourceId, + t.clientId + ) + ] ); export const clientPostureSnapshots = pgTable("clientPostureSnapshots", { @@ -1071,23 +1177,27 @@ export const clientPostureSnapshots = pgTable("clientPostureSnapshots", { collectedAt: integer("collectedAt").notNull() }); -export const olms = pgTable("olms", { - olmId: varchar("id").primaryKey(), - secretHash: varchar("secretHash").notNull(), - dateCreated: varchar("dateCreated").notNull(), - version: text("version"), - agent: text("agent"), - name: varchar("name"), - clientId: integer("clientId").references(() => clients.clientId, { - // we will switch this depending on the current org it wants to connect to - onDelete: "set null" - }), - userId: text("userId").references(() => users.userId, { - // optionally tied to a user and in this case delete when the user deletes - onDelete: "cascade" - }), - archived: boolean("archived").notNull().default(false) -}); +export const olms = pgTable( + "olms", + { + olmId: varchar("id").primaryKey(), + secretHash: varchar("secretHash").notNull(), + dateCreated: varchar("dateCreated").notNull(), + version: text("version"), + agent: text("agent"), + name: varchar("name"), + clientId: integer("clientId").references(() => clients.clientId, { + // we will switch this depending on the current org it wants to connect to + onDelete: "set null" + }), + userId: text("userId").references(() => users.userId, { + // optionally tied to a user and in this case delete when the user deletes + onDelete: "cascade" + }), + archived: boolean("archived").notNull().default(false) + }, + (t) => [index("idx_olms_clientid").on(t.clientId)] +); export const currentFingerprint = pgTable("currentFingerprint", { fingerprintId: serial("id").primaryKey(), diff --git a/server/index.ts b/server/index.ts index 99fd20156..53b3e9a69 100644 --- a/server/index.ts +++ b/server/index.ts @@ -24,6 +24,7 @@ import license from "#dynamic/license/license"; import { initLogCleanupInterval } from "@server/lib/cleanupLogs"; import { initAcmeCertSync } from "#dynamic/lib/acmeCertSync"; import { fetchServerIp } from "@server/lib/serverIpService"; +import { startRebuildQueueProcessor } from "@server/lib/rebuildClientAssociations"; async function startServers() { await setHostMeta(); @@ -41,6 +42,7 @@ async function startServers() { initLogCleanupInterval(); initAcmeCertSync(); + startRebuildQueueProcessor(); // Start all servers const apiServer = createApiServer(); diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index 4efc72476..7f271bbe5 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -8,6 +8,7 @@ import { exitNodes, newts, olms, + primaryDb, roleSiteResources, Site, SiteResource, @@ -20,10 +21,10 @@ import { } from "@server/db"; import { and, count, eq, inArray, ne } from "drizzle-orm"; -import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; +import { deletePeersBatch as newtDeletePeersBatch } from "@server/routers/newt/peers"; import { - initPeerAddHandshake, - deletePeer as olmDeletePeer + initPeerAddHandshakeBatch, + deletePeersBatch as olmDeletePeersBatch } from "@server/routers/olm/peers"; import { sendToExitNode } from "#dynamic/lib/exitNodes"; import logger from "@server/logger"; @@ -34,12 +35,13 @@ import { parseEndpoint } from "@server/lib/ip"; import { - addPeerData, - addTargets as addSubnetProxyTargets, - removePeerData, - removeTargets as removeSubnetProxyTargets + addPeerDataBatch, + addTargetsBatch as addSubnetProxyTargetsBatch, + removePeerDataBatch, + removeTargetsBatch as removeSubnetProxyTargetsBatch } from "@server/routers/client/targets"; import { lockManager } from "#dynamic/lib/lock"; +import { rebuildQueue } from "#dynamic/lib/rebuildQueue"; // TTL for rebuild-association locks. These functions can fan out into many // peer/proxy updates, so give them a generous window. @@ -160,18 +162,33 @@ export async function getClientSiteResourceAccess( export async function rebuildClientAssociationsFromSiteResource( siteResource: SiteResource, trx: Transaction | typeof db = db -): Promise<{ - mergedAllClients: { - clientId: number; - pubKey: string | null; - subnet: string | null; - }[]; -}> { - return await lockManager.withLock( - `rebuild-client-associations:site-resource:${siteResource.siteResourceId}`, - () => rebuildClientAssociationsFromSiteResourceImpl(siteResource, trx), - REBUILD_ASSOCIATIONS_LOCK_TTL_MS - ); +) { + try { + return await lockManager.withLock( + `rebuild-client-associations:site-resource:${siteResource.siteResourceId}`, + () => + rebuildClientAssociationsFromSiteResourceImpl( + siteResource, + trx + ), + REBUILD_ASSOCIATIONS_LOCK_TTL_MS + ); + } catch (err: any) { + if ( + typeof err?.message === "string" && + err.message.startsWith("Failed to acquire lock") + ) { + logger.warn( + `rebuildClientAssociations: could not acquire lock for site resource ${siteResource.siteResourceId}, queuing for deferred processing` + ); + await rebuildQueue.enqueue({ + type: "site-resource", + id: siteResource.siteResourceId + }); + return { mergedAllClients: [] }; + } + throw err; + } } async function rebuildClientAssociationsFromSiteResourceImpl( @@ -536,6 +553,28 @@ async function handleMessagesForSiteClients( const newtJobs: Promise[] = []; const olmJobs: Promise[] = []; const exitNodeJobs: Promise[] = []; + const newtPeerDeletes: { + siteId: number; + publicKey: string; + newtId: string; + }[] = []; + const olmPeerDeletes: { + clientId: number; + siteId: number; + publicKey: string; + olmId: string; + }[] = []; + const olmPeerAddHandshakes: { + clientId: number; + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }; + olmId: string; + }[] = []; // Combine all clients that need processing (those being added or removed) const clientsToProcess = new Map< @@ -584,6 +623,21 @@ async function handleMessagesForSiteClients( } } + // Batch-fetch all olm IDs for the clients we need to process + const clientIdsToProcess = Array.from(clientsToProcess.keys()); + const olmRows = + clientIdsToProcess.length > 0 + ? await trx + .select({ olmId: olms.olmId, clientId: olms.clientId }) + .from(olms) + .where(inArray(olms.clientId, clientIdsToProcess)) + : []; + const olmByClientId = new Map( + olmRows + .filter((r) => r.clientId !== null) + .map((r) => [r.clientId as number, r.olmId]) + ); + for (const client of clientsToProcess.values()) { // UPDATE THE NEWT if (!client.subnet || !client.pubKey) { @@ -600,14 +654,8 @@ async function handleMessagesForSiteClients( continue; } - const [olm] = await trx - .select({ - olmId: olms.olmId - }) - .from(olms) - .where(eq(olms.clientId, client.clientId)) - .limit(1); - if (!olm) { + const olmId = olmByClientId.get(client.clientId); + if (!olmId) { logger.warn( `Olm not found for client ${client.clientId} so cannot add/delete peers` ); @@ -615,15 +663,17 @@ async function handleMessagesForSiteClients( } if (isDelete) { - newtJobs.push(newtDeletePeer(siteId, client.pubKey, newt.newtId)); - olmJobs.push( - olmDeletePeer( - client.clientId, - siteId, - site.publicKey, - olm.olmId - ) - ); + newtPeerDeletes.push({ + siteId, + publicKey: client.pubKey, + newtId: newt.newtId + }); + olmPeerDeletes.push({ + clientId: client.clientId, + siteId, + publicKey: site.publicKey, + olmId + }); } if (isAdd) { @@ -635,23 +685,34 @@ async function handleMessagesForSiteClients( continue; } - await initPeerAddHandshake( - // this will kick off the add peer process for the client - client.clientId, - { + olmPeerAddHandshakes.push({ + clientId: client.clientId, + peer: { siteId, exitNode: { publicKey: exitNode.publicKey, endpoint: exitNode.endpoint } }, - olm.olmId - ); + olmId + }); } exitNodeJobs.push(updateClientSiteDestinations(client, trx)); } + if (newtPeerDeletes.length > 0) { + newtJobs.push(newtDeletePeersBatch(newtPeerDeletes)); + } + + if (olmPeerDeletes.length > 0) { + olmJobs.push(olmDeletePeersBatch(olmPeerDeletes)); + } + + if (olmPeerAddHandshakes.length > 0) { + olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes)); + } + Promise.all(exitNodeJobs).catch((error) => { logger.error( `rebuildClientAssociations: Error updating client site destinations for site ${site.siteId}:`, @@ -812,6 +873,20 @@ async function handleSubnetProxyTargetUpdates( ): Promise { const proxyJobs: Promise[] = []; const olmJobs: Promise[] = []; + const targetsToAddBatch: { + newtId: string; + targets: NonNullable< + Awaited> + >; + version: string | null; + }[] = []; + const targetsToRemoveBatch: { + newtId: string; + targets: NonNullable< + Awaited> + >; + version: string | null; + }[] = []; for (const siteData of sitesList) { const siteId = siteData.siteId; @@ -843,25 +918,25 @@ async function handleSubnetProxyTargetUpdates( ); if (targetsToAdd) { - proxyJobs.push( - addSubnetProxyTargets( - newt.newtId, - targetsToAdd, - newt.version - ) - ); + targetsToAddBatch.push({ + newtId: newt.newtId, + targets: targetsToAdd, + version: newt.version + }); } - for (const client of addedClients) { - olmJobs.push( - addPeerData( - client.clientId, + olmJobs.push( + addPeerDataBatch( + addedClients.map((client) => ({ + clientId: client.clientId, siteId, - generateRemoteSubnets([siteResource]), - generateAliasConfig([siteResource]) - ) - ); - } + remoteSubnets: generateRemoteSubnets([ + siteResource + ]), + aliases: generateAliasConfig([siteResource]) + })) + ) + ); } } @@ -880,15 +955,20 @@ async function handleSubnetProxyTargetUpdates( ); if (targetsToRemove) { - proxyJobs.push( - removeSubnetProxyTargets( - newt.newtId, - targetsToRemove, - newt.version - ) - ); + targetsToRemoveBatch.push({ + newtId: newt.newtId, + targets: targetsToRemove, + version: newt.version + }); } + const peerDataRemovals: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: ReturnType; + }[] = []; + for (const client of removedClients) { if (!siteResource.destination) { continue; @@ -936,31 +1016,58 @@ async function handleSubnetProxyTargetUpdates( ? [] : generateRemoteSubnets([siteResource]); - olmJobs.push( - removePeerData( - client.clientId, - siteId, - remoteSubnetsToRemove, - generateAliasConfig([siteResource]) - ) - ); + peerDataRemovals.push({ + clientId: client.clientId, + siteId, + remoteSubnets: remoteSubnetsToRemove, + aliases: generateAliasConfig([siteResource]) + }); + } + + if (peerDataRemovals.length > 0) { + olmJobs.push(removePeerDataBatch(peerDataRemovals)); } } } } - await Promise.all(proxyJobs); + if (targetsToAddBatch.length > 0) { + proxyJobs.push(addSubnetProxyTargetsBatch(targetsToAddBatch)); + } + + if (targetsToRemoveBatch.length > 0) { + proxyJobs.push(removeSubnetProxyTargetsBatch(targetsToRemoveBatch)); + } + + await Promise.all([...proxyJobs, ...olmJobs]); } export async function rebuildClientAssociationsFromClient( client: Client, trx: Transaction | typeof db = db ): Promise { - return await lockManager.withLock( - `rebuild-client-associations:client:${client.clientId}`, - () => rebuildClientAssociationsFromClientImpl(client, trx), - REBUILD_ASSOCIATIONS_LOCK_TTL_MS - ); + try { + return await lockManager.withLock( + `rebuild-client-associations:client:${client.clientId}`, + () => rebuildClientAssociationsFromClientImpl(client, trx), + REBUILD_ASSOCIATIONS_LOCK_TTL_MS + ); + } catch (err: any) { + if ( + typeof err?.message === "string" && + err.message.startsWith("Failed to acquire lock") + ) { + logger.warn( + `rebuildClientAssociations: could not acquire lock for client ${client.clientId}, queuing for deferred processing` + ); + await rebuildQueue.enqueue({ + type: "client", + id: client.clientId + }); + return; + } + throw err; + } } async function rebuildClientAssociationsFromClientImpl( @@ -1237,6 +1344,28 @@ async function handleMessagesForClientSites( const newtJobs: Promise[] = []; const olmJobs: Promise[] = []; const exitNodeJobs: Promise[] = []; + const newtPeerDeletes: { + siteId: number; + publicKey: string; + newtId: string; + }[] = []; + const olmPeerDeletes: { + clientId: number; + siteId: number; + publicKey: string; + olmId: string; + }[] = []; + const olmPeerAddHandshakes: { + clientId: number; + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }; + olmId: string; + }[] = []; const totalSitesOnClient = await trx .select({ count: count(clientSitesAssociationsCache.siteId) }) @@ -1268,19 +1397,19 @@ async function handleMessagesForClientSites( if (isRemove) { // Remove peer from newt - newtJobs.push( - newtDeletePeer(site.siteId, client.pubKey, newt.newtId) - ); + newtPeerDeletes.push({ + siteId: site.siteId, + publicKey: client.pubKey, + newtId: newt.newtId + }); try { // Remove peer from olm - olmJobs.push( - olmDeletePeer( - client.clientId, - site.siteId, - site.publicKey, - olmId - ) - ); + olmPeerDeletes.push({ + clientId: client.clientId, + siteId: site.siteId, + publicKey: site.publicKey, + olmId + }); } catch (error) { // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send if ( @@ -1312,10 +1441,9 @@ async function handleMessagesForClientSites( continue; } - await initPeerAddHandshake( - // this will kick off the add peer process for the client - client.clientId, - { + olmPeerAddHandshakes.push({ + clientId: client.clientId, + peer: { siteId: site.siteId, exitNode: { publicKey: exitNode.publicKey, @@ -1323,7 +1451,7 @@ async function handleMessagesForClientSites( } }, olmId - ); + }); } // Update exit node destinations @@ -1339,6 +1467,18 @@ async function handleMessagesForClientSites( ); } + if (newtPeerDeletes.length > 0) { + newtJobs.push(newtDeletePeersBatch(newtPeerDeletes)); + } + + if (olmPeerDeletes.length > 0) { + olmJobs.push(olmDeletePeersBatch(olmPeerDeletes)); + } + + if (olmPeerAddHandshakes.length > 0) { + olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes)); + } + Promise.all(exitNodeJobs).catch((error) => { logger.error( `rebuildClientAssociations: Error updating client site destinations for client ${client.clientId}:`, @@ -1437,6 +1577,20 @@ async function handleMessagesForClientResources( continue; } + const targetsToAddBatch: { + newtId: string; + targets: NonNullable< + Awaited> + >; + version: string | null; + }[] = []; + const peerDataAdds: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: ReturnType; + }[] = []; + for (const resource of resources) { const targets = await generateSubnetProxyTargetV2(resource, [ { @@ -1447,25 +1601,21 @@ async function handleMessagesForClientResources( ]); if (targets) { - proxyJobs.push( - addSubnetProxyTargets( - newt.newtId, - targets, - newt.version - ) - ); + targetsToAddBatch.push({ + newtId: newt.newtId, + targets, + version: newt.version + }); } try { // Add peer data to olm - olmJobs.push( - addPeerData( - client.clientId, - siteId, - generateRemoteSubnets([resource]), - generateAliasConfig([resource]) - ) - ); + peerDataAdds.push({ + clientId: client.clientId, + siteId, + remoteSubnets: generateRemoteSubnets([resource]), + aliases: generateAliasConfig([resource]) + }); } catch (error) { // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send if ( @@ -1480,6 +1630,14 @@ async function handleMessagesForClientResources( } } } + + if (targetsToAddBatch.length > 0) { + proxyJobs.push(addSubnetProxyTargetsBatch(targetsToAddBatch)); + } + + if (peerDataAdds.length > 0) { + olmJobs.push(addPeerDataBatch(peerDataAdds)); + } } } @@ -1546,6 +1704,20 @@ async function handleMessagesForClientResources( continue; } + const targetsToRemoveBatch: { + newtId: string; + targets: NonNullable< + Awaited> + >; + version: string | null; + }[] = []; + const peerDataRemovals: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: ReturnType; + }[] = []; + for (const resource of resources) { const targets = await generateSubnetProxyTargetV2(resource, [ { @@ -1556,13 +1728,11 @@ async function handleMessagesForClientResources( ]); if (targets) { - proxyJobs.push( - removeSubnetProxyTargets( - newt.newtId, - targets, - newt.version - ) - ); + targetsToRemoveBatch.push({ + newtId: newt.newtId, + targets, + version: newt.version + }); } try { @@ -1613,14 +1783,12 @@ async function handleMessagesForClientResources( : generateRemoteSubnets([resource]); // Remove peer data from olm - olmJobs.push( - removePeerData( - client.clientId, - siteId, - remoteSubnetsToRemove, - generateAliasConfig([resource]) - ) - ); + peerDataRemovals.push({ + clientId: client.clientId, + siteId, + remoteSubnets: remoteSubnetsToRemove, + aliases: generateAliasConfig([resource]) + }); } catch (error) { // if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send if ( @@ -1635,6 +1803,16 @@ async function handleMessagesForClientResources( } } } + + if (targetsToRemoveBatch.length > 0) { + proxyJobs.push( + removeSubnetProxyTargetsBatch(targetsToRemoveBatch) + ); + } + + if (peerDataRemovals.length > 0) { + olmJobs.push(removePeerDataBatch(peerDataRemovals)); + } } } @@ -1884,11 +2062,20 @@ export async function cleanupSiteAssociations( // 7. Fire all removal messages in parallel. const jobs: Promise[] = []; + const olmPeerDeletes: { + clientId: number; + siteId: number; + publicKey: string; + }[] = []; for (const client of allClients) { // Tell each olm to drop the site's WireGuard peer. if (site.publicKey) { - jobs.push(olmDeletePeer(client.clientId, siteId, site.publicKey)); + olmPeerDeletes.push({ + clientId: client.clientId, + siteId, + publicKey: site.publicKey + }); } // Recompute and push updated relay destinations (now excluding this site). @@ -1897,6 +2084,10 @@ export async function cleanupSiteAssociations( } } + if (olmPeerDeletes.length > 0) { + jobs.push(olmDeletePeersBatch(olmPeerDeletes)); + } + await Promise.all(jobs).catch((error) => { logger.error( `cleanupSiteAssociations: error sending cleanup messages for siteId=${siteId}:`, @@ -1906,3 +2097,47 @@ export async function cleanupSiteAssociations( logger.debug(`cleanupSiteAssociations: DONE siteId=${siteId}`); } + +/** + * Start the background rebuild queue processor. This should be called once + * during server startup. Only one server instance at a time will actively + * consume the queue (enforced via a distributed Redis lock); all other + * instances will poll and wait until the lock becomes available. + */ +export function startRebuildQueueProcessor(): void { + rebuildQueue.startProcessing({ + onSiteResource: async (siteResourceId: number) => { + const [siteResource] = await primaryDb + .select() + .from(siteResources) + .where(eq(siteResources.siteResourceId, siteResourceId)); + + if (!siteResource) { + logger.warn( + `Rebuild queue: site resource ${siteResourceId} not found, skipping` + ); + return; + } + + await rebuildClientAssociationsFromSiteResource( + siteResource, + primaryDb + ); + }, + onClient: async (clientId: number) => { + const [client] = await primaryDb + .select() + .from(clients) + .where(eq(clients.clientId, clientId)); + + if (!client) { + logger.warn( + `Rebuild queue: client ${clientId} not found, skipping` + ); + return; + } + + await rebuildClientAssociationsFromClient(client, primaryDb); + } + }); +} diff --git a/server/lib/rebuildQueue.ts b/server/lib/rebuildQueue.ts new file mode 100644 index 000000000..475858108 --- /dev/null +++ b/server/lib/rebuildQueue.ts @@ -0,0 +1,23 @@ +export type RebuildJobType = "site-resource" | "client"; + +export interface RebuildJob { + type: RebuildJobType; + id: number; +} + +export interface RebuildJobHandlers { + onSiteResource(siteResourceId: number): Promise; + onClient(clientId: number): Promise; +} + +export interface RebuildQueueManager { + enqueue(job: RebuildJob): Promise; + startProcessing(handlers: RebuildJobHandlers): void; +} + +class NoopRebuildQueue implements RebuildQueueManager { + async enqueue(_job: RebuildJob): Promise {} + startProcessing(_handlers: RebuildJobHandlers): void {} +} + +export const rebuildQueue: RebuildQueueManager = new NoopRebuildQueue(); diff --git a/server/private/lib/rebuildQueue.ts b/server/private/lib/rebuildQueue.ts new file mode 100644 index 000000000..2cd1dadc0 --- /dev/null +++ b/server/private/lib/rebuildQueue.ts @@ -0,0 +1,198 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025-2026 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import { redis } from "#private/lib/redis"; +import { lockManager } from "#dynamic/lib/lock"; +import logger from "@server/logger"; + +export type RebuildJobType = "site-resource" | "client"; + +export interface RebuildJob { + type: RebuildJobType; + id: number; +} + +export interface RebuildJobHandlers { + onSiteResource(siteResourceId: number): Promise; + onClient(clientId: number): Promise; +} + +// Redis list holding pending rebuild jobs (RPUSH to enqueue, LPOP to dequeue — FIFO order). +const QUEUE_KEY = "rebuild-client-associations:queue"; +const QUEUED_SET_KEY = "rebuild-client-associations:queued"; + +// Distributed lock that serialises queue consumption to a single server instance +// at a time. TTL is generous enough to cover a full batch of expensive rebuilds. +const PROCESSOR_LOCK_KEY = "rebuild-client-associations:processor"; + +// Each rebuild can take up to REBUILD_ASSOCIATIONS_LOCK_TTL_MS (120 s) per +// resource. Allow BATCH_SIZE resources per processor-lock acquisition, plus a +// small buffer. +const BATCH_SIZE = 5; +const PROCESSOR_LOCK_TTL_MS = 120000 * BATCH_SIZE + 30000; // ~630 s + +const POLL_INTERVAL_MS = 500; + +class RedisRebuildQueue { + private processingStarted = false; + + async enqueue(job: RebuildJob): Promise { + if (!redis || redis.status !== "ready") { + logger.warn( + `Rebuild queue: Redis not available — rebuild for ${job.type}:${job.id} will not be retried` + ); + return; + } + + try { + const dedupeKey = `${job.type}:${job.id}`; + const added = await redis.sadd(QUEUED_SET_KEY, dedupeKey); + if (added === 0) { + logger.debug( + `Rebuild queue: skipped duplicate queued job ${job.type}:${job.id}` + ); + return; + } + + await redis.rpush(QUEUE_KEY, JSON.stringify(job)); + logger.debug( + `Rebuild queue: enqueued ${job.type}:${job.id} (queue position: tail)` + ); + } catch (err) { + await redis + .srem(QUEUED_SET_KEY, `${job.type}:${job.id}`) + .catch((cleanupErr) => + logger.warn( + `Rebuild queue: failed to cleanup dedupe key for ${job.type}:${job.id} after enqueue failure:`, + cleanupErr + ) + ); + logger.error( + `Rebuild queue: failed to enqueue ${job.type}:${job.id}:`, + err + ); + } + } + + startProcessing(handlers: RebuildJobHandlers): void { + if (this.processingStarted) return; + this.processingStarted = true; + + this.processLoop(handlers).catch((err) => { + logger.error("Rebuild queue processor loop crashed:", err); + }); + + logger.info("Rebuild queue processor started"); + } + + private async processLoop(handlers: RebuildJobHandlers): Promise { + while (true) { + try { + await this.tryProcessBatch(handlers); + } catch (err) { + logger.error( + "Rebuild queue: unhandled error in process loop:", + err + ); + } + await new Promise((resolve) => + setTimeout(resolve, POLL_INTERVAL_MS) + ); + } + } + + private async tryProcessBatch(handlers: RebuildJobHandlers): Promise { + if (!redis || redis.status !== "ready") return; + + // Peek before acquiring the processor lock to avoid unnecessary Redis + // round-trips and lock contention when the queue is idle. + const queueLength = await redis.llen(QUEUE_KEY).catch(() => 0); + if (queueLength === 0) return; + + try { + await lockManager.withLock( + PROCESSOR_LOCK_KEY, + async () => { + for (let i = 0; i < BATCH_SIZE; i++) { + if (!redis || redis.status !== "ready") break; + + const payload = await redis.lpop(QUEUE_KEY); + if (payload === null) break; // queue drained + + let job: RebuildJob; + try { + job = JSON.parse(payload) as RebuildJob; + } catch { + logger.error( + `Rebuild queue: could not parse job payload, discarding: ${payload}` + ); + continue; + } + + // Remove from dedupe set once dequeued so the same job + // can be re-queued while this one is in progress. + await redis + .srem(QUEUED_SET_KEY, `${job.type}:${job.id}`) + .catch((cleanupErr) => + logger.warn( + `Rebuild queue: failed to remove dedupe key for ${job.type}:${job.id} on dequeue:`, + cleanupErr + ) + ); + + logger.debug( + `Rebuild queue: processing ${job.type}:${job.id}` + ); + + try { + if (job.type === "site-resource") { + await handlers.onSiteResource(job.id); + } else if (job.type === "client") { + await handlers.onClient(job.id); + } else { + logger.warn( + `Rebuild queue: unknown job type "${(job as any).type}", discarding` + ); + } + + logger.debug( + `Rebuild queue: completed ${job.type}:${job.id}` + ); + } catch (err) { + logger.error( + `Rebuild queue: job ${job.type}:${job.id} threw an error:`, + err + ); + } + } + }, + PROCESSOR_LOCK_TTL_MS + ); + } catch (err: any) { + if ( + typeof err?.message === "string" && + err.message.startsWith("Failed to acquire lock") + ) { + // Another server instance currently holds the processor lock and + // is consuming the queue — nothing to do this cycle. + logger.debug( + "Rebuild queue: processor lock held by another instance, skipping this cycle" + ); + } else { + throw err; + } + } + } +} + +export const rebuildQueue: RedisRebuildQueue = new RedisRebuildQueue(); diff --git a/server/private/routers/ws/ws.ts b/server/private/routers/ws/ws.ts index a592927cc..5e38c709e 100644 --- a/server/private/routers/ws/ws.ts +++ b/server/private/routers/ws/ws.ts @@ -38,6 +38,7 @@ import { messageHandlers } from "@server/routers/ws/messageHandlers"; import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers"; import { AuthenticatedWebSocket, + BatchSendMessage, ClientType, WSMessage, TokenPayload, @@ -187,6 +188,8 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true }); // Generate unique node ID for this instance const NODE_ID = uuidv4(); const REDIS_CHANNEL = "websocket_messages"; +const REDIS_DIRECT_BATCH_SIZE = 250; +const REDIS_DIRECT_FLUSH_INTERVAL_MS = 10; // Client tracking map (local to this node) const connectedClients: Map = new Map(); @@ -197,6 +200,15 @@ const clientConfigVersions: Map = new Map(); // Recovery tracking let isRedisRecoveryInProgress = false; +interface RedisDirectBatchEntry { + targetClientId: string; + message: WSMessage; + resolve: () => void; +} + +let pendingRedisDirectMessages: RedisDirectBatchEntry[] = []; +let redisDirectFlushTimer: NodeJS.Timeout | null = null; + // Helper to get map key const getClientMapKey = (clientId: string) => clientId; @@ -207,6 +219,78 @@ const getNodeConnectionsKey = (nodeId: string, clientId: string) => const getConfigVersionKey = (clientId: string) => `ws:configVersion:${clientId}`; +const clearRedisDirectFlushTimer = (): void => { + if (redisDirectFlushTimer) { + clearTimeout(redisDirectFlushTimer); + redisDirectFlushTimer = null; + } +}; + +const publishDirectBatch = async ( + entries: RedisDirectBatchEntry[] +): Promise => { + const redisMessage: RedisMessage = { + type: "direct-batch", + messages: entries.map((entry) => ({ + targetClientId: entry.targetClientId, + message: entry.message + })), + fromNodeId: NODE_ID + }; + + await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage)); +}; + +const flushPendingRedisDirectMessages = async (): Promise => { + clearRedisDirectFlushTimer(); + + if (pendingRedisDirectMessages.length === 0) { + return; + } + + const entries = pendingRedisDirectMessages; + pendingRedisDirectMessages = []; + + if (!redisManager.isRedisEnabled()) { + entries.forEach((entry) => entry.resolve()); + return; + } + + for (let i = 0; i < entries.length; i += REDIS_DIRECT_BATCH_SIZE) { + const batch = entries.slice(i, i + REDIS_DIRECT_BATCH_SIZE); + try { + await publishDirectBatch(batch); + } catch (error) { + logger.error( + "Failed to send batched direct messages via Redis, messages may be lost:", + error + ); + } finally { + batch.forEach((entry) => entry.resolve()); + } + } +}; + +const enqueueRedisDirectMessage = async ( + targetClientId: string, + message: WSMessage +): Promise => { + await new Promise((resolve) => { + pendingRedisDirectMessages.push({ targetClientId, message, resolve }); + + if (pendingRedisDirectMessages.length >= REDIS_DIRECT_BATCH_SIZE) { + void flushPendingRedisDirectMessages(); + return; + } + + if (!redisDirectFlushTimer) { + redisDirectFlushTimer = setTimeout(() => { + void flushPendingRedisDirectMessages(); + }, REDIS_DIRECT_FLUSH_INTERVAL_MS); + } + }); +}; + // Initialize Redis subscription for cross-node messaging const initializeRedisSubscription = async (): Promise => { if (!redisManager.isRedisEnabled()) return; @@ -227,7 +311,16 @@ const initializeRedisSubscription = async (): Promise => { // Send to specific client on this node await sendToClientLocal( redisMessage.targetClientId, - redisMessage.message + redisMessage.message, + {}, + redisMessage.message.configVersion + ); + } else if ( + redisMessage.type === "direct-batch" && + redisMessage.messages + ) { + await sendRedisDirectBatchToLocalClients( + redisMessage.messages ); } else if (redisMessage.type === "broadcast") { // Broadcast to all clients on this node except excluded @@ -503,7 +596,8 @@ const incrementClientConfigVersion = async ( const sendToClientLocal = async ( clientId: string, message: WSMessage, - options: SendMessageOptions = {} + options: SendMessageOptions = {}, + preResolvedConfigVersion?: number ): Promise => { const mapKey = getClientMapKey(clientId); const clients = connectedClients.get(mapKey); @@ -512,7 +606,8 @@ const sendToClientLocal = async ( } // Handle config version - const configVersion = await getClientConfigVersion(clientId); + const configVersion = + preResolvedConfigVersion ?? (await getClientConfigVersion(clientId)); // Add config version to message const messageWithVersion = { @@ -545,43 +640,71 @@ const sendToClientLocal = async ( return true; }; +const sendRedisDirectBatchToLocalClients = async ( + entries: { targetClientId: string; message: WSMessage }[] +): Promise => { + const jobs = entries.map((entry) => + sendToClientLocal( + entry.targetClientId, + entry.message, + {}, + entry.message.configVersion + ) + ); + await Promise.all(jobs); +}; + const broadcastToAllExceptLocal = async ( message: WSMessage, excludeClientId?: string, options: SendMessageOptions = {} ): Promise => { - for (const [mapKey, clients] of connectedClients.entries()) { - const [type, id] = mapKey.split(":"); - const clientId = mapKey; // mapKey is the clientId - if (!(excludeClientId && clientId === excludeClientId)) { - // Handle config version per client - let configVersion = await getClientConfigVersion(clientId); - if (options.incrementConfigVersion) { - configVersion = await incrementClientConfigVersion(clientId); - } + const sendPlans = await Promise.all( + Array.from(connectedClients.entries()).map( + async ([mapKey, clients]) => { + const clientId = mapKey; // mapKey is the clientId + if (excludeClientId && clientId === excludeClientId) { + return null; + } - // Add config version to message - const messageWithVersion = { - ...message, - configVersion - }; + let configVersion = await getClientConfigVersion(clientId); + if (options.incrementConfigVersion) { + configVersion = + await incrementClientConfigVersion(clientId); + } - if (options.compress) { - const compressed = zlib.gzipSync( - Buffer.from(JSON.stringify(messageWithVersion), "utf8") - ); - clients.forEach((client) => { - if (client.readyState === WebSocket.OPEN) { - client.send(compressed); + return { + clients, + messageWithVersion: { + ...message, + configVersion } - }); - } else { - clients.forEach((client) => { - if (client.readyState === WebSocket.OPEN) { - client.send(JSON.stringify(messageWithVersion)); - } - }); + }; } + ) + ); + + for (const plan of sendPlans) { + if (!plan) { + continue; + } + + if (options.compress) { + const compressed = zlib.gzipSync( + Buffer.from(JSON.stringify(plan.messageWithVersion), "utf8") + ); + plan.clients.forEach((client) => { + if (client.readyState === WebSocket.OPEN) { + client.send(compressed); + } + }); + } else { + const messageString = JSON.stringify(plan.messageWithVersion); + plan.clients.forEach((client) => { + if (client.readyState === WebSocket.OPEN) { + client.send(messageString); + } + }); } } }; @@ -602,28 +725,23 @@ const sendToClient = async ( ); // Try to send locally first - const localSent = await sendToClientLocal(clientId, message, options); + const localSent = await sendToClientLocal( + clientId, + message, + options, + configVersion + ); // Only send via Redis if the client is not connected locally and Redis is enabled if (!localSent && redisManager.isRedisEnabled()) { try { - const redisMessage: RedisMessage = { - type: "direct", - targetClientId: clientId, - message: { - ...message, - configVersion - }, - fromNodeId: NODE_ID - }; - - await redisManager.publish( - REDIS_CHANNEL, - JSON.stringify(redisMessage) - ); + await enqueueRedisDirectMessage(clientId, { + ...message, + configVersion + }); } catch (error) { logger.error( - "Failed to send message via Redis, message may be lost:", + "Failed to queue batched direct message for Redis delivery, message may be lost:", error ); // Continue execution - local delivery already attempted @@ -638,6 +756,95 @@ const sendToClient = async ( return localSent; }; +const sendToClientsBatch = async ( + entries: BatchSendMessage[] +): Promise => { + if (entries.length === 0) { + return; + } + + const remoteEntries: { targetClientId: string; message: WSMessage }[] = []; + const clientsWithIncrement = new Set( + entries + .filter((entry) => !!entry.options?.incrementConfigVersion) + .map((entry) => entry.clientId) + ); + const nonIncrementOnlyClientIds = Array.from( + new Set( + entries + .map((entry) => entry.clientId) + .filter((clientId) => !clientsWithIncrement.has(clientId)) + ) + ); + const stableConfigVersionByClient = new Map( + await Promise.all( + nonIncrementOnlyClientIds.map( + async (clientId) => + [clientId, await getClientConfigVersion(clientId)] as const + ) + ) + ); + + for (const entry of entries) { + const options = entry.options || {}; + const { clientId, message } = entry; + + const configVersion = options.incrementConfigVersion + ? await incrementClientConfigVersion(clientId) + : stableConfigVersionByClient.get(clientId); + + logger.debug( + `sendToClientsBatch: Message type ${message.type} queued for clientId ${clientId} (new configVersion: ${configVersion})` + ); + + const localSent = await sendToClientLocal( + clientId, + message, + options, + configVersion + ); + + if (!localSent && redisManager.isRedisEnabled()) { + remoteEntries.push({ + targetClientId: clientId, + message: { + ...message, + configVersion + } + }); + } else if (!localSent && !redisManager.isRedisEnabled()) { + logger.debug( + `Could not deliver batch message to ${clientId} - not connected locally and Redis unavailable` + ); + } + } + + if (!redisManager.isRedisEnabled() || remoteEntries.length === 0) { + return; + } + + for (let i = 0; i < remoteEntries.length; i += REDIS_DIRECT_BATCH_SIZE) { + const messages = remoteEntries.slice(i, i + REDIS_DIRECT_BATCH_SIZE); + try { + const redisMessage: RedisMessage = { + type: "direct-batch", + messages, + fromNodeId: NODE_ID + }; + + await redisManager.publish( + REDIS_CHANNEL, + JSON.stringify(redisMessage) + ); + } catch (error) { + logger.error( + "Failed to send explicit direct batch via Redis, messages may be lost:", + error + ); + } + } +}; + const broadcastToAllExcept = async ( message: WSMessage, excludeClientId?: string, @@ -1109,6 +1316,8 @@ const disconnectClient = async (clientId: string): Promise => { // Cleanup function for graceful shutdown const cleanup = async (): Promise => { try { + await flushPendingRedisDirectMessages(); + // Close all WebSocket connections connectedClients.forEach((clients) => { clients.forEach((client) => { @@ -1139,6 +1348,7 @@ export { router, handleWSUpgrade, sendToClient, + sendToClientsBatch, broadcastToAllExcept, connectedClients, hasActiveConnections, diff --git a/server/routers/client/targets.ts b/server/routers/client/targets.ts index c208acd88..c62a64ae0 100644 --- a/server/routers/client/targets.ts +++ b/server/routers/client/targets.ts @@ -1,4 +1,4 @@ -import { sendToClient } from "#dynamic/routers/ws"; +import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import { db, newts, olms } from "@server/db"; import { Alias, @@ -8,7 +8,7 @@ import { } from "@server/lib/ip"; import { canCompress } from "@server/lib/clientVersionChecks"; import logger from "@server/logger"; -import { eq } from "drizzle-orm"; +import { eq, inArray } from "drizzle-orm"; import semver from "semver"; const NEWT_V2_TARGETS_VERSION = ">=1.10.3"; @@ -59,6 +59,42 @@ export async function addTargets( ); } +export async function addTargetsBatch( + entries: { + newtId: string; + targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolved = await Promise.all( + entries.map(async (entry) => ({ + ...entry, + targets: await convertTargetsIfNecessary( + entry.newtId, + entry.targets + ) + })) + ); + + await sendToClientsBatch( + resolved.map((entry) => ({ + clientId: entry.newtId, + message: { + type: `newt/wg/targets/add`, + data: entry.targets + }, + options: { + incrementConfigVersion: true, + compress: canCompress(entry.version, "newt") + } + })) + ); +} + export async function removeTargets( newtId: string, targets: SubnetProxyTarget[] | SubnetProxyTargetV2[], @@ -76,6 +112,42 @@ export async function removeTargets( ); } +export async function removeTargetsBatch( + entries: { + newtId: string; + targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolved = await Promise.all( + entries.map(async (entry) => ({ + ...entry, + targets: await convertTargetsIfNecessary( + entry.newtId, + entry.targets + ) + })) + ); + + await sendToClientsBatch( + resolved.map((entry) => ({ + clientId: entry.newtId, + message: { + type: `newt/wg/targets/remove`, + data: entry.targets + }, + options: { + incrementConfigVersion: true, + compress: canCompress(entry.version, "newt") + } + })) + ); +} + export async function updateTargets( newtId: string, targets: { @@ -201,6 +273,171 @@ export async function removePeerData( }); } +const resolveOlmTargets = async ( + entries: { + clientId: number; + olmId?: string; + version?: string | null; + }[] +) => { + const unresolvedClientIds = entries + .filter((entry) => !entry.olmId) + .map((entry) => entry.clientId); + + const olmMap = new Map(); + + if (unresolvedClientIds.length > 0) { + const olmRows = await db + .select({ + clientId: olms.clientId, + olmId: olms.olmId, + version: olms.version + }) + .from(olms) + .where(inArray(olms.clientId, unresolvedClientIds)); + + for (const row of olmRows) { + if (row.clientId !== null) { + olmMap.set(row.clientId, { + olmId: row.olmId, + version: row.version + }); + } + } + } + + return entries + .map((entry) => { + if (entry.olmId) { + return { + clientId: entry.clientId, + olmId: entry.olmId, + version: entry.version + }; + } + + const resolved = olmMap.get(entry.clientId); + if (!resolved) { + return null; + } + + return { + clientId: entry.clientId, + olmId: resolved.olmId, + version: entry.version ?? resolved.version + }; + }) + .filter((entry) => entry !== null); +}; + +export async function addPeerDataBatch( + entries: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: Alias[]; + olmId?: string; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolvedTargets = await resolveOlmTargets(entries); + + if (resolvedTargets.length === 0) { + return; + } + + const payloads = entries + .map((entry) => { + const resolved = resolvedTargets.find( + (target) => target.clientId === entry.clientId + ); + if (!resolved) { + return null; + } + + return { + clientId: resolved.olmId, + message: { + type: `olm/wg/peer/data/add`, + data: { + siteId: entry.siteId, + remoteSubnets: entry.remoteSubnets, + aliases: entry.aliases + } + }, + options: { + incrementConfigVersion: true, + compress: canCompress(resolved.version, "olm") + } + }; + }) + .filter((entry) => entry !== null); + + if (payloads.length === 0) { + return; + } + + await sendToClientsBatch(payloads); +} + +export async function removePeerDataBatch( + entries: { + clientId: number; + siteId: number; + remoteSubnets: string[]; + aliases: Alias[]; + olmId?: string; + version?: string | null; + }[] +) { + if (entries.length === 0) { + return; + } + + const resolvedTargets = await resolveOlmTargets(entries); + + if (resolvedTargets.length === 0) { + return; + } + + const payloads = entries + .map((entry) => { + const resolved = resolvedTargets.find( + (target) => target.clientId === entry.clientId + ); + if (!resolved) { + return null; + } + + return { + clientId: resolved.olmId, + message: { + type: `olm/wg/peer/data/remove`, + data: { + siteId: entry.siteId, + remoteSubnets: entry.remoteSubnets, + aliases: entry.aliases + } + }, + options: { + incrementConfigVersion: true, + compress: canCompress(resolved.version, "olm") + } + }; + }) + .filter((entry) => entry !== null); + + if (payloads.length === 0) { + return; + } + + await sendToClientsBatch(payloads); +} + export async function updatePeerData( clientId: number, siteId: number, diff --git a/server/routers/newt/peers.ts b/server/routers/newt/peers.ts index 4b74d863d..6c38671f3 100644 --- a/server/routers/newt/peers.ts +++ b/server/routers/newt/peers.ts @@ -1,7 +1,7 @@ import { db, Site } from "@server/db"; import { newts, sites } from "@server/db"; import { eq } from "drizzle-orm"; -import { sendToClient } from "#dynamic/routers/ws"; +import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import logger from "@server/logger"; export async function addPeer( @@ -36,10 +36,14 @@ export async function addPeer( newtId = newt.newtId; } - await sendToClient(newtId, { - type: "newt/wg/peer/add", - data: peer - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + newtId, + { + type: "newt/wg/peer/add", + data: peer + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); @@ -76,12 +80,16 @@ export async function deletePeer( newtId = newt.newtId; } - await sendToClient(newtId, { - type: "newt/wg/peer/remove", - data: { - publicKey - } - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + newtId, + { + type: "newt/wg/peer/remove", + data: { + publicKey + } + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); @@ -90,6 +98,35 @@ export async function deletePeer( return site; } +export async function deletePeersBatch( + peers: { + siteId: number; + publicKey: string; + newtId: string; + }[] +) { + if (peers.length === 0) { + return; + } + + await sendToClientsBatch( + peers.map((peer) => ({ + clientId: peer.newtId, + message: { + type: "newt/wg/peer/remove", + data: { + publicKey: peer.publicKey + } + }, + options: { incrementConfigVersion: true } + })) + ).catch((error) => { + logger.warn(`Error sending batched newt peer removals:`, error); + }); + + logger.info(`Deleted ${peers.length} peer(s) from newts (batch)`); +} + export async function updatePeer( siteId: number, publicKey: string, @@ -122,13 +159,17 @@ export async function updatePeer( newtId = newt.newtId; } - await sendToClient(newtId, { - type: "newt/wg/peer/update", - data: { - publicKey, - ...peer - } - }, { incrementConfigVersion: true }).catch((error) => { + await sendToClient( + newtId, + { + type: "newt/wg/peer/update", + data: { + publicKey, + ...peer + } + }, + { incrementConfigVersion: true } + ).catch((error) => { logger.warn(`Error sending message:`, error); }); diff --git a/server/routers/olm/peers.ts b/server/routers/olm/peers.ts index 05e153fea..962d7367e 100644 --- a/server/routers/olm/peers.ts +++ b/server/routers/olm/peers.ts @@ -1,9 +1,9 @@ -import { sendToClient } from "#dynamic/routers/ws"; +import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws"; import { clientSitesAssociationsCache, db, olms } from "@server/db"; import { canCompress } from "@server/lib/clientVersionChecks"; import config from "@server/lib/config"; import logger from "@server/logger"; -import { and, eq } from "drizzle-orm"; +import { and, eq, inArray } from "drizzle-orm"; import { Alias } from "yaml"; export async function addPeer( @@ -205,3 +205,150 @@ export async function initPeerAddHandshake( `Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}` ); } + +export async function deletePeersBatch( + peers: { + clientId: number; + siteId: number; + publicKey: string; + olmId?: string; + version?: string | null; + }[] +) { + if (peers.length === 0) { + return; + } + + const unresolvedClientIds = peers + .filter((peer) => !peer.olmId) + .map((peer) => peer.clientId); + + const olmByClientId = new Map< + number, + { olmId: string; version: string | null } + >(); + + if (unresolvedClientIds.length > 0) { + const olmRows = await db + .select({ + clientId: olms.clientId, + olmId: olms.olmId, + version: olms.version + }) + .from(olms) + .where(inArray(olms.clientId, unresolvedClientIds)); + + for (const row of olmRows) { + if (row.clientId !== null) { + olmByClientId.set(row.clientId, { + olmId: row.olmId, + version: row.version + }); + } + } + } + + const batchPayloads = peers + .map((peer) => { + const resolved = peer.olmId + ? { olmId: peer.olmId, version: peer.version ?? null } + : olmByClientId.get(peer.clientId); + if (!resolved) { + return null; + } + + return { + clientId: resolved.olmId, + message: { + type: "olm/wg/peer/remove", + data: { + publicKey: peer.publicKey, + siteId: peer.siteId + } + }, + options: { + incrementConfigVersion: true, + compress: canCompress( + peer.version ?? resolved.version, + "olm" + ) + } + }; + }) + .filter((payload) => payload !== null); + + if (batchPayloads.length === 0) { + return; + } + + await sendToClientsBatch(batchPayloads).catch((error) => { + logger.warn(`Error sending batched olm peer removals:`, error); + }); + + logger.info(`Deleted ${batchPayloads.length} peer(s) from olms (batch)`); +} + +export async function initPeerAddHandshakeBatch( + handshakes: { + clientId: number; + peer: { + siteId: number; + exitNode: { + publicKey: string; + endpoint: string; + }; + }; + olmId: string; + chainId?: string; + }[] +) { + if (handshakes.length === 0) { + return; + } + + await sendToClientsBatch( + handshakes.map((item) => ({ + clientId: item.olmId, + message: { + type: "olm/wg/peer/holepunch/site/add", + data: { + siteId: item.peer.siteId, + exitNode: { + publicKey: item.peer.exitNode.publicKey, + relayPort: + config.getRawConfig().gerbil.clients_start_port, + endpoint: item.peer.exitNode.endpoint + }, + chainId: item.chainId + } + }, + options: { incrementConfigVersion: true } + })) + ).catch((error) => { + logger.warn(`Error sending batched olm handshakes:`, error); + }); + + await Promise.all( + handshakes.map((item) => + db + .update(clientSitesAssociationsCache) + .set({ isJitMode: false }) + .where( + and( + eq( + clientSitesAssociationsCache.clientId, + item.clientId + ), + eq( + clientSitesAssociationsCache.siteId, + item.peer.siteId + ) + ) + ) + ) + ); + + logger.info( + `Initiated ${handshakes.length} peer add handshake(s) to olms (batch)` + ); +} diff --git a/server/routers/siteResource/updateSiteResource.ts b/server/routers/siteResource/updateSiteResource.ts index db4d4445b..3f271d2f9 100644 --- a/server/routers/siteResource/updateSiteResource.ts +++ b/server/routers/siteResource/updateSiteResource.ts @@ -28,7 +28,10 @@ import { isIpInCidr, portRangeStringSchema } from "@server/lib/ip"; -import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations"; +import { + getClientSiteResourceAccess, + rebuildClientAssociationsFromSiteResource +} from "@server/lib/rebuildClientAssociations"; import logger from "@server/logger"; import HttpCode from "@server/types/HttpCode"; import { NextFunction, Request, Response } from "express"; @@ -846,9 +849,14 @@ export async function handleMessagingForUpdatedSiteResource( updatedSiteResource ); - const { mergedAllClients } = - await rebuildClientAssociationsFromSiteResource( - existingSiteResource || updatedSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below + await rebuildClientAssociationsFromSiteResource( + existingSiteResource || updatedSiteResource, // we want to rebuild based on the existing resource then we will apply the change to the destination below + trx + ); + + const { sitesList, mergedAllClients, mergedAllClientIds } = + await getClientSiteResourceAccess( + existingSiteResource || updatedSiteResource, trx ); diff --git a/server/routers/ws/types.ts b/server/routers/ws/types.ts index e539954ce..eeb272457 100644 --- a/server/routers/ws/types.ts +++ b/server/routers/ws/types.ts @@ -76,12 +76,32 @@ export interface SendMessageOptions { compress?: boolean; } -// Redis message type for cross-node communication -export interface RedisMessage { - type: "direct" | "broadcast"; - targetClientId?: string; - excludeClientId?: string; +export interface BatchSendMessage { + clientId: string; message: WSMessage; - fromNodeId: string; options?: SendMessageOptions; } + +// Redis message types for cross-node communication +export type RedisMessage = + | { + type: "direct"; + targetClientId: string; + message: WSMessage; + fromNodeId: string; + } + | { + type: "direct-batch"; + messages: { + targetClientId: string; + message: WSMessage; + }[]; + fromNodeId: string; + } + | { + type: "broadcast"; + excludeClientId?: string; + message: WSMessage; + fromNodeId: string; + options?: SendMessageOptions; + }; diff --git a/server/routers/ws/ws.ts b/server/routers/ws/ws.ts index e7dcfe9cb..4ce337a20 100644 --- a/server/routers/ws/ws.ts +++ b/server/routers/ws/ws.ts @@ -26,7 +26,8 @@ import { WebSocketRequest, WSMessage, AuthenticatedWebSocket, - SendMessageOptions + SendMessageOptions, + BatchSendMessage } from "./types"; import { validateSessionToken } from "@server/auth/sessions/app"; @@ -212,6 +213,20 @@ const sendToClient = async ( return localSent; }; +const sendToClientsBatch = async ( + entries: BatchSendMessage[] +): Promise => { + if (entries.length === 0) { + return; + } + + await Promise.all( + entries.map((entry) => + sendToClient(entry.clientId, entry.message, entry.options) + ) + ); +}; + const broadcastToAllExcept = async ( message: WSMessage, excludeClientId?: string, @@ -552,6 +567,7 @@ export { router, handleWSUpgrade, sendToClient, + sendToClientsBatch, broadcastToAllExcept, connectedClients, hasActiveConnections,