diff --git a/server/routers/newt/handleGetConfigMessage.ts b/server/routers/newt/handleGetConfigMessage.ts index bfe14ec5..927fed50 100644 --- a/server/routers/newt/handleGetConfigMessage.ts +++ b/server/routers/newt/handleGetConfigMessage.ts @@ -7,7 +7,8 @@ import { ExitNode, exitNodes, siteResources, - clientSiteResourcesAssociationsCache + clientSiteResourcesAssociationsCache, + Site } from "@server/db"; import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db"; import { eq } from "drizzle-orm"; @@ -130,6 +131,38 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { } } + const { peers, targets } = await buildClientConfigurationForNewtClient( + site, + exitNode + ); + + // Build the configuration response + const configResponse = { + ipAddress: site.address, + peers, + targets + }; + + logger.debug("Sending config: ", configResponse); + + return { + message: { + type: "newt/wg/receive-config", + data: { + ...configResponse + } + }, + broadcast: false, + excludeSender: false, + }; +}; + +export async function buildClientConfigurationForNewtClient( + site: Site, + exitNode?: ExitNode +) { + const siteId = site.siteId; + // Get all clients connected to this site const clientsRes = await db .select() @@ -278,22 +311,8 @@ export const handleGetConfigMessage: MessageHandler = async (context) => { targetsToSend.push(...resourceTargets); } - // Build the configuration response - const configResponse = { - ipAddress: site.address, + return { peers: validPeers, targets: targetsToSend }; - - logger.debug("Sending config: ", configResponse); - return { - message: { - type: "newt/wg/receive-config", - data: { - ...configResponse - } - }, - broadcast: false, - excludeSender: false - }; -}; +} diff --git a/server/routers/newt/handleNewtPingMessage.ts b/server/routers/newt/handleNewtPingMessage.ts new file mode 100644 index 00000000..e7dea7ce --- /dev/null +++ b/server/routers/newt/handleNewtPingMessage.ts @@ -0,0 +1,141 @@ +import { db } from "@server/db"; +import { disconnectClient } from "#dynamic/routers/ws"; +import { getClientConfigVersion, MessageHandler } from "@server/routers/ws"; +import { clients, Newt } from "@server/db"; +import { eq, lt, isNull, and, or } from "drizzle-orm"; +import logger from "@server/logger"; +import { validateSessionToken } from "@server/auth/sessions/app"; +import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy"; +import { sendTerminateClient } from "../client/terminate"; +import { encodeHexLowerCase } from "@oslojs/encoding"; +import { sha256 } from "@oslojs/crypto/sha2"; + +// Track if the offline checker interval is running +// let offlineCheckerInterval: NodeJS.Timeout | null = null; +// const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds +// const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes + +/** + * Starts the background interval that checks for clients that haven't pinged recently + * and marks them as offline + */ +// export const startNewtOfflineChecker = (): void => { +// if (offlineCheckerInterval) { +// return; // Already running +// } + +// offlineCheckerInterval = setInterval(async () => { +// try { +// const twoMinutesAgo = Math.floor( +// (Date.now() - OFFLINE_THRESHOLD_MS) / 1000 +// ); + +// // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING + +// // Find clients that haven't pinged in the last 2 minutes and mark them as offline +// const offlineClients = await db +// .update(clients) +// .set({ online: false }) +// .where( +// and( +// eq(clients.online, true), +// or( +// lt(clients.lastPing, twoMinutesAgo), +// isNull(clients.lastPing) +// ) +// ) +// ) +// .returning(); + +// for (const offlineClient of offlineClients) { +// logger.info( +// `Kicking offline newt client ${offlineClient.clientId} due to inactivity` +// ); + +// if (!offlineClient.newtId) { +// logger.warn( +// `Offline client ${offlineClient.clientId} has no newtId, cannot disconnect` +// ); +// continue; +// } + +// // Send a disconnect message to the client if connected +// try { +// await sendTerminateClient( +// offlineClient.clientId, +// offlineClient.newtId +// ); // terminate first +// // wait a moment to ensure the message is sent +// await new Promise((resolve) => setTimeout(resolve, 1000)); +// await disconnectClient(offlineClient.newtId); +// } catch (error) { +// logger.error( +// `Error sending disconnect to offline newt ${offlineClient.clientId}`, +// { error } +// ); +// } +// } +// } catch (error) { +// logger.error("Error in offline checker interval", { error }); +// } +// }, OFFLINE_CHECK_INTERVAL); + +// logger.debug("Started offline checker interval"); +// }; + +/** + * Stops the background interval that checks for offline clients + */ +// export const stopNewtOfflineChecker = (): void => { +// if (offlineCheckerInterval) { +// clearInterval(offlineCheckerInterval); +// offlineCheckerInterval = null; +// logger.info("Stopped offline checker interval"); +// } +// }; + +/** + * Handles ping messages from clients and responds with pong + */ +export const handleNewtPingMessage: MessageHandler = async (context) => { + const { message, client: c, sendToClient } = context; + const newt = c as Newt; + + if (!newt) { + logger.warn("Newt not found"); + return; + } + + // get the version + const configVersion = await getClientConfigVersion(newt.newtId); + + if (message.configVersion && configVersion != message.configVersion) { + logger.warn(`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`); + + // TODO: sync the client + } + + // try { + // // Update the client's last ping timestamp + // await db + // .update(clients) + // .set({ + // lastPing: Math.floor(Date.now() / 1000), + // online: true + // }) + // .where(eq(clients.clientId, newt.clientId)); + // } catch (error) { + // logger.error("Error handling ping message", { error }); + // } + + return { + message: { + type: "pong", + data: { + timestamp: new Date().toISOString() + } + }, + broadcast: false, + excludeSender: false + }; +}; diff --git a/server/routers/newt/handleNewtRegisterMessage.ts b/server/routers/newt/handleNewtRegisterMessage.ts index c7f2131e..28f6e64a 100644 --- a/server/routers/newt/handleNewtRegisterMessage.ts +++ b/server/routers/newt/handleNewtRegisterMessage.ts @@ -233,6 +233,35 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => { .where(eq(newts.newtId, newt.newtId)); } + const { tcpTargets, udpTargets, validHealthCheckTargets } = + await buildTargetConfigurationForNewtClient(siteId); + + logger.debug( + `Sending health check targets to newt ${newt.newtId}: ${JSON.stringify(validHealthCheckTargets)}` + ); + + return { + message: { + type: "newt/wg/connect", + data: { + endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`, + relayPort: config.getRawConfig().gerbil.clients_start_port, + publicKey: exitNode.publicKey, + serverIP: exitNode.address.split("/")[0], + tunnelIP: siteSubnet.split("/")[0], + targets: { + udp: udpTargets, + tcp: tcpTargets + }, + healthCheckTargets: validHealthCheckTargets + } + }, + broadcast: false, // Send to all clients + excludeSender: false // Include sender in broadcast + }; +}; + +export async function buildTargetConfigurationForNewtClient(siteId: number) { // Get all enabled targets with their resource protocol information const allTargets = await db .select({ @@ -337,30 +366,12 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => { (target) => target !== null ); - logger.debug( - `Sending health check targets to newt ${newt.newtId}: ${JSON.stringify(validHealthCheckTargets)}` - ); - return { - message: { - type: "newt/wg/connect", - data: { - endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`, - relayPort: config.getRawConfig().gerbil.clients_start_port, - publicKey: exitNode.publicKey, - serverIP: exitNode.address.split("/")[0], - tunnelIP: siteSubnet.split("/")[0], - targets: { - udp: udpTargets, - tcp: tcpTargets - }, - healthCheckTargets: validHealthCheckTargets - } - }, - broadcast: false, // Send to all clients - excludeSender: false // Include sender in broadcast + validHealthCheckTargets, + tcpTargets, + udpTargets }; -}; +} async function getUniqueSubnetForSite( exitNode: ExitNode, diff --git a/server/routers/newt/index.ts b/server/routers/newt/index.ts index 6b17f324..8ff1b61a 100644 --- a/server/routers/newt/index.ts +++ b/server/routers/newt/index.ts @@ -6,3 +6,4 @@ export * from "./handleGetConfigMessage"; export * from "./handleSocketMessages"; export * from "./handleNewtPingRequestMessage"; export * from "./handleApplyBlueprintMessage"; +export * from "./handleNewtPingMessage"; diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index 0fa490c8..46b071c9 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -1,6 +1,6 @@ import { db } from "@server/db"; import { disconnectClient } from "#dynamic/routers/ws"; -import { MessageHandler } from "@server/routers/ws"; +import { getClientConfigVersion, MessageHandler } from "@server/routers/ws"; import { clients, Olm } from "@server/db"; import { eq, lt, isNull, and, or } from "drizzle-orm"; import logger from "@server/logger"; @@ -108,6 +108,15 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { return; } + // get the version + const configVersion = await getClientConfigVersion(olm.olmId); + + if (message.configVersion && configVersion != message.configVersion) { + logger.warn(`Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})`); + + // TODO: sync the client + } + if (olm.userId) { // we need to check a user token to make sure its still valid const { session: userSession, user } = diff --git a/server/routers/olm/handleOlmRegisterMessage.ts b/server/routers/olm/handleOlmRegisterMessage.ts index 0f71ee8b..a662383a 100644 --- a/server/routers/olm/handleOlmRegisterMessage.ts +++ b/server/routers/olm/handleOlmRegisterMessage.ts @@ -1,4 +1,5 @@ import { + Client, clientSiteResourcesAssociationsCache, db, orgs, @@ -13,7 +14,7 @@ import { olms, sites } from "@server/db"; -import { and, eq, inArray, isNull } from "drizzle-orm"; +import { and, count, eq, inArray, isNull } from "drizzle-orm"; import { addPeer, deletePeer } from "../newt/peers"; import logger from "@server/logger"; import { generateAliasConfig } from "@server/lib/ip"; @@ -144,6 +145,64 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); } + // Get all sites data + const sitesCountResult = await db + .select({ count: count() }) + .from(sites) + .innerJoin( + clientSitesAssociationsCache, + eq(sites.siteId, clientSitesAssociationsCache.siteId) + ) + .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); + + // Extract the count value from the result array + const sitesCount = sitesCountResult.length > 0 ? sitesCountResult[0].count : 0; + + // Prepare an array to store site configurations + logger.debug( + `Found ${sitesCount} sites for client ${client.clientId}` + ); + + // this prevents us from accepting a register from an olm that has not hole punched yet. + // the olm will pump the register so we can keep checking + // TODO: I still think there is a better way to do this rather than locking it out here but ??? + if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) { + logger.warn( + "Client last hole punch is too old and we have sites to send; skipping this register" + ); + return; + } + + const siteConfigurations = await buildSiteConfigurationForOlmClient(client, publicKey, relay); + + // REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES + // if (siteConfigurations.length === 0) { + // logger.warn("No valid site configurations found"); + // return; + // } + + // Return connect message with all site configurations + return { + message: { + type: "olm/wg/connect", + data: { + sites: siteConfigurations, + tunnelIP: client.subnet, + utilitySubnet: org.utilitySubnet + } + }, + broadcast: false, + excludeSender: false + }; +}; + +export async function buildSiteConfigurationForOlmClient( + client: Client, + publicKey: string, + relay: boolean +) { + const siteConfigurations = []; + // Get all sites data const sitesData = await db .select() @@ -154,22 +213,6 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { ) .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); - // Prepare an array to store site configurations - const siteConfigurations = []; - logger.debug( - `Found ${sitesData.length} sites for client ${client.clientId}` - ); - - // this prevents us from accepting a register from an olm that has not hole punched yet. - // the olm will pump the register so we can keep checking - // TODO: I still think there is a better way to do this rather than locking it out here but ??? - if (now - (client.lastHolePunch || 0) > 5 && sitesData.length > 0) { - logger.warn( - "Client last hole punch is too old and we have sites to send; skipping this register" - ); - return; - } - // Process each site for (const { sites: site, @@ -289,23 +332,5 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => { }); } - // REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES - // if (siteConfigurations.length === 0) { - // logger.warn("No valid site configurations found"); - // return; - // } - - // Return connect message with all site configurations - return { - message: { - type: "olm/wg/connect", - data: { - sites: siteConfigurations, - tunnelIP: client.subnet, - utilitySubnet: org.utilitySubnet - } - }, - broadcast: false, - excludeSender: false - }; -}; + return siteConfigurations; +} diff --git a/server/routers/ws/messageHandlers.ts b/server/routers/ws/messageHandlers.ts index acd1aef0..bcf0b4dc 100644 --- a/server/routers/ws/messageHandlers.ts +++ b/server/routers/ws/messageHandlers.ts @@ -5,7 +5,8 @@ import { handleDockerStatusMessage, handleDockerContainersMessage, handleNewtPingRequestMessage, - handleApplyBlueprintMessage + handleApplyBlueprintMessage, + handleNewtPingMessage } from "../newt"; import { handleOlmRegisterMessage, @@ -24,6 +25,7 @@ export const messageHandlers: Record = { "olm/wg/relay": handleOlmRelayMessage, "olm/wg/unrelay": handleOlmUnRelayMessage, "olm/ping": handleOlmPingMessage, + "newt/ping": handleNewtPingMessage, "newt/wg/register": handleNewtRegisterMessage, "newt/wg/get-config": handleGetConfigMessage, "newt/receive-bandwidth": handleReceiveBandwidthMessage, diff --git a/server/routers/ws/types.ts b/server/routers/ws/types.ts index 5cca3c09..81d3bd49 100644 --- a/server/routers/ws/types.ts +++ b/server/routers/ws/types.ts @@ -52,7 +52,11 @@ export interface HandlerContext { senderWs: WebSocket; client: Newt | Olm | RemoteExitNode | undefined; clientType: ClientType; - sendToClient: (clientId: string, message: WSMessage, options?: SendMessageOptions) => Promise; + sendToClient: ( + clientId: string, + message: WSMessage, + options?: SendMessageOptions + ) => Promise; broadcastToAllExcept: ( message: WSMessage, excludeClientId?: string, diff --git a/server/routers/ws/ws.ts b/server/routers/ws/ws.ts index f707848c..063202db 100644 --- a/server/routers/ws/ws.ts +++ b/server/routers/ws/ws.ts @@ -95,7 +95,7 @@ const sendToClientLocal = async ( if (!clients || clients.length === 0) { return false; } - + // Increment config version if requested if (options.incrementConfigVersion) { const currentVersion = clientConfigVersions.get(clientId) || 0; @@ -106,14 +106,14 @@ const sendToClientLocal = async ( client.configVersion = newVersion; }); } - + // Include config version in message const configVersion = clientConfigVersions.get(clientId) || 0; const messageWithVersion = { ...message, configVersion }; - + const messageString = JSON.stringify(messageWithVersion); clients.forEach((client) => { if (client.readyState === WebSocket.OPEN) { @@ -189,7 +189,7 @@ const hasActiveConnections = async (clientId: string): Promise => { }; // Get the current config version for a client -const getClientConfigVersion = (clientId: string): number => { +const getClientConfigVersion = async (clientId: string): Promise => { return clientConfigVersions.get(clientId) || 0; };