From 3debc6c8d3ba00ed92da296daa854409af125db8 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 16 Feb 2026 20:29:55 -0800 Subject: [PATCH] Add round trip tracking for any message --- server/db/pg/schema/schema.ts | 14 +- server/db/sqlite/schema/schema.ts | 16 +- server/private/routers/ssh/signSshKey.ts | 177 ++++++++++++++++---- server/routers/external.ts | 3 + server/routers/ws/checkRoundTripMessage.ts | 85 ++++++++++ server/routers/ws/handleRoundTripMessage.ts | 49 ++++++ server/routers/ws/index.ts | 1 + server/routers/ws/messageHandlers.ts | 4 +- 8 files changed, 317 insertions(+), 32 deletions(-) create mode 100644 server/routers/ws/checkRoundTripMessage.ts create mode 100644 server/routers/ws/handleRoundTripMessage.ts diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index 4188d894..ca46e207 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -330,7 +330,8 @@ export const userOrgs = pgTable("userOrgs", { .notNull() .references(() => roles.roleId), isOwner: boolean("isOwner").notNull().default(false), - autoProvisioned: boolean("autoProvisioned").default(false) + autoProvisioned: boolean("autoProvisioned").default(false), + pamUsername: varchar("pamUsername") // cleaned username for ssh and such }); export const emailVerificationCodes = pgTable("emailVerificationCodes", { @@ -986,6 +987,16 @@ export const deviceWebAuthCodes = pgTable("deviceWebAuthCodes", { }) }); +export const roundTripMessageTracker = pgTable("roundTripMessageTracker", { + messageId: serial("messageId").primaryKey(), + wsClientId: varchar("clientId"), + messageType: varchar("messageType"), + sentAt: bigint("sentAt", { mode: "number" }).notNull(), + receivedAt: bigint("receivedAt", { mode: "number" }), + error: text("error"), + complete: boolean("complete").notNull().default(false) +}); + export type Org = InferSelectModel; export type User = InferSelectModel; export type Site = InferSelectModel; @@ -1046,3 +1057,4 @@ export type SecurityKey = InferSelectModel; export type WebauthnChallenge = InferSelectModel; export type DeviceWebAuthCode = InferSelectModel; export type RequestAuditLog = InferSelectModel; +export type RoundTripMessageTracker = InferSelectModel; diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index 6d60ec68..ce08dea1 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -635,7 +635,8 @@ export const userOrgs = sqliteTable("userOrgs", { isOwner: integer("isOwner", { mode: "boolean" }).notNull().default(false), autoProvisioned: integer("autoProvisioned", { mode: "boolean" - }).default(false) + }).default(false), + pamUsername: text("pamUsername") // cleaned username for ssh and such }); export const emailVerificationCodes = sqliteTable("emailVerificationCodes", { @@ -1077,6 +1078,16 @@ export const deviceWebAuthCodes = sqliteTable("deviceWebAuthCodes", { }) }); +export const roundTripMessageTracker = sqliteTable("roundTripMessageTracker", { + messageId: integer("messageId").primaryKey({ autoIncrement: true }), + wsClientId: text("clientId"), + messageType: text("messageType"), + sentAt: integer("sentAt").notNull(), + receivedAt: integer("receivedAt"), + error: text("error"), + complete: integer("complete", { mode: "boolean" }).notNull().default(false) +}); + export type Org = InferSelectModel; export type User = InferSelectModel; export type Site = InferSelectModel; @@ -1138,3 +1149,6 @@ export type SecurityKey = InferSelectModel; export type WebauthnChallenge = InferSelectModel; export type RequestAuditLog = InferSelectModel; export type DeviceWebAuthCode = InferSelectModel; +export type RoundTripMessageTracker = InferSelectModel< + typeof roundTripMessageTracker +>; diff --git a/server/private/routers/ssh/signSshKey.ts b/server/private/routers/ssh/signSshKey.ts index 593a83bb..378c3576 100644 --- a/server/private/routers/ssh/signSshKey.ts +++ b/server/private/routers/ssh/signSshKey.ts @@ -13,7 +13,7 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db, orgs, siteResources } from "@server/db"; +import { db, newts, orgs, roundTripMessageTracker, siteResources, sites, userOrgs } from "@server/db"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -24,6 +24,7 @@ import { eq, or, and } from "drizzle-orm"; import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource"; import { signPublicKey, getOrgCAKeys } from "#private/lib/sshCA"; import config from "@server/lib/config"; +import { sendToClient } from "#dynamic/routers/ws"; const paramsSchema = z.strictObject({ orgId: z.string().nonempty() @@ -49,6 +50,7 @@ const bodySchema = z export type SignSshKeyResponse = { certificate: string; + messageId: number; sshUsername: string; sshHost: string; resourceId: number; @@ -118,6 +120,104 @@ export async function signSshKey( ); } + const [userOrg] = await db + .select() + .from(userOrgs) + .where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId))) + .limit(1); + + if (!userOrg) { + return next( + createHttpError( + HttpCode.FORBIDDEN, + "User does not belong to the specified organization" + ) + ); + } + + let usernameToUse; + if (!userOrg.pamUsername) { + if (req.user?.email) { + // Extract username from email (first part before @) + usernameToUse = req.user?.email.split("@")[0]; + if (!usernameToUse) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Unable to extract username from email" + ) + ); + } + } else if (req.user?.username) { + usernameToUse = req.user.username; + // We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates + usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, ""); + if (!usernameToUse) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Username is not valid for SSH certificate" + ) + ); + } + } else { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "User does not have a valid email or username for SSH certificate" + ) + ); + } + + // check if we have a existing user in this org with the same + const [existingUserWithSameName] = await db + .select() + .from(userOrgs) + .where( + and( + eq(userOrgs.orgId, orgId), + eq(userOrgs.pamUsername, usernameToUse) + ) + ) + .limit(1); + + if (existingUserWithSameName) { + let foundUniqueUsername = false; + for (let attempt = 0; attempt < 20; attempt++) { + const randomNum = Math.floor(Math.random() * 101); // 0 to 100 + const candidateUsername = `${usernameToUse}${randomNum}`; + + const [existingUser] = await db + .select() + .from(userOrgs) + .where( + and( + eq(userOrgs.orgId, orgId), + eq(userOrgs.pamUsername, candidateUsername) + ) + ) + .limit(1); + + if (!existingUser) { + usernameToUse = candidateUsername; + foundUniqueUsername = true; + break; + } + } + + if (!foundUniqueUsername) { + return next( + createHttpError( + HttpCode.CONFLICT, + "Unable to generate a unique username for SSH certificate" + ) + ); + } + } + } else { + usernameToUse = userOrg.pamUsername; + } + // Get and decrypt the org's CA keys const caKeys = await getOrgCAKeys( orgId, @@ -201,35 +301,18 @@ export async function signSshKey( ); } - let usernameToUse; - if (req.user?.email) { - // Extract username from email (first part before @) - usernameToUse = req.user?.email.split("@")[0]; - if (!usernameToUse) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - "Unable to extract username from email" - ) - ); - } - } else if (req.user?.username) { - usernameToUse = req.user.username; - // We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates - usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, ""); - if (!usernameToUse) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - "Username is not valid for SSH certificate" - ) - ); - } - } else { + // get the site + const [newt] = await db + .select() + .from(newts) + .where(eq(newts.siteId, resource.siteId)) + .limit(1); + + if (!newt) { return next( createHttpError( - HttpCode.BAD_REQUEST, - "User does not have a valid email or username for SSH certificate" + HttpCode.INTERNAL_SERVER_ERROR, + "Site associated with resource not found" ) ); } @@ -240,17 +323,53 @@ export async function signSshKey( const validFor = 300n; const cert = signPublicKey(caKeys.privateKeyPem, publicKey, { - keyId: `${usernameToUse}@${orgId}`, + keyId: `${usernameToUse}@${resource.niceId}`, validPrincipals: [usernameToUse, resource.niceId], validAfter: now - 60n, // Start 1 min ago for clock skew validBefore: now + validFor }); + const [message] = await db + .insert(roundTripMessageTracker) + .values({ + wsClientId: newt.newtId, + messageType: `newt/pam/connection`, + sentAt: Math.floor(Date.now() / 1000), + }) + .returning(); + + if (!message) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to create message tracker entry" + ) + ); + } + + await sendToClient(newt.newtId, { + type: `newt/pam/connection`, + data: { + messageId: message.messageId, + orgId: orgId, + agentPort: 8080, + agentHost: resource.destination, + caCert: publicKey, + username: usernameToUse, + niceId: resource.niceId, + metadata: { + sudo: true, + homedir: true + } + } + }); + const expiresIn = Number(validFor); // seconds return response(res, { data: { certificate: cert.certificate, + messageId: message.messageId, sshUsername: usernameToUse, sshHost: resource.destination, resourceId: resource.siteResourceId, diff --git a/server/routers/external.ts b/server/routers/external.ts index 5d25e898..a9d075a6 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -50,6 +50,7 @@ import createHttpError from "http-errors"; import { build } from "@server/build"; import { createStore } from "#dynamic/lib/rateLimitStore"; import { logActionAudit } from "#dynamic/middlewares"; +import { checkRoundTripMessage } from "./ws"; // Root routes export const unauthenticated = Router(); @@ -1123,6 +1124,8 @@ authenticated.get( blueprints.getBlueprint ); +authenticated.get("/ws/round-trip-message/:messageId", checkRoundTripMessage); + // Auth routes export const authRouter = Router(); unauthenticated.use("/auth", authRouter); diff --git a/server/routers/ws/checkRoundTripMessage.ts b/server/routers/ws/checkRoundTripMessage.ts new file mode 100644 index 00000000..9c832db5 --- /dev/null +++ b/server/routers/ws/checkRoundTripMessage.ts @@ -0,0 +1,85 @@ +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { db, roundTripMessageTracker } from "@server/db"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; +import { eq } from "drizzle-orm"; +import { OpenAPITags, registry } from "@server/openApi"; + +const checkRoundTripMessageParamsSchema = z + .object({ + messageId: z + .string() + .transform(Number) + .pipe(z.number().int().positive()) + }) + .strict(); + +// registry.registerPath({ +// method: "get", +// path: "/ws/round-trip-message/{messageId}", +// description: +// "Check if a round trip message has been completed by checking the roundTripMessageTracker table", +// tags: [OpenAPITags.WebSocket], +// request: { +// params: checkRoundTripMessageParamsSchema +// }, +// responses: {} +// }); + +export async function checkRoundTripMessage( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = checkRoundTripMessageParamsSchema.safeParse( + req.params + ); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const { messageId } = parsedParams.data; + + // Get the round trip message from the tracker + const [message] = await db + .select() + .from(roundTripMessageTracker) + .where(eq(roundTripMessageTracker.messageId, messageId)) + .limit(1); + + if (!message) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Message not found") + ); + } + + return response(res, { + data: { + messageId: message.messageId, + complete: message.complete, + sentAt: message.sentAt, + receivedAt: message.receivedAt, + error: message.error, + }, + success: true, + error: false, + message: "Round trip message status retrieved successfully", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +} diff --git a/server/routers/ws/handleRoundTripMessage.ts b/server/routers/ws/handleRoundTripMessage.ts new file mode 100644 index 00000000..ed5d0773 --- /dev/null +++ b/server/routers/ws/handleRoundTripMessage.ts @@ -0,0 +1,49 @@ +import { db, roundTripMessageTracker } from "@server/db"; +import { MessageHandler } from "@server/routers/ws"; +import { eq } from "drizzle-orm"; +import logger from "@server/logger"; + +interface RoundTripCompleteMessage { + messageId: number; + complete: boolean; + error?: string; +} + +export const handleRoundTripMessage: MessageHandler = async ( + context +) => { + const { message, client: c } = context; + + logger.info("Handling round trip message"); + + const data = message.data as RoundTripCompleteMessage; + + try { + const { messageId, complete, error } = data; + + if (!messageId) { + logger.error("Round trip message missing messageId"); + return; + } + + // Update the roundTripMessageTracker with completion status + await db + .update(roundTripMessageTracker) + .set({ + complete: complete, + receivedAt: Math.floor(Date.now() / 1000), + error: error || null + }) + .where(eq(roundTripMessageTracker.messageId, messageId)); + + logger.info(`Round trip message ${messageId} marked as complete: ${complete}`); + + if (error) { + logger.warn(`Round trip message ${messageId} completed with error: ${error}`); + } + } catch (error) { + logger.error("Error processing round trip message:", error); + } + + return; +}; diff --git a/server/routers/ws/index.ts b/server/routers/ws/index.ts index b580b369..f5b4e2e4 100644 --- a/server/routers/ws/index.ts +++ b/server/routers/ws/index.ts @@ -1,2 +1,3 @@ export * from "./ws"; export * from "./types"; +export * from "./checkRoundTripMessage"; diff --git a/server/routers/ws/messageHandlers.ts b/server/routers/ws/messageHandlers.ts index 45c62e6c..9a14344a 100644 --- a/server/routers/ws/messageHandlers.ts +++ b/server/routers/ws/messageHandlers.ts @@ -18,6 +18,7 @@ import { handleOlmDisconnecingMessage } from "../olm"; import { handleHealthcheckStatusMessage } from "../target"; +import { handleRoundTripMessage } from "./handleRoundTripMessage"; import { MessageHandler } from "./types"; export const messageHandlers: Record = { @@ -35,7 +36,8 @@ export const messageHandlers: Record = { "newt/socket/containers": handleDockerContainersMessage, "newt/ping/request": handleNewtPingRequestMessage, "newt/blueprint/apply": handleApplyBlueprintMessage, - "newt/healthcheck/status": handleHealthcheckStatusMessage + "newt/healthcheck/status": handleHealthcheckStatusMessage, + "ws/round-trip/complete": handleRoundTripMessage }; startOlmOfflineChecker(); // this is to handle the offline check for olms