Merge branch 'dev' into rdp-ssh

This commit is contained in:
Owen
2026-05-15 11:18:31 -07:00
12 changed files with 560 additions and 382 deletions

View File

@@ -20,9 +20,7 @@ import {
} from "@server/db";
import { and, eq, inArray, ne } from "drizzle-orm";
import {
deletePeer as newtDeletePeer
} from "@server/routers/newt/peers";
import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers";
import {
initPeerAddHandshake,
deletePeer as olmDeletePeer
@@ -33,7 +31,7 @@ import {
generateAliasConfig,
generateRemoteSubnets,
generateSubnetProxyTargetV2,
parseEndpoint,
parseEndpoint
} from "@server/lib/ip";
import {
addPeerData,
@@ -51,10 +49,7 @@ export async function getClientSiteResourceAccess(
? await trx
.select()
.from(sites)
.innerJoin(
siteNetworks,
eq(siteNetworks.siteId, sites.siteId)
)
.innerJoin(siteNetworks, eq(siteNetworks.siteId, sites.siteId))
.where(eq(siteNetworks.networkId, siteResource.networkId))
.then((rows) => rows.map((row) => row.sites))
: [];
@@ -362,7 +357,8 @@ export async function rebuildClientAssociationsFromSiteResource(
.where(inArray(clients.clientId, existingClientSiteIds))
: [];
const otherResourceClientIds = clientsFromOtherResourcesBySite.get(siteId) ?? new Set<number>();
const otherResourceClientIds =
clientsFromOtherResourcesBySite.get(siteId) ?? new Set<number>();
logger.debug(
`rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] siteId=${siteId} otherResourceClientIds=[${[...otherResourceClientIds].join(", ")}] mergedAllClientIds=[${mergedAllClientIds.join(", ")}]`
@@ -709,7 +705,7 @@ export async function updateClientSiteDestinations(
sourcePort: destination.sourcePort,
destinations: destination.destinations
};
logger.info(
logger.debug(
`Payload for update-destinations: ${JSON.stringify(payload, null, 2)}`
);

View File

@@ -11,7 +11,7 @@ import {
ExitNode
} from "@server/db";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
import { eq, inArray } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
@@ -97,86 +97,119 @@ export async function generateRelayMappings(exitNode: ExitNode) {
return {};
}
// Filter to sites with the required fields up front so the rest of the
// function can safely treat endpoint/subnet/listenPort as defined.
const validSites = sitesRes.filter(
(s) => s.endpoint && s.subnet && s.listenPort
);
if (validSites.length === 0) {
return {};
}
const siteIds = validSites.map((s) => s.siteId);
const orgIds = Array.from(
new Set(
validSites
.map((s) => s.orgId)
.filter((id): id is NonNullable<typeof id> => id != null)
)
);
// Batch fetch all client-site associations for these sites in one query.
const clientSitesRes = siteIds.length
? await db
.select()
.from(clientSitesAssociationsCache)
.where(inArray(clientSitesAssociationsCache.siteId, siteIds))
: [];
// Batch fetch all sites in the relevant orgs in one query (covers
// site-to-site communication for every site processed below).
const orgSitesRes = orgIds.length
? await db.select().from(sites).where(inArray(sites.orgId, orgIds))
: [];
// Index org sites by orgId for O(1) lookup per site.
const sitesByOrg = new Map<string, typeof orgSitesRes>();
for (const peer of orgSitesRes) {
if (
peer.orgId == null ||
!peer.endpoint ||
!peer.subnet ||
!peer.listenPort
) {
continue;
}
let arr = sitesByOrg.get(peer.orgId);
if (!arr) {
arr = [];
sitesByOrg.set(peer.orgId, arr);
}
arr.push(peer);
}
// Index client-site associations by siteId for O(1) lookup per site.
const clientSitesBySite = new Map<number, typeof clientSitesRes>();
for (const cs of clientSitesRes) {
let arr = clientSitesBySite.get(cs.siteId);
if (!arr) {
arr = [];
clientSitesBySite.set(cs.siteId, arr);
}
arr.push(cs);
}
// Initialize mappings object for multi-peer support
const mappings: { [key: string]: ProxyMapping } = {};
// Process each site
for (const site of sitesRes) {
if (!site.endpoint || !site.subnet || !site.listenPort) {
continue;
// Track destinations per endpoint to deduplicate in O(1).
const seen = new Map<string, Set<string>>();
const addDestination = (endpoint: string, dest: PeerDestination) => {
let destSet = seen.get(endpoint);
if (!destSet) {
destSet = new Set();
seen.set(endpoint, destSet);
mappings[endpoint] = { destinations: [] };
}
// Find all clients associated with this site through clientSites
const clientSitesRes = await db
.select()
.from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.siteId, site.siteId));
for (const clientSite of clientSitesRes) {
if (!clientSite.endpoint) {
continue;
}
// Add this site as a destination for the client
if (!mappings[clientSite.endpoint]) {
mappings[clientSite.endpoint] = { destinations: [] };
}
// Add site as a destination for this client
const destination: PeerDestination = {
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
};
// Check if this destination is already in the array to avoid duplicates
const isDuplicate = mappings[clientSite.endpoint].destinations.some(
(dest) =>
dest.destinationIP === destination.destinationIP &&
dest.destinationPort === destination.destinationPort
);
if (!isDuplicate) {
mappings[clientSite.endpoint].destinations.push(destination);
}
const key = `${dest.destinationIP}:${dest.destinationPort}`;
if (!destSet.has(key)) {
destSet.add(key);
mappings[endpoint].destinations.push(dest);
}
};
// Also handle site-to-site communication (all sites in the same org)
if (site.orgId) {
const orgSites = await db
.select()
.from(sites)
.where(eq(sites.orgId, site.orgId));
// Process each site using the pre-fetched data.
for (const site of validSites) {
const siteDestination: PeerDestination = {
destinationIP: site.subnet!.split("/")[0],
destinationPort: site.listenPort! || 1 // this satisfies gerbil for now but should be reevaluated
};
for (const peer of orgSites) {
// Skip self
if (
peer.siteId === site.siteId ||
!peer.endpoint ||
!peer.subnet ||
!peer.listenPort
) {
// Add this site as a destination for each associated client.
const clientSites = clientSitesBySite.get(site.siteId);
if (clientSites) {
for (const clientSite of clientSites) {
if (!clientSite.endpoint) {
continue;
}
addDestination(clientSite.endpoint, siteDestination);
}
}
// Add peer site as a destination for this site
if (!mappings[site.endpoint]) {
mappings[site.endpoint] = { destinations: [] };
}
const destination: PeerDestination = {
destinationIP: peer.subnet.split("/")[0],
destinationPort: peer.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
};
// Check for duplicates
const isDuplicate = mappings[site.endpoint].destinations.some(
(dest) =>
dest.destinationIP === destination.destinationIP &&
dest.destinationPort === destination.destinationPort
);
if (!isDuplicate) {
mappings[site.endpoint].destinations.push(destination);
// Site-to-site communication (all sites in the same org).
if (site.orgId != null) {
const peers = sitesByOrg.get(site.orgId);
if (peers) {
for (const peer of peers) {
if (peer.siteId === site.siteId) {
continue;
}
addDestination(site.endpoint!, {
destinationIP: peer.subnet!.split("/")[0],
destinationPort: peer.listenPort! || 1 // this satisfies gerbil for now but should be reevaluated
});
}
}
}

View File

@@ -11,7 +11,7 @@ import {
ExitNode
} from "@server/db";
import { db } from "@server/db";
import { eq, and } from "drizzle-orm";
import { eq, and, inArray } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
@@ -185,16 +185,20 @@ export async function updateAndGenerateEndpointDestinations(
const sitesOnExitNode = await db
.select({
siteId: sites.siteId,
newtId: newts.newtId,
subnet: sites.subnet,
listenPort: sites.listenPort,
publicKey: sites.publicKey,
endpoint: clientSitesAssociationsCache.endpoint
endpoint: clientSitesAssociationsCache.endpoint,
isRelayed: clientSitesAssociationsCache.isRelayed,
isJitMode: clientSitesAssociationsCache.isJitMode
})
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.innerJoin(newts, eq(sites.siteId, newts.siteId))
.where(
and(
eq(sites.exitNodeId, exitNode.exitNodeId),
@@ -202,24 +206,36 @@ export async function updateAndGenerateEndpointDestinations(
)
);
// Update clientSites for each site on this exit node
// Format the endpoint properly for both IPv4 and IPv6
const formattedEndpoint = formatEndpoint(ip, port);
// Determine which rows actually need updating and whether the endpoint
// (as opposed to only the publicKey) changed for any of them.
const siteIdsToUpdate: number[] = [];
const sitesWithNewtsToUpdate: { siteId: number; newtId: string }[] = [];
let endpointChanged = false;
for (const site of sitesOnExitNode) {
// logger.debug(
// `Updating site ${site.siteId} on exit node ${exitNode.exitNodeId}`
// );
// Format the endpoint properly for both IPv4 and IPv6
const formattedEndpoint = formatEndpoint(ip, port);
// if the public key or endpoint has changed, update it otherwise continue
if (
site.endpoint === formattedEndpoint &&
site.publicKey === publicKey
) {
continue;
}
siteIdsToUpdate.push(site.siteId);
if (!site.isRelayed && !site.isJitMode) {
sitesWithNewtsToUpdate.push({
siteId: site.siteId,
newtId: site.newtId
});
}
if (site.endpoint !== formattedEndpoint) {
endpointChanged = true;
}
}
const [updatedClientSitesAssociationsCache] = await db
if (siteIdsToUpdate.length > 0) {
// Single bulk update for all affected rows for this client on this exit node
await db
.update(clientSitesAssociationsCache)
.set({
endpoint: formattedEndpoint,
@@ -228,24 +244,30 @@ export async function updateAndGenerateEndpointDestinations(
.where(
and(
eq(clientSitesAssociationsCache.clientId, olm.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
inArray(
clientSitesAssociationsCache.siteId,
siteIdsToUpdate
)
)
)
.returning();
);
if (
updatedClientSitesAssociationsCache.endpoint !==
site.endpoint && // this is the endpoint from the join table not the site
updatedClient.pubKey === publicKey // only trigger if the client's public key matches the current public key which means it has registered so we dont prematurely send the update
) {
// Only trigger downstream peer updates once per hole punch: the
// endpoint is the same for every site on this exit node, and
// handleClientEndpointChange already fans out to all connected
// sites for this client.
if (endpointChanged && updatedClient.pubKey === publicKey) {
logger.info(
`ClientSitesAssociationsCache for client ${olm.clientId} and site ${site.siteId} endpoint changed from ${site.endpoint} to ${updatedClientSitesAssociationsCache.endpoint}`
`ClientSitesAssociationsCache for client ${olm.clientId} endpoint changed to ${formattedEndpoint} for ${siteIdsToUpdate.length} site(s) on exit node ${exitNode.exitNodeId}`
);
// Handle any additional logic for endpoint change
handleClientEndpointChange(
sitesWithNewtsToUpdate,
olm.clientId,
updatedClientSitesAssociationsCache.endpoint!
);
formattedEndpoint
).catch((error) => {
logger.error(
`Failed to handle client endpoint change for client ${olm.clientId}: ${error}`
);
});
}
}
@@ -336,59 +358,14 @@ export async function updateAndGenerateEndpointDestinations(
`Site ${newt.siteId} endpoint changed from ${site.endpoint} to ${updatedSite.endpoint}`
);
// Handle any additional logic for endpoint change
handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!);
handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!).catch(
(error) => {
logger.error(
`Failed to handle site endpoint change for site ${newt.siteId}: ${error}`
);
}
);
}
// if (!updatedSite || !updatedSite.subnet) {
// logger.warn(`Site not found: ${newt.siteId}`);
// throw new Error("Site not found");
// }
// Find all clients that connect to this site
// const sitesClientPairs = await db
// .select()
// .from(clientSites)
// .where(eq(clientSites.siteId, newt.siteId));
// THE NEWT IS NOT SENDING RAW WG TO THE GERBIL SO IDK IF WE REALLY NEED THIS - REMOVING
// Get client details for each client
// for (const pair of sitesClientPairs) {
// const [client] = await db
// .select()
// .from(clients)
// .where(eq(clients.clientId, pair.clientId));
// if (client && client.endpoint) {
// const [host, portStr] = client.endpoint.split(':');
// if (host && portStr) {
// destinations.push({
// destinationIP: host,
// destinationPort: parseInt(portStr, 10)
// });
// }
// }
// }
// If this is a newt/site, also add other sites in the same org
// if (updatedSite.orgId) {
// const orgSites = await db
// .select()
// .from(sites)
// .where(eq(sites.orgId, updatedSite.orgId));
// for (const site of orgSites) {
// // Don't add the current site to the destinations
// if (site.siteId !== currentSiteId && site.subnet && site.endpoint && site.listenPort) {
// const [host, portStr] = site.endpoint.split(':');
// if (host && portStr) {
// destinations.push({
// destinationIP: host,
// destinationPort: site.listenPort
// });
// }
// }
// }
// }
}
return destinations;
}
@@ -408,12 +385,14 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
return;
}
// Get all non-relayed clients connected to this site
// Get all non-relayed and not jit clients connected to this site
const connectedClients = await db
.select({
online: clients.online,
clientId: clients.clientId,
olmId: olms.olmId,
isRelayed: clientSitesAssociationsCache.isRelayed
isRelayed: clientSitesAssociationsCache.isRelayed,
isJitMode: clientSitesAssociationsCache.isJitMode
})
.from(clientSitesAssociationsCache)
.innerJoin(
@@ -423,32 +402,36 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
.innerJoin(olms, eq(olms.clientId, clients.clientId))
.where(
and(
eq(clients.online, true), // the client has to be online or it does not matter...
eq(clientSitesAssociationsCache.siteId, siteId),
eq(clientSitesAssociationsCache.isRelayed, false)
eq(clientSitesAssociationsCache.isRelayed, false),
eq(clientSitesAssociationsCache.isJitMode, false)
)
);
// Update each non-relayed client with the new site endpoint
for (const client of connectedClients) {
try {
await updateOlmPeer(
client.clientId,
{
siteId: siteId,
publicKey: site.publicKey,
endpoint: newEndpoint
},
client.olmId
);
logger.debug(
`Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}`
);
} catch (error) {
logger.error(
`Failed to update client ${client.clientId} with new site endpoint: ${error}`
);
}
}
// Update each non-relayed client with the new site endpoint (in parallel)
await Promise.allSettled(
connectedClients.map(async (client) => {
try {
await updateOlmPeer(
client.clientId,
{
siteId: siteId,
publicKey: site.publicKey!,
endpoint: newEndpoint
},
client.olmId
);
logger.debug(
`Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}`
);
} catch (error) {
logger.error(
`Failed to update client ${client.clientId} with new site endpoint: ${error}`
);
}
})
);
} catch (error) {
logger.error(
`Error handling site endpoint change for site ${siteId}: ${error}`
@@ -457,10 +440,11 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
}
async function handleClientEndpointChange(
sitesWithNewtsToUpdate: { siteId: number; newtId: string }[],
clientId: number,
newEndpoint: string
) {
// Alert all sites connected to this client that the endpoint has changed (only if NOT relayed)
// Alert all sites connected to this client that the endpoint has changed (only if NOT relayed and NOT JIT MODE)
try {
// Get client details
const [client] = await db
@@ -474,58 +458,42 @@ async function handleClientEndpointChange(
return;
}
// Get all non-relayed sites connected to this client
const connectedSites = await db
.select({
siteId: sites.siteId,
newtId: newts.newtId,
isRelayed: clientSitesAssociationsCache.isRelayed,
subnet: clients.subnet
})
.from(clientSitesAssociationsCache)
.innerJoin(
sites,
eq(clientSitesAssociationsCache.siteId, sites.siteId)
)
.innerJoin(newts, eq(newts.siteId, sites.siteId))
.innerJoin(
clients,
eq(clientSitesAssociationsCache.clientId, clients.clientId)
)
.where(
and(
eq(clientSitesAssociationsCache.clientId, clientId),
eq(clientSitesAssociationsCache.isRelayed, false)
)
if (sitesWithNewtsToUpdate.length > 250) {
logger.warn(
`Client ${clientId} has ${sitesWithNewtsToUpdate.length} connected sites so the client will be in jit mode anyway, skipping endpoint updates`
);
return;
}
// Update each non-relayed site with the new client endpoint
for (const siteData of connectedSites) {
try {
if (!siteData.subnet) {
// Update each non-relayed site with the new client endpoint (in parallel)
await Promise.allSettled(
sitesWithNewtsToUpdate.map(async ({ siteId, newtId }) => {
if (!client.pubKey) {
logger.warn(
`Client ${clientId} has no subnet, skipping update for site ${siteData.siteId}`
`Client ${clientId} has no public key, skipping update for site ${siteId}`
);
continue;
return;
}
await updateNewtPeer(
siteData.siteId,
client.pubKey,
{
endpoint: newEndpoint
},
siteData.newtId
);
logger.debug(
`Updated site ${siteData.siteId} with new client ${clientId} endpoint: ${newEndpoint}`
);
} catch (error) {
logger.error(
`Failed to update site ${siteData.siteId} with new client endpoint: ${error}`
);
}
}
try {
await updateNewtPeer(
siteId,
client.pubKey,
{
endpoint: newEndpoint
},
newtId
);
logger.debug(
`Updated site ${siteId} with new client ${clientId} endpoint: ${newEndpoint}`
);
} catch (error) {
logger.error(
`Failed to update site ${siteId} with new client endpoint: ${error}`
);
}
})
);
} catch (error) {
logger.error(
`Error handling client endpoint change for client ${clientId}: ${error}`

View File

@@ -5,6 +5,7 @@ import {
db,
exitNodes,
networks,
SiteResource,
siteNetworks,
siteResources,
sites
@@ -15,7 +16,7 @@ import {
generateRemoteSubnets
} from "@server/lib/ip";
import logger from "@server/logger";
import { and, eq } from "drizzle-orm";
import { eq, inArray } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import config from "@server/lib/config";
@@ -27,11 +28,11 @@ export async function buildSiteConfigurationForOlmClient(
) {
const siteConfigurations: {
siteId: number;
name?: string
endpoint?: string
publicKey?: string
serverIP?: string | null
serverPort?: number | null
name?: string;
endpoint?: string;
publicKey?: string;
serverIP?: string | null;
serverPort?: number | null;
remoteSubnets?: string[];
aliases: Alias[];
}[] = [];
@@ -46,50 +47,79 @@ export async function buildSiteConfigurationForOlmClient(
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
if (sitesData.length === 0) {
return siteConfigurations;
}
// Batch-fetch every site resource this client has access to across ALL sites
// in a single query, then group by siteId in memory. This avoids issuing one
// query per site (which would be N round-trips for N sites).
const allClientSiteResources = await db
.select({
siteResource: siteResources,
siteId: siteNetworks.siteId
})
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.innerJoin(networks, eq(siteResources.networkId, networks.networkId))
.innerJoin(siteNetworks, eq(networks.networkId, siteNetworks.networkId))
.where(
eq(clientSiteResourcesAssociationsCache.clientId, client.clientId)
);
const siteResourcesBySiteId = new Map<number, SiteResource[]>();
for (const row of allClientSiteResources) {
const arr = siteResourcesBySiteId.get(row.siteId);
if (arr) {
arr.push(row.siteResource);
} else {
siteResourcesBySiteId.set(row.siteId, [row.siteResource]);
}
}
// Batch-fetch exit nodes for all sites in one query (only needed in relay mode).
const exitNodesById = new Map<number, typeof exitNodes.$inferSelect>();
if (!jitMode && relay) {
const exitNodeIds = Array.from(
new Set(
sitesData
.map(({ sites: s }) => s.exitNodeId)
.filter((id): id is number => id != null)
)
);
if (exitNodeIds.length > 0) {
const nodes = await db
.select()
.from(exitNodes)
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
for (const n of nodes) {
exitNodesById.set(n.exitNodeId, n);
}
}
}
const clientsStartPort = config.getRawConfig().gerbil.clients_start_port;
const peerOps: Promise<unknown>[] = [];
// Process each site
for (const {
sites: site,
clientSitesAssociationsCache: association
} of sitesData) {
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
eq(networks.networkId, siteNetworks.networkId)
)
.where(
and(
eq(siteNetworks.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
const allSiteResources = siteResourcesBySiteId.get(site.siteId) ?? [];
if (jitMode) {
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(({ siteResources }) => siteResources)
// ),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
// remoteSubnets: generateRemoteSubnets(allSiteResources),
aliases: generateAliasConfig(allSiteResources)
});
continue;
}
@@ -109,10 +139,9 @@ export async function buildSiteConfigurationForOlmClient(
continue;
}
if (!site.publicKey || site.publicKey == "") { // the site is not ready to accept new peers
logger.warn(
`Site ${site.siteId} has no public key, skipping`
);
if (!site.publicKey || site.publicKey == "") {
// the site is not ready to accept new peers
logger.warn(`Site ${site.siteId} has no public key, skipping`);
continue;
}
@@ -128,7 +157,7 @@ export async function buildSiteConfigurationForOlmClient(
logger.info(
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
);
await deletePeer(site.siteId, client.pubKey!);
peerOps.push(deletePeer(site.siteId, client.pubKey!));
}
if (!site.subnet) {
@@ -136,27 +165,19 @@ export async function buildSiteConfigurationForOlmClient(
continue;
}
const [clientSite] = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
)
.limit(1);
// Add the peer to the exit node for this site
if (clientSite.endpoint && publicKey) {
// Add the peer to the exit node for this site. The endpoint comes from
// the already-joined association row above, so no extra query needed.
if (association.endpoint && publicKey) {
logger.info(
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}`
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${association.endpoint}`
);
peerOps.push(
addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: relay ? "" : association.endpoint
})
);
await addPeer(site.siteId, {
publicKey: publicKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
endpoint: relay ? "" : clientSite.endpoint
});
} else {
logger.warn(
`Client ${client.clientId} has no endpoint, skipping peer addition`
@@ -165,16 +186,12 @@ export async function buildSiteConfigurationForOlmClient(
let relayEndpoint: string | undefined = undefined;
if (relay) {
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
const exitNode = exitNodesById.get(site.exitNodeId);
if (!exitNode) {
logger.warn(`Exit node not found for site ${site.siteId}`);
continue;
}
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
relayEndpoint = `${exitNode.endpoint}:${clientsStartPort}`;
}
// Add site configuration to the array
@@ -186,12 +203,16 @@ export async function buildSiteConfigurationForOlmClient(
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: generateRemoteSubnets(
allSiteResources.map(({ siteResources }) => siteResources)
),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
remoteSubnets: generateRemoteSubnets(allSiteResources),
aliases: generateAliasConfig(allSiteResources)
});
}
// Run all peer add/delete operations concurrently rather than serially per
// site, so total time is bounded by the slowest call instead of the sum.
if (peerOps.length > 0) {
Promise.allSettled(peerOps).catch((err) => {
logger.error("Error processing peer operations: ", err);
});
}

View File

@@ -8,7 +8,7 @@ import {
ExitNode,
exitNodes,
sites,
clientSitesAssociationsCache,
clientSitesAssociationsCache
} from "@server/db";
import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode";
@@ -28,6 +28,7 @@ import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import config from "@server/lib/config";
import { APP_VERSION } from "@server/lib/consts";
import { build } from "@server/build";
export const olmGetTokenBodySchema = z.object({
olmId: z.string(),
@@ -220,6 +221,22 @@ export async function getOlmToken(
)
.where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!));
if (clientSites.length > 250 && build == "saas") {
// set all of the cache rows isJitMode to true
await db
.update(clientSitesAssociationsCache)
.set({ isJitMode: true })
.where(
and(
eq(
clientSitesAssociationsCache.clientId,
clientIdToUse!
),
eq(clientSitesAssociationsCache.isJitMode, false)
)
);
}
// Extract unique exit node IDs
const exitNodeIds = Array.from(
new Set(

View File

@@ -1,4 +1,4 @@
import { db, orgs } from "@server/db";
import { db, orgs, primaryDb } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import {
clients,
@@ -7,7 +7,7 @@ import {
olms,
sites
} from "@server/db";
import { count, eq } from "drizzle-orm";
import { and, count, eq, ne, or } from "drizzle-orm";
import logger from "@server/logger";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app";
@@ -81,7 +81,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
.where(eq(olms.olmId, olm.olmId));
}
const [client] = await db
const [client] = await primaryDb // read from the primary here so there is no latency with the last update on the holepunch
.select()
.from(clients)
.where(eq(clients.clientId, olm.clientId))
@@ -98,7 +98,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (client.blocked) {
logger.debug(
`[handleOlmRegisterMessage] Client ${client.clientId} is blocked. Ignoring register.`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId);
return;
@@ -107,7 +107,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (client.approvalState == "pending") {
logger.debug(
`[handleOlmRegisterMessage] Client ${client.clientId} approval is pending. Ignoring register.`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId);
return;
@@ -136,7 +136,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (!org) {
logger.warn("[handleOlmRegisterMessage] Org not found", {
orgId: client.orgId
orgId: client.orgId,
clientId: client.clientId
});
sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId);
return;
@@ -145,7 +146,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (orgId) {
if (!olm.userId) {
logger.warn("[handleOlmRegisterMessage] Olm has no user ID", {
orgId: client.orgId
orgId: client.orgId,
clientId: client.clientId
});
sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId);
return;
@@ -156,7 +158,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (!userSession || !user) {
logger.warn(
"[handleOlmRegisterMessage] Invalid user session for olm register",
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId);
return;
@@ -164,7 +166,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (user.userId !== olm.userId) {
logger.warn(
"[handleOlmRegisterMessage] User ID mismatch for olm register",
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId);
return;
@@ -182,13 +184,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.debug("[handleOlmRegisterMessage] Policy check result", {
orgId: client.orgId,
clientId: client.clientId,
policyCheck
});
if (policyCheck?.error) {
logger.error(
`[handleOlmRegisterMessage] Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return;
@@ -197,7 +200,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (policyCheck.policies?.passwordAge?.compliant === false) {
logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant password age for org ${orgId}`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED,
@@ -209,7 +212,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
) {
logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} has non-compliant session length for org ${orgId}`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED,
@@ -219,7 +222,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} else if (policyCheck.policies?.requiredTwoFactor === false) {
logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
sendOlmError(
OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED,
@@ -229,7 +232,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} else if (!policyCheck.allowed) {
logger.warn(
`[handleOlmRegisterMessage] Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
return;
@@ -253,7 +256,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// Prepare an array to store site configurations
logger.debug(
`[handleOlmRegisterMessage] Found ${sitesCount} sites for client ${client.clientId}`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
let jitMode = false;
@@ -263,19 +266,20 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info(
`[handleOlmRegisterMessage] Too many sites (${sitesCount}), dropping into JIT mode`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
jitMode = true;
}
logger.debug(
`[handleOlmRegisterMessage] Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
if (!publicKey) {
logger.warn("[handleOlmRegisterMessage] Public key not provided", {
orgId: client.orgId
orgId: client.orgId,
clientId: client.clientId
});
return;
}
@@ -283,7 +287,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (client.pubKey !== publicKey || client.archived) {
logger.info(
"[handleOlmRegisterMessage] Public key mismatch. Updating public key and clearing session info...",
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
// Update the client's public key
await db
@@ -301,7 +305,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
isRelayed: relay == true,
isJitMode: jitMode
})
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
or(
ne(
clientSitesAssociationsCache.isRelayed,
relay == true
),
ne(clientSitesAssociationsCache.isJitMode, jitMode)
)
)
);
}
// this prevents us from accepting a register from an olm that has not hole punched yet.
@@ -310,7 +325,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
logger.warn(
`[handleOlmRegisterMessage] Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`,
{ orgId: client.orgId }
{ orgId: client.orgId, clientId: client.clientId }
);
return;
}

View File

@@ -17,7 +17,7 @@ import { initPeerAddHandshake } from "./peers";
export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
context
) => {
logger.info("Handling register olm message!");
logger.info("Handle Olm Server Init Add Peer Handshake Message");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;

View File

@@ -9,16 +9,50 @@ import {
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm";
import { count, eq, inArray } from "drizzle-orm";
import config from "@server/lib/config";
import { canCompress } from "@server/lib/clientVersionChecks";
import { build } from "@server/build";
export async function sendOlmSyncMessage(olm: Olm, client: Client) {
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(
`[handleOlmRegisterMessage] Found ${sitesCount} sites for client ${client.clientId}`,
{ orgId: client.orgId }
);
let jitMode = false;
if (sitesCount > 250 && build == "saas") {
// THIS IS THE MAX ON THE BUSINESS TIER
// we have too many sites
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info(
`[handleOlmRegisterMessage] Too many sites (${sitesCount}), dropping into JIT mode`,
{ orgId: client.orgId }
);
jitMode = true;
}
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
const siteConfigurations = await buildSiteConfigurationForOlmClient(
client,
client.pubKey,
false
false,
jitMode
);
// Get all exit nodes from sites where the client has peers
@@ -82,7 +116,6 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
exitNodes: exitNodesData
}
},
{
compress: canCompress(olm.version, "olm")
}

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { db } from "@server/db";
import { and, eq, or, inArray } from "drizzle-orm";
import { db, DB_TYPE } from "@server/db";
import { and, eq, or, inArray, sql } from "drizzle-orm";
import {
resources,
userResources,
@@ -12,7 +12,9 @@ import {
resourceWhitelist,
siteResources,
userSiteResources,
roleSiteResources
roleSiteResources,
siteNetworks,
sites
} from "@server/db";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
@@ -156,9 +158,24 @@ export async function getUserResources(
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
tcpPortRangeString: string | null;
udpPortRangeString: string | null;
disableIcmp: boolean | null;
siteIds: number[];
siteNames: string[];
siteNiceIds: string[];
siteAddresses: (string | null)[];
siteOnlines: boolean[];
}> = [];
if (accessibleSiteResourceIds.length > 0) {
siteResourcesData = await db
const aggCol = <T>(column: any) => {
if (DB_TYPE === "sqlite") {
return sql<T>`json_group_array(${column})`;
}
return sql<T>`COALESCE(array_agg(${column}) FILTER (WHERE ${sites.siteId} IS NOT NULL), '{}')`;
};
const siteResourcesRaw = await db
.select({
siteResourceId: siteResources.siteResourceId,
name: siteResources.name,
@@ -170,9 +187,22 @@ export async function getUserResources(
fullDomain: siteResources.fullDomain,
enabled: siteResources.enabled,
alias: siteResources.alias,
aliasAddress: siteResources.aliasAddress
aliasAddress: siteResources.aliasAddress,
tcpPortRangeString: siteResources.tcpPortRangeString,
udpPortRangeString: siteResources.udpPortRangeString,
disableIcmp: siteResources.disableIcmp,
siteIds: aggCol<number[]>(sites.siteId),
siteNames: aggCol<string[]>(sites.name),
siteNiceIds: aggCol<string[]>(sites.niceId),
siteAddresses: aggCol<(string | null)[]>(sites.address),
siteOnlines: aggCol<boolean[]>(sites.online)
})
.from(siteResources)
.leftJoin(
siteNetworks,
eq(siteResources.networkId, siteNetworks.networkId)
)
.leftJoin(sites, eq(siteNetworks.siteId, sites.siteId))
.where(
and(
inArray(
@@ -182,7 +212,55 @@ export async function getUserResources(
eq(siteResources.orgId, orgId),
eq(siteResources.enabled, true)
)
);
)
.groupBy(siteResources.siteResourceId);
siteResourcesData = siteResourcesRaw.map((row: any) => {
if (DB_TYPE !== "sqlite") {
return row;
}
const siteIdsRaw = JSON.parse(row.siteIds) as (number | null)[];
const siteNamesRaw = JSON.parse(row.siteNames) as (
| string
| null
)[];
const siteNiceIdsRaw = JSON.parse(row.siteNiceIds) as (
| string
| null
)[];
const siteAddressesRaw = JSON.parse(row.siteAddresses) as (
| string
| null
)[];
const siteOnlinesRaw = JSON.parse(row.siteOnlines) as (
| 0
| 1
| null
)[];
const siteIds: number[] = [];
const siteNames: string[] = [];
const siteNiceIds: string[] = [];
const siteAddresses: (string | null)[] = [];
const siteOnlines: boolean[] = [];
for (let i = 0; i < siteIdsRaw.length; i++) {
if (siteIdsRaw[i] == null) continue;
siteIds.push(siteIdsRaw[i] as number);
siteNames.push((siteNamesRaw[i] ?? "") as string);
siteNiceIds.push((siteNiceIdsRaw[i] ?? "") as string);
siteAddresses.push(siteAddressesRaw[i] ?? null);
siteOnlines.push(siteOnlinesRaw[i] === 1);
}
return {
...row,
siteIds,
siteNames,
siteNiceIds,
siteAddresses,
siteOnlines
};
});
}
// Check for password, pincode, and whitelist protection for each resource
@@ -260,6 +338,14 @@ export async function getUserResources(
enabled: siteResource.enabled,
alias: siteResource.alias,
aliasAddress: siteResource.aliasAddress,
tcpPortRangeString: siteResource.tcpPortRangeString,
udpPortRangeString: siteResource.udpPortRangeString,
disableIcmp: siteResource.disableIcmp,
siteIds: siteResource.siteIds,
siteNames: siteResource.siteNames,
siteNiceIds: siteResource.siteNiceIds,
siteAddresses: siteResource.siteAddresses,
siteOnlines: siteResource.siteOnlines,
type: "site" as const
};
});
@@ -302,11 +388,19 @@ export type GetUserResourcesResponse = {
destination: string;
mode: string;
protocol: string | null;
tcpPortRangeString: string | null;
udpPortRangeString: string | null;
disableIcmp: boolean | null;
ssl: boolean;
fullDomain: string | null;
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
siteIds: number[];
siteNames: string[];
siteNiceIds: string[];
siteAddresses: (string | null)[];
siteOnlines: boolean[];
type: "site";
}>;
};