diff --git a/server/private/lib/exitNodes/exitNodes.ts b/server/private/lib/exitNodes/exitNodes.ts index f6417dae2..1f9517725 100644 --- a/server/private/lib/exitNodes/exitNodes.ts +++ b/server/private/lib/exitNodes/exitNodes.ts @@ -18,12 +18,15 @@ import { resources, targets, sites, + siteLabels, + remoteExitNodes, + remoteExitNodePreferenceLabels, targetHealthCheck, Transaction } from "@server/db"; import logger from "@server/logger"; import { ExitNodePingResult } from "@server/routers/newt"; -import { eq, and, or, ne, isNull } from "drizzle-orm"; +import { eq, and, or, ne, isNull, inArray } from "drizzle-orm"; import axios from "axios"; import config from "../config"; @@ -150,7 +153,8 @@ export async function verifyExitNodeOrgAccess( export async function listExitNodes( orgId: string, filterOnline = false, - noCloud = false + noCloud = false, + siteId?: number ) { const allExitNodes = await db .select({ @@ -237,7 +241,7 @@ export async function listExitNodes( // }) // ); - const remoteExitNodes = allExitNodes.filter( + let remoteExitNodesList = allExitNodes.filter( (node) => node.type === "remoteExitNode" && (!filterOnline || node.online) ); @@ -246,9 +250,82 @@ export async function listExitNodes( node.type === "gerbil" && (!filterOnline || node.online) && !noCloud ); + // Apply label-based filtering to remote exit nodes if siteId is provided + if (siteId !== undefined && remoteExitNodesList.length > 0) { + // Get the site's labels + const siteLabelRows = await db + .select({ labelId: siteLabels.labelId }) + .from(siteLabels) + .where(eq(siteLabels.siteId, siteId)); + const siteLabelIds = new Set(siteLabelRows.map((r) => r.labelId)); + + // Get the remoteExitNode records for these exit nodes so we have the remoteExitNodeId + const exitNodeIds = remoteExitNodesList.map((n) => n.exitNodeId); + const remoteNodeRows = await db + .select({ + exitNodeId: remoteExitNodes.exitNodeId, + remoteExitNodeId: remoteExitNodes.remoteExitNodeId + }) + .from(remoteExitNodes) + .where(inArray(remoteExitNodes.exitNodeId, exitNodeIds)); + + const exitNodeIdToRemoteId = new Map( + remoteNodeRows + .filter((r) => r.exitNodeId !== null) + .map((r) => [r.exitNodeId!, r.remoteExitNodeId]) + ); + + // Get preference labels for all remote exit nodes + const remoteExitNodeIds = remoteNodeRows.map((r) => r.remoteExitNodeId); + const prefLabelRows = + remoteExitNodeIds.length > 0 + ? await db + .select({ + remoteExitNodeId: + remoteExitNodePreferenceLabels.remoteExitNodeId, + labelId: remoteExitNodePreferenceLabels.labelId + }) + .from(remoteExitNodePreferenceLabels) + .where( + inArray( + remoteExitNodePreferenceLabels.remoteExitNodeId, + remoteExitNodeIds + ) + ) + : []; + + // Build a map of remoteExitNodeId -> Set of labelIds + const prefLabelsMap = new Map>(); + for (const row of prefLabelRows) { + if (!prefLabelsMap.has(row.remoteExitNodeId)) { + prefLabelsMap.set(row.remoteExitNodeId, new Set()); + } + prefLabelsMap.get(row.remoteExitNodeId)!.add(row.labelId); + } + + // Filter: include node if it has no preference labels, or if site shares at least one label + const filtered = remoteExitNodesList.filter((node) => { + const remoteId = exitNodeIdToRemoteId.get(node.exitNodeId); + if (!remoteId) return true; // no remoteExitNode record, don't filter + const prefLabels = prefLabelsMap.get(remoteId); + if (!prefLabels || prefLabels.size === 0) return true; // no preference labels, include + // include only if site has at least one matching label + for (const labelId of siteLabelIds) { + if (prefLabels.has(labelId)) return true; + } + return false; + }); + + // Only apply the filtered list if at least one remote node remains; + // otherwise fall through to the gerbil fallback below + if (filtered.length > 0 || remoteExitNodesList.length === 0) { + remoteExitNodesList = filtered; + } + } + // THIS PROVIDES THE FALL const exitNodesList = - remoteExitNodes.length > 0 ? remoteExitNodes : gerbilExitNodes; + remoteExitNodesList.length > 0 ? remoteExitNodesList : gerbilExitNodes; return exitNodesList; } diff --git a/server/routers/newt/handleNewtPingRequestMessage.ts b/server/routers/newt/handleNewtPingRequestMessage.ts index 8f6df4bec..f239dc4de 100644 --- a/server/routers/newt/handleNewtPingRequestMessage.ts +++ b/server/routers/newt/handleNewtPingRequestMessage.ts @@ -38,7 +38,8 @@ export const handleNewtPingRequestMessage: MessageHandler = async (context) => { const exitNodesList = await listExitNodes( site.orgId, true, - noCloud || false + noCloud || false, + newt.siteId ); // filter for only the online ones let lastExitNodeId = null;