Add round trip tracking for any message

This commit is contained in:
Owen
2026-02-16 20:29:55 -08:00
parent 5092eb58fb
commit 3debc6c8d3
8 changed files with 317 additions and 32 deletions

View File

@@ -330,7 +330,8 @@ export const userOrgs = pgTable("userOrgs", {
.notNull()
.references(() => roles.roleId),
isOwner: boolean("isOwner").notNull().default(false),
autoProvisioned: boolean("autoProvisioned").default(false)
autoProvisioned: boolean("autoProvisioned").default(false),
pamUsername: varchar("pamUsername") // cleaned username for ssh and such
});
export const emailVerificationCodes = pgTable("emailVerificationCodes", {
@@ -986,6 +987,16 @@ export const deviceWebAuthCodes = pgTable("deviceWebAuthCodes", {
})
});
export const roundTripMessageTracker = pgTable("roundTripMessageTracker", {
messageId: serial("messageId").primaryKey(),
wsClientId: varchar("clientId"),
messageType: varchar("messageType"),
sentAt: bigint("sentAt", { mode: "number" }).notNull(),
receivedAt: bigint("receivedAt", { mode: "number" }),
error: text("error"),
complete: boolean("complete").notNull().default(false)
});
export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>;
@@ -1046,3 +1057,4 @@ export type SecurityKey = InferSelectModel<typeof securityKeys>;
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type RoundTripMessageTracker = InferSelectModel<typeof roundTripMessageTracker>;

View File

@@ -635,7 +635,8 @@ export const userOrgs = sqliteTable("userOrgs", {
isOwner: integer("isOwner", { mode: "boolean" }).notNull().default(false),
autoProvisioned: integer("autoProvisioned", {
mode: "boolean"
}).default(false)
}).default(false),
pamUsername: text("pamUsername") // cleaned username for ssh and such
});
export const emailVerificationCodes = sqliteTable("emailVerificationCodes", {
@@ -1077,6 +1078,16 @@ export const deviceWebAuthCodes = sqliteTable("deviceWebAuthCodes", {
})
});
export const roundTripMessageTracker = sqliteTable("roundTripMessageTracker", {
messageId: integer("messageId").primaryKey({ autoIncrement: true }),
wsClientId: text("clientId"),
messageType: text("messageType"),
sentAt: integer("sentAt").notNull(),
receivedAt: integer("receivedAt"),
error: text("error"),
complete: integer("complete", { mode: "boolean" }).notNull().default(false)
});
export type Org = InferSelectModel<typeof orgs>;
export type User = InferSelectModel<typeof users>;
export type Site = InferSelectModel<typeof sites>;
@@ -1138,3 +1149,6 @@ export type SecurityKey = InferSelectModel<typeof securityKeys>;
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
export type RoundTripMessageTracker = InferSelectModel<
typeof roundTripMessageTracker
>;

View File

@@ -13,7 +13,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, orgs, siteResources } from "@server/db";
import { db, newts, orgs, roundTripMessageTracker, siteResources, sites, userOrgs } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -24,6 +24,7 @@ import { eq, or, and } from "drizzle-orm";
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
import { signPublicKey, getOrgCAKeys } from "#private/lib/sshCA";
import config from "@server/lib/config";
import { sendToClient } from "#dynamic/routers/ws";
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
@@ -49,6 +50,7 @@ const bodySchema = z
export type SignSshKeyResponse = {
certificate: string;
messageId: number;
sshUsername: string;
sshHost: string;
resourceId: number;
@@ -118,6 +120,104 @@ export async function signSshKey(
);
}
const [userOrg] = await db
.select()
.from(userOrgs)
.where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)))
.limit(1);
if (!userOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"User does not belong to the specified organization"
)
);
}
let usernameToUse;
if (!userOrg.pamUsername) {
if (req.user?.email) {
// Extract username from email (first part before @)
usernameToUse = req.user?.email.split("@")[0];
if (!usernameToUse) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Unable to extract username from email"
)
);
}
} else if (req.user?.username) {
usernameToUse = req.user.username;
// We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "");
if (!usernameToUse) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Username is not valid for SSH certificate"
)
);
}
} else {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"User does not have a valid email or username for SSH certificate"
)
);
}
// check if we have a existing user in this org with the same
const [existingUserWithSameName] = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.pamUsername, usernameToUse)
)
)
.limit(1);
if (existingUserWithSameName) {
let foundUniqueUsername = false;
for (let attempt = 0; attempt < 20; attempt++) {
const randomNum = Math.floor(Math.random() * 101); // 0 to 100
const candidateUsername = `${usernameToUse}${randomNum}`;
const [existingUser] = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.pamUsername, candidateUsername)
)
)
.limit(1);
if (!existingUser) {
usernameToUse = candidateUsername;
foundUniqueUsername = true;
break;
}
}
if (!foundUniqueUsername) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Unable to generate a unique username for SSH certificate"
)
);
}
}
} else {
usernameToUse = userOrg.pamUsername;
}
// Get and decrypt the org's CA keys
const caKeys = await getOrgCAKeys(
orgId,
@@ -201,35 +301,18 @@ export async function signSshKey(
);
}
let usernameToUse;
if (req.user?.email) {
// Extract username from email (first part before @)
usernameToUse = req.user?.email.split("@")[0];
if (!usernameToUse) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Unable to extract username from email"
)
);
}
} else if (req.user?.username) {
usernameToUse = req.user.username;
// We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "");
if (!usernameToUse) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Username is not valid for SSH certificate"
)
);
}
} else {
// get the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, resource.siteId))
.limit(1);
if (!newt) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"User does not have a valid email or username for SSH certificate"
HttpCode.INTERNAL_SERVER_ERROR,
"Site associated with resource not found"
)
);
}
@@ -240,17 +323,53 @@ export async function signSshKey(
const validFor = 300n;
const cert = signPublicKey(caKeys.privateKeyPem, publicKey, {
keyId: `${usernameToUse}@${orgId}`,
keyId: `${usernameToUse}@${resource.niceId}`,
validPrincipals: [usernameToUse, resource.niceId],
validAfter: now - 60n, // Start 1 min ago for clock skew
validBefore: now + validFor
});
const [message] = await db
.insert(roundTripMessageTracker)
.values({
wsClientId: newt.newtId,
messageType: `newt/pam/connection`,
sentAt: Math.floor(Date.now() / 1000),
})
.returning();
if (!message) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create message tracker entry"
)
);
}
await sendToClient(newt.newtId, {
type: `newt/pam/connection`,
data: {
messageId: message.messageId,
orgId: orgId,
agentPort: 8080,
agentHost: resource.destination,
caCert: publicKey,
username: usernameToUse,
niceId: resource.niceId,
metadata: {
sudo: true,
homedir: true
}
}
});
const expiresIn = Number(validFor); // seconds
return response<SignSshKeyResponse>(res, {
data: {
certificate: cert.certificate,
messageId: message.messageId,
sshUsername: usernameToUse,
sshHost: resource.destination,
resourceId: resource.siteResourceId,

View File

@@ -50,6 +50,7 @@ import createHttpError from "http-errors";
import { build } from "@server/build";
import { createStore } from "#dynamic/lib/rateLimitStore";
import { logActionAudit } from "#dynamic/middlewares";
import { checkRoundTripMessage } from "./ws";
// Root routes
export const unauthenticated = Router();
@@ -1123,6 +1124,8 @@ authenticated.get(
blueprints.getBlueprint
);
authenticated.get("/ws/round-trip-message/:messageId", checkRoundTripMessage);
// Auth routes
export const authRouter = Router();
unauthenticated.use("/auth", authRouter);

View File

@@ -0,0 +1,85 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, roundTripMessageTracker } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
const checkRoundTripMessageParamsSchema = z
.object({
messageId: z
.string()
.transform(Number)
.pipe(z.number().int().positive())
})
.strict();
// registry.registerPath({
// method: "get",
// path: "/ws/round-trip-message/{messageId}",
// description:
// "Check if a round trip message has been completed by checking the roundTripMessageTracker table",
// tags: [OpenAPITags.WebSocket],
// request: {
// params: checkRoundTripMessageParamsSchema
// },
// responses: {}
// });
export async function checkRoundTripMessage(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = checkRoundTripMessageParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { messageId } = parsedParams.data;
// Get the round trip message from the tracker
const [message] = await db
.select()
.from(roundTripMessageTracker)
.where(eq(roundTripMessageTracker.messageId, messageId))
.limit(1);
if (!message) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Message not found")
);
}
return response(res, {
data: {
messageId: message.messageId,
complete: message.complete,
sentAt: message.sentAt,
receivedAt: message.receivedAt,
error: message.error,
},
success: true,
error: false,
message: "Round trip message status retrieved successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -0,0 +1,49 @@
import { db, roundTripMessageTracker } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
interface RoundTripCompleteMessage {
messageId: number;
complete: boolean;
error?: string;
}
export const handleRoundTripMessage: MessageHandler = async (
context
) => {
const { message, client: c } = context;
logger.info("Handling round trip message");
const data = message.data as RoundTripCompleteMessage;
try {
const { messageId, complete, error } = data;
if (!messageId) {
logger.error("Round trip message missing messageId");
return;
}
// Update the roundTripMessageTracker with completion status
await db
.update(roundTripMessageTracker)
.set({
complete: complete,
receivedAt: Math.floor(Date.now() / 1000),
error: error || null
})
.where(eq(roundTripMessageTracker.messageId, messageId));
logger.info(`Round trip message ${messageId} marked as complete: ${complete}`);
if (error) {
logger.warn(`Round trip message ${messageId} completed with error: ${error}`);
}
} catch (error) {
logger.error("Error processing round trip message:", error);
}
return;
};

View File

@@ -1,2 +1,3 @@
export * from "./ws";
export * from "./types";
export * from "./checkRoundTripMessage";

View File

@@ -18,6 +18,7 @@ import {
handleOlmDisconnecingMessage
} from "../olm";
import { handleHealthcheckStatusMessage } from "../target";
import { handleRoundTripMessage } from "./handleRoundTripMessage";
import { MessageHandler } from "./types";
export const messageHandlers: Record<string, MessageHandler> = {
@@ -35,7 +36,8 @@ export const messageHandlers: Record<string, MessageHandler> = {
"newt/socket/containers": handleDockerContainersMessage,
"newt/ping/request": handleNewtPingRequestMessage,
"newt/blueprint/apply": handleApplyBlueprintMessage,
"newt/healthcheck/status": handleHealthcheckStatusMessage
"newt/healthcheck/status": handleHealthcheckStatusMessage,
"ws/round-trip/complete": handleRoundTripMessage
};
startOlmOfflineChecker(); // this is to handle the offline check for olms