Merge pull request #3073 from fosrl/dev

Further optimizations
This commit is contained in:
Owen Schwartz
2026-05-14 12:00:25 -07:00
committed by GitHub
3 changed files with 62 additions and 27 deletions

View File

@@ -411,12 +411,14 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
return; return;
} }
// Get all non-relayed clients connected to this site // Get all non-relayed and not jit clients connected to this site
const connectedClients = await db const connectedClients = await db
.select({ .select({
online: clients.online,
clientId: clients.clientId, clientId: clients.clientId,
olmId: olms.olmId, olmId: olms.olmId,
isRelayed: clientSitesAssociationsCache.isRelayed isRelayed: clientSitesAssociationsCache.isRelayed,
isJitMode: clientSitesAssociationsCache.isJitMode
}) })
.from(clientSitesAssociationsCache) .from(clientSitesAssociationsCache)
.innerJoin( .innerJoin(
@@ -426,32 +428,36 @@ async function handleSiteEndpointChange(siteId: number, newEndpoint: string) {
.innerJoin(olms, eq(olms.clientId, clients.clientId)) .innerJoin(olms, eq(olms.clientId, clients.clientId))
.where( .where(
and( and(
eq(clients.online, true), // the client has to be online or it does not matter...
eq(clientSitesAssociationsCache.siteId, siteId), eq(clientSitesAssociationsCache.siteId, siteId),
eq(clientSitesAssociationsCache.isRelayed, false) eq(clientSitesAssociationsCache.isRelayed, false),
eq(clientSitesAssociationsCache.isJitMode, false)
) )
); );
// Update each non-relayed client with the new site endpoint // Update each non-relayed client with the new site endpoint (in parallel)
for (const client of connectedClients) { await Promise.allSettled(
try { connectedClients.map(async (client) => {
await updateOlmPeer( try {
client.clientId, await updateOlmPeer(
{ client.clientId,
siteId: siteId, {
publicKey: site.publicKey, siteId: siteId,
endpoint: newEndpoint publicKey: site.publicKey!,
}, endpoint: newEndpoint
client.olmId },
); client.olmId
logger.debug( );
`Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}` logger.debug(
); `Updated client ${client.clientId} with new site ${siteId} endpoint: ${newEndpoint}`
} catch (error) { );
logger.error( } catch (error) {
`Failed to update client ${client.clientId} with new site endpoint: ${error}` logger.error(
); `Failed to update client ${client.clientId} with new site endpoint: ${error}`
} );
} }
})
);
} catch (error) { } catch (error) {
logger.error( logger.error(
`Error handling site endpoint change for site ${siteId}: ${error}` `Error handling site endpoint change for site ${siteId}: ${error}`
@@ -498,6 +504,7 @@ async function handleClientEndpointChange( // TODO: I THINK WE DONT NEED TO HIT
) )
.where( .where(
and( and(
eq(sites.online, true), // the site has to be online or it does not matter...
eq(clientSitesAssociationsCache.clientId, clientId), eq(clientSitesAssociationsCache.clientId, clientId),
eq(clientSitesAssociationsCache.isRelayed, false), eq(clientSitesAssociationsCache.isRelayed, false),
eq(clientSitesAssociationsCache.isJitMode, false) eq(clientSitesAssociationsCache.isJitMode, false)

View File

@@ -8,7 +8,7 @@ import {
ExitNode, ExitNode,
exitNodes, exitNodes,
sites, sites,
clientSitesAssociationsCache, clientSitesAssociationsCache
} from "@server/db"; } from "@server/db";
import { olms } from "@server/db"; import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -28,6 +28,7 @@ import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger"; import logger from "@server/logger";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { APP_VERSION } from "@server/lib/consts"; import { APP_VERSION } from "@server/lib/consts";
import { build } from "@server/build";
export const olmGetTokenBodySchema = z.object({ export const olmGetTokenBodySchema = z.object({
olmId: z.string(), olmId: z.string(),
@@ -220,6 +221,22 @@ export async function getOlmToken(
) )
.where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!)); .where(eq(clientSitesAssociationsCache.clientId, clientIdToUse!));
if (clientSites.length > 250 && build == "saas") {
// set all of the cache rows isJitMode to true
await db
.update(clientSitesAssociationsCache)
.set({ isJitMode: true })
.where(
and(
eq(
clientSitesAssociationsCache.clientId,
clientIdToUse!
),
eq(clientSitesAssociationsCache.isJitMode, false)
)
);
}
// Extract unique exit node IDs // Extract unique exit node IDs
const exitNodeIds = Array.from( const exitNodeIds = Array.from(
new Set( new Set(

View File

@@ -7,7 +7,7 @@ import {
olms, olms,
sites sites
} from "@server/db"; } from "@server/db";
import { count, eq } from "drizzle-orm"; import { and, count, eq, ne, or } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy"; import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app"; import { validateSessionToken } from "@server/auth/sessions/app";
@@ -301,7 +301,18 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
isRelayed: relay == true, isRelayed: relay == true,
isJitMode: jitMode isJitMode: jitMode
}) })
.where(eq(clientSitesAssociationsCache.clientId, client.clientId)); .where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
or(
ne(
clientSitesAssociationsCache.isRelayed,
relay == true
),
ne(clientSitesAssociationsCache.isJitMode, jitMode)
)
)
);
} }
// this prevents us from accepting a register from an olm that has not hole punched yet. // this prevents us from accepting a register from an olm that has not hole punched yet.