diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index 51cdc8b4e..f7fdb81a8 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -13,7 +13,7 @@ import { import { olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; import response from "@server/lib/response"; -import { and, eq, inArray } from "drizzle-orm"; +import { and, count, eq, inArray } from "drizzle-orm"; import { NextFunction, Request, Response } from "express"; import createHttpError from "http-errors"; import { z } from "zod"; @@ -24,6 +24,7 @@ import { EXPIRES } from "@server/auth/sessions/olm"; import { getOrCreateCachedToken } from "#dynamic/lib/tokenCache"; +import { listExitNodes } from "#dynamic/lib/exitNodes"; import { verifyPassword } from "@server/auth/password"; import logger from "@server/logger"; import config from "@server/lib/config"; @@ -150,6 +151,7 @@ export async function getOlmToken( ); let clientIdToUse; + let orgIdToUse: string; if (orgId) { // we did provide the org const [client] = await db @@ -183,6 +185,7 @@ export async function getOlmToken( } clientIdToUse = client.clientId; + orgIdToUse = orgId; } else { if (!existingOlm.clientId) { return next( @@ -209,6 +212,7 @@ export async function getOlmToken( } clientIdToUse = client.clientId; + orgIdToUse = client.orgId; } // Get all exit nodes from sites where the client has peers @@ -265,7 +269,7 @@ export async function getOlmToken( } } - const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => { + let exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => { return { publicKey: exitNode.publicKey, relayPort: config.getRawConfig().gerbil.clients_start_port, @@ -274,6 +278,73 @@ export async function getOlmToken( }; }); + // If no exit nodes were found for the client's sites, fall back to + // finding an available node in the same region (as newt does on ping). + if (exitNodesHpData.length === 0) { + logger.debug( + `No exit nodes found for olm ${olmId} client sites; falling back to region node selection` + ); + const fallbackNodes = await listExitNodes(orgIdToUse!, true); + + const weightedNodes = await Promise.all( + fallbackNodes.map(async (node) => { + let weight = 1; + const maxConnections = node.maxConnections; + if ( + maxConnections !== null && + maxConnections !== undefined + ) { + const [currentConnections] = await db + .select({ count: count() }) + .from(sites) + .where( + and( + eq(sites.exitNodeId, node.exitNodeId), + eq(sites.online, true) + ) + ); + if (currentConnections.count >= maxConnections) { + return null; + } + weight = + (maxConnections - currentConnections.count) / + maxConnections; + } + return { node, weight }; + }) + ); + + const availableNodes = weightedNodes + .filter( + ( + n + ): n is { + node: (typeof fallbackNodes)[0]; + weight: number; + } => n !== null + ) + .sort((a, b) => b.weight - a.weight); + + if (availableNodes.length > 0) { + const best = availableNodes[0].node; + exitNodesHpData = [ + { + publicKey: best.publicKey, + relayPort: + config.getRawConfig().gerbil.clients_start_port, + endpoint: best.endpoint, + siteIds: [] + // it should still HP without the site ids but it will get stuck in the client + // if a site is removed or something because its not tied to a site which is okay for the session + } + ]; + } else { + logger.warn( + `No available fallback exit nodes found for olm ${olmId}` + ); + } + } + logger.debug("Token created successfully"); return response<{