Merge pull request #3294 from fosrl/copilot/add-batching-mechanism-to-rebuild-client-associati

Batch Redis websocket fan-out and deduplicate queued rebuild jobs
This commit is contained in:
Owen Schwartz
2026-06-21 14:20:56 -07:00
committed by GitHub
8 changed files with 934 additions and 143 deletions

View File

@@ -21,10 +21,10 @@ import {
} from "@server/db";
import { and, count, eq, inArray, ne } from "drizzle-orm";
import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers";
import { deletePeersBatch as newtDeletePeersBatch } from "@server/routers/newt/peers";
import {
initPeerAddHandshake,
deletePeer as olmDeletePeer
initPeerAddHandshakeBatch,
deletePeersBatch as olmDeletePeersBatch
} from "@server/routers/olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
import logger from "@server/logger";
@@ -35,10 +35,10 @@ import {
parseEndpoint
} from "@server/lib/ip";
import {
addPeerData,
addTargets as addSubnetProxyTargets,
removePeerData,
removeTargets as removeSubnetProxyTargets
addPeerDataBatch,
addTargetsBatch as addSubnetProxyTargetsBatch,
removePeerDataBatch,
removeTargetsBatch as removeSubnetProxyTargetsBatch
} from "@server/routers/client/targets";
import { lockManager } from "#dynamic/lib/lock";
import { rebuildQueue } from "#dynamic/lib/rebuildQueue";
@@ -559,6 +559,28 @@ async function handleMessagesForSiteClients(
const newtJobs: Promise<any>[] = [];
const olmJobs: Promise<any>[] = [];
const exitNodeJobs: Promise<any>[] = [];
const newtPeerDeletes: {
siteId: number;
publicKey: string;
newtId: string;
}[] = [];
const olmPeerDeletes: {
clientId: number;
siteId: number;
publicKey: string;
olmId: string;
}[] = [];
const olmPeerAddHandshakes: {
clientId: number;
peer: {
siteId: number;
exitNode: {
publicKey: string;
endpoint: string;
};
};
olmId: string;
}[] = [];
// Combine all clients that need processing (those being added or removed)
const clientsToProcess = new Map<
@@ -638,15 +660,17 @@ async function handleMessagesForSiteClients(
}
if (isDelete) {
newtJobs.push(newtDeletePeer(siteId, client.pubKey, newt.newtId));
olmJobs.push(
olmDeletePeer(
client.clientId,
siteId,
site.publicKey,
olm.olmId
)
);
newtPeerDeletes.push({
siteId,
publicKey: client.pubKey,
newtId: newt.newtId
});
olmPeerDeletes.push({
clientId: client.clientId,
siteId,
publicKey: site.publicKey,
olmId: olm.olmId
});
}
if (isAdd) {
@@ -658,23 +682,34 @@ async function handleMessagesForSiteClients(
continue;
}
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
{
olmPeerAddHandshakes.push({
clientId: client.clientId,
peer: {
siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
olm.olmId
);
olmId: olm.olmId
});
}
exitNodeJobs.push(updateClientSiteDestinations(client, trx));
}
if (newtPeerDeletes.length > 0) {
newtJobs.push(newtDeletePeersBatch(newtPeerDeletes));
}
if (olmPeerDeletes.length > 0) {
olmJobs.push(olmDeletePeersBatch(olmPeerDeletes));
}
if (olmPeerAddHandshakes.length > 0) {
olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes));
}
Promise.all(exitNodeJobs).catch((error) => {
logger.error(
`rebuildClientAssociations: Error updating client site destinations for site ${site.siteId}:`,
@@ -867,24 +902,28 @@ async function handleSubnetProxyTargetUpdates(
if (targetsToAdd) {
proxyJobs.push(
addSubnetProxyTargets(
newt.newtId,
targetsToAdd,
newt.version
)
addSubnetProxyTargetsBatch([
{
newtId: newt.newtId,
targets: targetsToAdd,
version: newt.version
}
])
);
}
for (const client of addedClients) {
olmJobs.push(
addPeerData(
client.clientId,
olmJobs.push(
addPeerDataBatch(
addedClients.map((client) => ({
clientId: client.clientId,
siteId,
generateRemoteSubnets([siteResource]),
generateAliasConfig([siteResource])
)
);
}
remoteSubnets: generateRemoteSubnets([
siteResource
]),
aliases: generateAliasConfig([siteResource])
}))
)
);
}
}
@@ -904,14 +943,23 @@ async function handleSubnetProxyTargetUpdates(
if (targetsToRemove) {
proxyJobs.push(
removeSubnetProxyTargets(
newt.newtId,
targetsToRemove,
newt.version
)
removeSubnetProxyTargetsBatch([
{
newtId: newt.newtId,
targets: targetsToRemove,
version: newt.version
}
])
);
}
const peerDataRemovals: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: ReturnType<typeof generateAliasConfig>;
}[] = [];
for (const client of removedClients) {
if (!siteResource.destination) {
continue;
@@ -959,14 +1007,16 @@ async function handleSubnetProxyTargetUpdates(
? []
: generateRemoteSubnets([siteResource]);
olmJobs.push(
removePeerData(
client.clientId,
siteId,
remoteSubnetsToRemove,
generateAliasConfig([siteResource])
)
);
peerDataRemovals.push({
clientId: client.clientId,
siteId,
remoteSubnets: remoteSubnetsToRemove,
aliases: generateAliasConfig([siteResource])
});
}
if (peerDataRemovals.length > 0) {
olmJobs.push(removePeerDataBatch(peerDataRemovals));
}
}
}
@@ -1277,6 +1327,28 @@ async function handleMessagesForClientSites(
const newtJobs: Promise<any>[] = [];
const olmJobs: Promise<any>[] = [];
const exitNodeJobs: Promise<any>[] = [];
const newtPeerDeletes: {
siteId: number;
publicKey: string;
newtId: string;
}[] = [];
const olmPeerDeletes: {
clientId: number;
siteId: number;
publicKey: string;
olmId: string;
}[] = [];
const olmPeerAddHandshakes: {
clientId: number;
peer: {
siteId: number;
exitNode: {
publicKey: string;
endpoint: string;
};
};
olmId: string;
}[] = [];
const totalSitesOnClient = await trx
.select({ count: count(clientSitesAssociationsCache.siteId) })
@@ -1308,19 +1380,19 @@ async function handleMessagesForClientSites(
if (isRemove) {
// Remove peer from newt
newtJobs.push(
newtDeletePeer(site.siteId, client.pubKey, newt.newtId)
);
newtPeerDeletes.push({
siteId: site.siteId,
publicKey: client.pubKey,
newtId: newt.newtId
});
try {
// Remove peer from olm
olmJobs.push(
olmDeletePeer(
client.clientId,
site.siteId,
site.publicKey,
olmId
)
);
olmPeerDeletes.push({
clientId: client.clientId,
siteId: site.siteId,
publicKey: site.publicKey,
olmId
});
} catch (error) {
// if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send
if (
@@ -1352,10 +1424,9 @@ async function handleMessagesForClientSites(
continue;
}
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
{
olmPeerAddHandshakes.push({
clientId: client.clientId,
peer: {
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
@@ -1363,7 +1434,7 @@ async function handleMessagesForClientSites(
}
},
olmId
);
});
}
// Update exit node destinations
@@ -1379,6 +1450,18 @@ async function handleMessagesForClientSites(
);
}
if (newtPeerDeletes.length > 0) {
newtJobs.push(newtDeletePeersBatch(newtPeerDeletes));
}
if (olmPeerDeletes.length > 0) {
olmJobs.push(olmDeletePeersBatch(olmPeerDeletes));
}
if (olmPeerAddHandshakes.length > 0) {
olmJobs.push(initPeerAddHandshakeBatch(olmPeerAddHandshakes));
}
Promise.all(exitNodeJobs).catch((error) => {
logger.error(
`rebuildClientAssociations: Error updating client site destinations for client ${client.clientId}:`,
@@ -1477,6 +1560,20 @@ async function handleMessagesForClientResources(
continue;
}
const targetsToAddBatch: {
newtId: string;
targets: NonNullable<
Awaited<ReturnType<typeof generateSubnetProxyTargetV2>>
>;
version: string | null;
}[] = [];
const peerDataAdds: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: ReturnType<typeof generateAliasConfig>;
}[] = [];
for (const resource of resources) {
const targets = await generateSubnetProxyTargetV2(resource, [
{
@@ -1487,25 +1584,21 @@ async function handleMessagesForClientResources(
]);
if (targets) {
proxyJobs.push(
addSubnetProxyTargets(
newt.newtId,
targets,
newt.version
)
);
targetsToAddBatch.push({
newtId: newt.newtId,
targets,
version: newt.version
});
}
try {
// Add peer data to olm
olmJobs.push(
addPeerData(
client.clientId,
siteId,
generateRemoteSubnets([resource]),
generateAliasConfig([resource])
)
);
peerDataAdds.push({
clientId: client.clientId,
siteId,
remoteSubnets: generateRemoteSubnets([resource]),
aliases: generateAliasConfig([resource])
});
} catch (error) {
// if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send
if (
@@ -1520,6 +1613,14 @@ async function handleMessagesForClientResources(
}
}
}
if (targetsToAddBatch.length > 0) {
proxyJobs.push(addSubnetProxyTargetsBatch(targetsToAddBatch));
}
if (peerDataAdds.length > 0) {
olmJobs.push(addPeerDataBatch(peerDataAdds));
}
}
}
@@ -1586,6 +1687,20 @@ async function handleMessagesForClientResources(
continue;
}
const targetsToRemoveBatch: {
newtId: string;
targets: NonNullable<
Awaited<ReturnType<typeof generateSubnetProxyTargetV2>>
>;
version: string | null;
}[] = [];
const peerDataRemovals: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: ReturnType<typeof generateAliasConfig>;
}[] = [];
for (const resource of resources) {
const targets = await generateSubnetProxyTargetV2(resource, [
{
@@ -1596,13 +1711,11 @@ async function handleMessagesForClientResources(
]);
if (targets) {
proxyJobs.push(
removeSubnetProxyTargets(
newt.newtId,
targets,
newt.version
)
);
targetsToRemoveBatch.push({
newtId: newt.newtId,
targets,
version: newt.version
});
}
try {
@@ -1653,14 +1766,12 @@ async function handleMessagesForClientResources(
: generateRemoteSubnets([resource]);
// Remove peer data from olm
olmJobs.push(
removePeerData(
client.clientId,
siteId,
remoteSubnetsToRemove,
generateAliasConfig([resource])
)
);
peerDataRemovals.push({
clientId: client.clientId,
siteId,
remoteSubnets: remoteSubnetsToRemove,
aliases: generateAliasConfig([resource])
});
} catch (error) {
// if the error includes not found then its just because the olm does not exist anymore or yet and its fine if we dont send
if (
@@ -1675,6 +1786,16 @@ async function handleMessagesForClientResources(
}
}
}
if (targetsToRemoveBatch.length > 0) {
proxyJobs.push(
removeSubnetProxyTargetsBatch(targetsToRemoveBatch)
);
}
if (peerDataRemovals.length > 0) {
olmJobs.push(removePeerDataBatch(peerDataRemovals));
}
}
}
@@ -1928,7 +2049,15 @@ export async function cleanupSiteAssociations(
for (const client of allClients) {
// Tell each olm to drop the site's WireGuard peer.
if (site.publicKey) {
jobs.push(olmDeletePeer(client.clientId, siteId, site.publicKey));
jobs.push(
olmDeletePeersBatch([
{
clientId: client.clientId,
siteId,
publicKey: site.publicKey
}
])
);
}
// Recompute and push updated relay destinations (now excluding this site).

View File

@@ -29,6 +29,7 @@ export interface RebuildJobHandlers {
// Redis list holding pending rebuild jobs (RPUSH to enqueue, LPOP to dequeue — FIFO order).
const QUEUE_KEY = "rebuild-client-associations:queue";
const QUEUED_SET_KEY = "rebuild-client-associations:queued";
// Distributed lock that serialises queue consumption to a single server instance
// at a time. TTL is generous enough to cover a full batch of expensive rebuilds.
@@ -54,11 +55,28 @@ class RedisRebuildQueue {
}
try {
const dedupeKey = `${job.type}:${job.id}`;
const added = await redis.sadd(QUEUED_SET_KEY, dedupeKey);
if (added === 0) {
logger.debug(
`Rebuild queue: skipped duplicate queued job ${job.type}:${job.id}`
);
return;
}
await redis.rpush(QUEUE_KEY, JSON.stringify(job));
logger.debug(
`Rebuild queue: enqueued ${job.type}:${job.id} (queue position: tail)`
);
} catch (err) {
await redis
.srem(QUEUED_SET_KEY, `${job.type}:${job.id}`)
.catch((cleanupErr) =>
logger.warn(
`Rebuild queue: failed to cleanup dedupe key for ${job.type}:${job.id} after enqueue failure:`,
cleanupErr
)
);
logger.error(
`Rebuild queue: failed to enqueue ${job.type}:${job.id}:`,
err
@@ -121,6 +139,17 @@ class RedisRebuildQueue {
continue;
}
// Remove from dedupe set once dequeued so the same job
// can be re-queued while this one is in progress.
await redis
.srem(QUEUED_SET_KEY, `${job.type}:${job.id}`)
.catch((cleanupErr) =>
logger.warn(
`Rebuild queue: failed to remove dedupe key for ${job.type}:${job.id} on dequeue:`,
cleanupErr
)
);
logger.debug(
`Rebuild queue: processing ${job.type}:${job.id}`
);

View File

@@ -38,6 +38,7 @@ import { messageHandlers } from "@server/routers/ws/messageHandlers";
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
import {
AuthenticatedWebSocket,
BatchSendMessage,
ClientType,
WSMessage,
TokenPayload,
@@ -187,6 +188,8 @@ const wss: WebSocketServer = new WebSocketServer({ noServer: true });
// Generate unique node ID for this instance
const NODE_ID = uuidv4();
const REDIS_CHANNEL = "websocket_messages";
const REDIS_DIRECT_BATCH_SIZE = 250;
const REDIS_DIRECT_FLUSH_INTERVAL_MS = 10;
// Client tracking map (local to this node)
const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
@@ -197,6 +200,15 @@ const clientConfigVersions: Map<string, number> = new Map();
// Recovery tracking
let isRedisRecoveryInProgress = false;
interface RedisDirectBatchEntry {
targetClientId: string;
message: WSMessage;
resolve: () => void;
}
let pendingRedisDirectMessages: RedisDirectBatchEntry[] = [];
let redisDirectFlushTimer: NodeJS.Timeout | null = null;
// Helper to get map key
const getClientMapKey = (clientId: string) => clientId;
@@ -207,6 +219,78 @@ const getNodeConnectionsKey = (nodeId: string, clientId: string) =>
const getConfigVersionKey = (clientId: string) =>
`ws:configVersion:${clientId}`;
const clearRedisDirectFlushTimer = (): void => {
if (redisDirectFlushTimer) {
clearTimeout(redisDirectFlushTimer);
redisDirectFlushTimer = null;
}
};
const publishDirectBatch = async (
entries: RedisDirectBatchEntry[]
): Promise<void> => {
const redisMessage: RedisMessage = {
type: "direct-batch",
messages: entries.map((entry) => ({
targetClientId: entry.targetClientId,
message: entry.message
})),
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
};
const flushPendingRedisDirectMessages = async (): Promise<void> => {
clearRedisDirectFlushTimer();
if (pendingRedisDirectMessages.length === 0) {
return;
}
const entries = pendingRedisDirectMessages;
pendingRedisDirectMessages = [];
if (!redisManager.isRedisEnabled()) {
entries.forEach((entry) => entry.resolve());
return;
}
for (let i = 0; i < entries.length; i += REDIS_DIRECT_BATCH_SIZE) {
const batch = entries.slice(i, i + REDIS_DIRECT_BATCH_SIZE);
try {
await publishDirectBatch(batch);
} catch (error) {
logger.error(
"Failed to send batched direct messages via Redis, messages may be lost:",
error
);
} finally {
batch.forEach((entry) => entry.resolve());
}
}
};
const enqueueRedisDirectMessage = async (
targetClientId: string,
message: WSMessage
): Promise<void> => {
await new Promise<void>((resolve) => {
pendingRedisDirectMessages.push({ targetClientId, message, resolve });
if (pendingRedisDirectMessages.length >= REDIS_DIRECT_BATCH_SIZE) {
void flushPendingRedisDirectMessages();
return;
}
if (!redisDirectFlushTimer) {
redisDirectFlushTimer = setTimeout(() => {
void flushPendingRedisDirectMessages();
}, REDIS_DIRECT_FLUSH_INTERVAL_MS);
}
});
};
// Initialize Redis subscription for cross-node messaging
const initializeRedisSubscription = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return;
@@ -227,7 +311,16 @@ const initializeRedisSubscription = async (): Promise<void> => {
// Send to specific client on this node
await sendToClientLocal(
redisMessage.targetClientId,
redisMessage.message
redisMessage.message,
{},
redisMessage.message.configVersion
);
} else if (
redisMessage.type === "direct-batch" &&
redisMessage.messages
) {
await sendRedisDirectBatchToLocalClients(
redisMessage.messages
);
} else if (redisMessage.type === "broadcast") {
// Broadcast to all clients on this node except excluded
@@ -503,7 +596,8 @@ const incrementClientConfigVersion = async (
const sendToClientLocal = async (
clientId: string,
message: WSMessage,
options: SendMessageOptions = {}
options: SendMessageOptions = {},
preResolvedConfigVersion?: number
): Promise<boolean> => {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
@@ -512,7 +606,8 @@ const sendToClientLocal = async (
}
// Handle config version
const configVersion = await getClientConfigVersion(clientId);
const configVersion =
preResolvedConfigVersion ?? (await getClientConfigVersion(clientId));
// Add config version to message
const messageWithVersion = {
@@ -545,6 +640,20 @@ const sendToClientLocal = async (
return true;
};
const sendRedisDirectBatchToLocalClients = async (
entries: { targetClientId: string; message: WSMessage }[]
): Promise<void> => {
const jobs = entries.map((entry) =>
sendToClientLocal(
entry.targetClientId,
entry.message,
{},
entry.message.configVersion
)
);
await Promise.all(jobs);
};
const broadcastToAllExceptLocal = async (
message: WSMessage,
excludeClientId?: string,
@@ -607,23 +716,13 @@ const sendToClient = async (
// Only send via Redis if the client is not connected locally and Redis is enabled
if (!localSent && redisManager.isRedisEnabled()) {
try {
const redisMessage: RedisMessage = {
type: "direct",
targetClientId: clientId,
message: {
...message,
configVersion
},
fromNodeId: NODE_ID
};
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
await enqueueRedisDirectMessage(clientId, {
...message,
configVersion
});
} catch (error) {
logger.error(
"Failed to send message via Redis, message may be lost:",
"Failed to queue batched direct message for Redis delivery, message may be lost:",
error
);
// Continue execution - local delivery already attempted
@@ -638,6 +737,76 @@ const sendToClient = async (
return localSent;
};
const sendToClientsBatch = async (
entries: BatchSendMessage[]
): Promise<void> => {
if (entries.length === 0) {
return;
}
const remoteEntries: { targetClientId: string; message: WSMessage }[] = [];
for (const entry of entries) {
const options = entry.options || {};
const { clientId, message } = entry;
let configVersion = await getClientConfigVersion(clientId);
if (options.incrementConfigVersion) {
configVersion = await incrementClientConfigVersion(clientId);
}
logger.debug(
`sendToClientsBatch: Message type ${message.type} queued for clientId ${clientId} (new configVersion: ${configVersion})`
);
const localSent = await sendToClientLocal(
clientId,
message,
options,
configVersion
);
if (!localSent && redisManager.isRedisEnabled()) {
remoteEntries.push({
targetClientId: clientId,
message: {
...message,
configVersion
}
});
} else if (!localSent && !redisManager.isRedisEnabled()) {
logger.debug(
`Could not deliver batch message to ${clientId} - not connected locally and Redis unavailable`
);
}
}
if (!redisManager.isRedisEnabled() || remoteEntries.length === 0) {
return;
}
for (let i = 0; i < remoteEntries.length; i += REDIS_DIRECT_BATCH_SIZE) {
const messages = remoteEntries.slice(i, i + REDIS_DIRECT_BATCH_SIZE);
try {
const redisMessage: RedisMessage = {
type: "direct-batch",
messages,
fromNodeId: NODE_ID
};
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error(
"Failed to send explicit direct batch via Redis, messages may be lost:",
error
);
}
}
};
const broadcastToAllExcept = async (
message: WSMessage,
excludeClientId?: string,
@@ -1109,6 +1278,8 @@ const disconnectClient = async (clientId: string): Promise<boolean> => {
// Cleanup function for graceful shutdown
const cleanup = async (): Promise<void> => {
try {
await flushPendingRedisDirectMessages();
// Close all WebSocket connections
connectedClients.forEach((clients) => {
clients.forEach((client) => {
@@ -1139,6 +1310,7 @@ export {
router,
handleWSUpgrade,
sendToClient,
sendToClientsBatch,
broadcastToAllExcept,
connectedClients,
hasActiveConnections,

View File

@@ -1,4 +1,4 @@
import { sendToClient } from "#dynamic/routers/ws";
import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws";
import { db, newts, olms } from "@server/db";
import {
Alias,
@@ -8,7 +8,7 @@ import {
} from "@server/lib/ip";
import { canCompress } from "@server/lib/clientVersionChecks";
import logger from "@server/logger";
import { eq } from "drizzle-orm";
import { eq, inArray } from "drizzle-orm";
import semver from "semver";
const NEWT_V2_TARGETS_VERSION = ">=1.10.3";
@@ -59,6 +59,42 @@ export async function addTargets(
);
}
export async function addTargetsBatch(
entries: {
newtId: string;
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
version?: string | null;
}[]
) {
if (entries.length === 0) {
return;
}
const resolved = await Promise.all(
entries.map(async (entry) => ({
...entry,
targets: await convertTargetsIfNecessary(
entry.newtId,
entry.targets
)
}))
);
await sendToClientsBatch(
resolved.map((entry) => ({
clientId: entry.newtId,
message: {
type: `newt/wg/targets/add`,
data: entry.targets
},
options: {
incrementConfigVersion: true,
compress: canCompress(entry.version, "newt")
}
}))
);
}
export async function removeTargets(
newtId: string,
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[],
@@ -76,6 +112,42 @@ export async function removeTargets(
);
}
export async function removeTargetsBatch(
entries: {
newtId: string;
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
version?: string | null;
}[]
) {
if (entries.length === 0) {
return;
}
const resolved = await Promise.all(
entries.map(async (entry) => ({
...entry,
targets: await convertTargetsIfNecessary(
entry.newtId,
entry.targets
)
}))
);
await sendToClientsBatch(
resolved.map((entry) => ({
clientId: entry.newtId,
message: {
type: `newt/wg/targets/remove`,
data: entry.targets
},
options: {
incrementConfigVersion: true,
compress: canCompress(entry.version, "newt")
}
}))
);
}
export async function updateTargets(
newtId: string,
targets: {
@@ -201,6 +273,171 @@ export async function removePeerData(
});
}
const resolveOlmTargets = async (
entries: {
clientId: number;
olmId?: string;
version?: string | null;
}[]
) => {
const unresolvedClientIds = entries
.filter((entry) => !entry.olmId)
.map((entry) => entry.clientId);
const olmMap = new Map<number, { olmId: string; version: string | null }>();
if (unresolvedClientIds.length > 0) {
const olmRows = await db
.select({
clientId: olms.clientId,
olmId: olms.olmId,
version: olms.version
})
.from(olms)
.where(inArray(olms.clientId, unresolvedClientIds));
for (const row of olmRows) {
if (row.clientId !== null) {
olmMap.set(row.clientId, {
olmId: row.olmId,
version: row.version
});
}
}
}
return entries
.map((entry) => {
if (entry.olmId) {
return {
clientId: entry.clientId,
olmId: entry.olmId,
version: entry.version
};
}
const resolved = olmMap.get(entry.clientId);
if (!resolved) {
return null;
}
return {
clientId: entry.clientId,
olmId: resolved.olmId,
version: entry.version ?? resolved.version
};
})
.filter((entry) => entry !== null);
};
export async function addPeerDataBatch(
entries: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: Alias[];
olmId?: string;
version?: string | null;
}[]
) {
if (entries.length === 0) {
return;
}
const resolvedTargets = await resolveOlmTargets(entries);
if (resolvedTargets.length === 0) {
return;
}
const payloads = entries
.map((entry) => {
const resolved = resolvedTargets.find(
(target) => target.clientId === entry.clientId
);
if (!resolved) {
return null;
}
return {
clientId: resolved.olmId,
message: {
type: `olm/wg/peer/data/add`,
data: {
siteId: entry.siteId,
remoteSubnets: entry.remoteSubnets,
aliases: entry.aliases
}
},
options: {
incrementConfigVersion: true,
compress: canCompress(resolved.version, "olm")
}
};
})
.filter((entry) => entry !== null);
if (payloads.length === 0) {
return;
}
await sendToClientsBatch(payloads);
}
export async function removePeerDataBatch(
entries: {
clientId: number;
siteId: number;
remoteSubnets: string[];
aliases: Alias[];
olmId?: string;
version?: string | null;
}[]
) {
if (entries.length === 0) {
return;
}
const resolvedTargets = await resolveOlmTargets(entries);
if (resolvedTargets.length === 0) {
return;
}
const payloads = entries
.map((entry) => {
const resolved = resolvedTargets.find(
(target) => target.clientId === entry.clientId
);
if (!resolved) {
return null;
}
return {
clientId: resolved.olmId,
message: {
type: `olm/wg/peer/data/remove`,
data: {
siteId: entry.siteId,
remoteSubnets: entry.remoteSubnets,
aliases: entry.aliases
}
},
options: {
incrementConfigVersion: true,
compress: canCompress(resolved.version, "olm")
}
};
})
.filter((entry) => entry !== null);
if (payloads.length === 0) {
return;
}
await sendToClientsBatch(payloads);
}
export async function updatePeerData(
clientId: number,
siteId: number,

View File

@@ -1,7 +1,7 @@
import { db, Site } from "@server/db";
import { newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToClient } from "#dynamic/routers/ws";
import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws";
import logger from "@server/logger";
export async function addPeer(
@@ -36,10 +36,14 @@ export async function addPeer(
newtId = newt.newtId;
}
await sendToClient(newtId, {
type: "newt/wg/peer/add",
data: peer
}, { incrementConfigVersion: true }).catch((error) => {
await sendToClient(
newtId,
{
type: "newt/wg/peer/add",
data: peer
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -76,12 +80,16 @@ export async function deletePeer(
newtId = newt.newtId;
}
await sendToClient(newtId, {
type: "newt/wg/peer/remove",
data: {
publicKey
}
}, { incrementConfigVersion: true }).catch((error) => {
await sendToClient(
newtId,
{
type: "newt/wg/peer/remove",
data: {
publicKey
}
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -90,6 +98,35 @@ export async function deletePeer(
return site;
}
export async function deletePeersBatch(
peers: {
siteId: number;
publicKey: string;
newtId: string;
}[]
) {
if (peers.length === 0) {
return;
}
await sendToClientsBatch(
peers.map((peer) => ({
clientId: peer.newtId,
message: {
type: "newt/wg/peer/remove",
data: {
publicKey: peer.publicKey
}
},
options: { incrementConfigVersion: true }
}))
).catch((error) => {
logger.warn(`Error sending batched newt peer removals:`, error);
});
logger.info(`Deleted ${peers.length} peer(s) from newts (batch)`);
}
export async function updatePeer(
siteId: number,
publicKey: string,
@@ -122,13 +159,17 @@ export async function updatePeer(
newtId = newt.newtId;
}
await sendToClient(newtId, {
type: "newt/wg/peer/update",
data: {
publicKey,
...peer
}
}, { incrementConfigVersion: true }).catch((error) => {
await sendToClient(
newtId,
{
type: "newt/wg/peer/update",
data: {
publicKey,
...peer
}
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});

View File

@@ -1,9 +1,9 @@
import { sendToClient } from "#dynamic/routers/ws";
import { sendToClient, sendToClientsBatch } from "#dynamic/routers/ws";
import { clientSitesAssociationsCache, db, olms } from "@server/db";
import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config";
import logger from "@server/logger";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
import { Alias } from "yaml";
export async function addPeer(
@@ -205,3 +205,150 @@ export async function initPeerAddHandshake(
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
);
}
export async function deletePeersBatch(
peers: {
clientId: number;
siteId: number;
publicKey: string;
olmId?: string;
version?: string | null;
}[]
) {
if (peers.length === 0) {
return;
}
const unresolvedClientIds = peers
.filter((peer) => !peer.olmId)
.map((peer) => peer.clientId);
const olmByClientId = new Map<
number,
{ olmId: string; version: string | null }
>();
if (unresolvedClientIds.length > 0) {
const olmRows = await db
.select({
clientId: olms.clientId,
olmId: olms.olmId,
version: olms.version
})
.from(olms)
.where(inArray(olms.clientId, unresolvedClientIds));
for (const row of olmRows) {
if (row.clientId !== null) {
olmByClientId.set(row.clientId, {
olmId: row.olmId,
version: row.version
});
}
}
}
const batchPayloads = peers
.map((peer) => {
const resolved = peer.olmId
? { olmId: peer.olmId, version: peer.version ?? null }
: olmByClientId.get(peer.clientId);
if (!resolved) {
return null;
}
return {
clientId: resolved.olmId,
message: {
type: "olm/wg/peer/remove",
data: {
publicKey: peer.publicKey,
siteId: peer.siteId
}
},
options: {
incrementConfigVersion: true,
compress: canCompress(
peer.version ?? resolved.version,
"olm"
)
}
};
})
.filter((payload) => payload !== null);
if (batchPayloads.length === 0) {
return;
}
await sendToClientsBatch(batchPayloads).catch((error) => {
logger.warn(`Error sending batched olm peer removals:`, error);
});
logger.info(`Deleted ${batchPayloads.length} peer(s) from olms (batch)`);
}
export async function initPeerAddHandshakeBatch(
handshakes: {
clientId: number;
peer: {
siteId: number;
exitNode: {
publicKey: string;
endpoint: string;
};
};
olmId: string;
chainId?: string;
}[]
) {
if (handshakes.length === 0) {
return;
}
await sendToClientsBatch(
handshakes.map((item) => ({
clientId: item.olmId,
message: {
type: "olm/wg/peer/holepunch/site/add",
data: {
siteId: item.peer.siteId,
exitNode: {
publicKey: item.peer.exitNode.publicKey,
relayPort:
config.getRawConfig().gerbil.clients_start_port,
endpoint: item.peer.exitNode.endpoint
},
chainId: item.chainId
}
},
options: { incrementConfigVersion: true }
}))
).catch((error) => {
logger.warn(`Error sending batched olm handshakes:`, error);
});
await Promise.all(
handshakes.map((item) =>
db
.update(clientSitesAssociationsCache)
.set({ isJitMode: false })
.where(
and(
eq(
clientSitesAssociationsCache.clientId,
item.clientId
),
eq(
clientSitesAssociationsCache.siteId,
item.peer.siteId
)
)
)
)
);
logger.info(
`Initiated ${handshakes.length} peer add handshake(s) to olms (batch)`
);
}

View File

@@ -76,12 +76,32 @@ export interface SendMessageOptions {
compress?: boolean;
}
// Redis message type for cross-node communication
export interface RedisMessage {
type: "direct" | "broadcast";
targetClientId?: string;
excludeClientId?: string;
export interface BatchSendMessage {
clientId: string;
message: WSMessage;
fromNodeId: string;
options?: SendMessageOptions;
}
// Redis message types for cross-node communication
export type RedisMessage =
| {
type: "direct";
targetClientId: string;
message: WSMessage;
fromNodeId: string;
}
| {
type: "direct-batch";
messages: {
targetClientId: string;
message: WSMessage;
}[];
fromNodeId: string;
}
| {
type: "broadcast";
excludeClientId?: string;
message: WSMessage;
fromNodeId: string;
options?: SendMessageOptions;
};

View File

@@ -26,7 +26,8 @@ import {
WebSocketRequest,
WSMessage,
AuthenticatedWebSocket,
SendMessageOptions
SendMessageOptions,
BatchSendMessage
} from "./types";
import { validateSessionToken } from "@server/auth/sessions/app";
@@ -212,6 +213,20 @@ const sendToClient = async (
return localSent;
};
const sendToClientsBatch = async (
entries: BatchSendMessage[]
): Promise<void> => {
if (entries.length === 0) {
return;
}
await Promise.all(
entries.map((entry) =>
sendToClient(entry.clientId, entry.message, entry.options)
)
);
};
const broadcastToAllExcept = async (
message: WSMessage,
excludeClientId?: string,
@@ -552,6 +567,7 @@ export {
router,
handleWSUpgrade,
sendToClient,
sendToClientsBatch,
broadcastToAllExcept,
connectedClients,
hasActiveConnections,