mirror of
https://github.com/fosrl/pangolin.git
synced 2026-02-17 23:16:21 +00:00
Add round trip tracking for any message
This commit is contained in:
@@ -330,7 +330,8 @@ export const userOrgs = pgTable("userOrgs", {
|
||||
.notNull()
|
||||
.references(() => roles.roleId),
|
||||
isOwner: boolean("isOwner").notNull().default(false),
|
||||
autoProvisioned: boolean("autoProvisioned").default(false)
|
||||
autoProvisioned: boolean("autoProvisioned").default(false),
|
||||
pamUsername: varchar("pamUsername") // cleaned username for ssh and such
|
||||
});
|
||||
|
||||
export const emailVerificationCodes = pgTable("emailVerificationCodes", {
|
||||
@@ -986,6 +987,16 @@ export const deviceWebAuthCodes = pgTable("deviceWebAuthCodes", {
|
||||
})
|
||||
});
|
||||
|
||||
export const roundTripMessageTracker = pgTable("roundTripMessageTracker", {
|
||||
messageId: serial("messageId").primaryKey(),
|
||||
wsClientId: varchar("clientId"),
|
||||
messageType: varchar("messageType"),
|
||||
sentAt: bigint("sentAt", { mode: "number" }).notNull(),
|
||||
receivedAt: bigint("receivedAt", { mode: "number" }),
|
||||
error: text("error"),
|
||||
complete: boolean("complete").notNull().default(false)
|
||||
});
|
||||
|
||||
export type Org = InferSelectModel<typeof orgs>;
|
||||
export type User = InferSelectModel<typeof users>;
|
||||
export type Site = InferSelectModel<typeof sites>;
|
||||
@@ -1046,3 +1057,4 @@ export type SecurityKey = InferSelectModel<typeof securityKeys>;
|
||||
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
|
||||
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
|
||||
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
|
||||
export type RoundTripMessageTracker = InferSelectModel<typeof roundTripMessageTracker>;
|
||||
|
||||
@@ -635,7 +635,8 @@ export const userOrgs = sqliteTable("userOrgs", {
|
||||
isOwner: integer("isOwner", { mode: "boolean" }).notNull().default(false),
|
||||
autoProvisioned: integer("autoProvisioned", {
|
||||
mode: "boolean"
|
||||
}).default(false)
|
||||
}).default(false),
|
||||
pamUsername: text("pamUsername") // cleaned username for ssh and such
|
||||
});
|
||||
|
||||
export const emailVerificationCodes = sqliteTable("emailVerificationCodes", {
|
||||
@@ -1077,6 +1078,16 @@ export const deviceWebAuthCodes = sqliteTable("deviceWebAuthCodes", {
|
||||
})
|
||||
});
|
||||
|
||||
export const roundTripMessageTracker = sqliteTable("roundTripMessageTracker", {
|
||||
messageId: integer("messageId").primaryKey({ autoIncrement: true }),
|
||||
wsClientId: text("clientId"),
|
||||
messageType: text("messageType"),
|
||||
sentAt: integer("sentAt").notNull(),
|
||||
receivedAt: integer("receivedAt"),
|
||||
error: text("error"),
|
||||
complete: integer("complete", { mode: "boolean" }).notNull().default(false)
|
||||
});
|
||||
|
||||
export type Org = InferSelectModel<typeof orgs>;
|
||||
export type User = InferSelectModel<typeof users>;
|
||||
export type Site = InferSelectModel<typeof sites>;
|
||||
@@ -1138,3 +1149,6 @@ export type SecurityKey = InferSelectModel<typeof securityKeys>;
|
||||
export type WebauthnChallenge = InferSelectModel<typeof webauthnChallenge>;
|
||||
export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
|
||||
export type DeviceWebAuthCode = InferSelectModel<typeof deviceWebAuthCodes>;
|
||||
export type RoundTripMessageTracker = InferSelectModel<
|
||||
typeof roundTripMessageTracker
|
||||
>;
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, orgs, siteResources } from "@server/db";
|
||||
import { db, newts, orgs, roundTripMessageTracker, siteResources, sites, userOrgs } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -24,6 +24,7 @@ import { eq, or, and } from "drizzle-orm";
|
||||
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
|
||||
import { signPublicKey, getOrgCAKeys } from "#private/lib/sshCA";
|
||||
import config from "@server/lib/config";
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string().nonempty()
|
||||
@@ -49,6 +50,7 @@ const bodySchema = z
|
||||
|
||||
export type SignSshKeyResponse = {
|
||||
certificate: string;
|
||||
messageId: number;
|
||||
sshUsername: string;
|
||||
sshHost: string;
|
||||
resourceId: number;
|
||||
@@ -118,6 +120,104 @@ export async function signSshKey(
|
||||
);
|
||||
}
|
||||
|
||||
const [userOrg] = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId)))
|
||||
.limit(1);
|
||||
|
||||
if (!userOrg) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not belong to the specified organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let usernameToUse;
|
||||
if (!userOrg.pamUsername) {
|
||||
if (req.user?.email) {
|
||||
// Extract username from email (first part before @)
|
||||
usernameToUse = req.user?.email.split("@")[0];
|
||||
if (!usernameToUse) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Unable to extract username from email"
|
||||
)
|
||||
);
|
||||
}
|
||||
} else if (req.user?.username) {
|
||||
usernameToUse = req.user.username;
|
||||
// We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates
|
||||
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "");
|
||||
if (!usernameToUse) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Username is not valid for SSH certificate"
|
||||
)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"User does not have a valid email or username for SSH certificate"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// check if we have a existing user in this org with the same
|
||||
const [existingUserWithSameName] = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.orgId, orgId),
|
||||
eq(userOrgs.pamUsername, usernameToUse)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (existingUserWithSameName) {
|
||||
let foundUniqueUsername = false;
|
||||
for (let attempt = 0; attempt < 20; attempt++) {
|
||||
const randomNum = Math.floor(Math.random() * 101); // 0 to 100
|
||||
const candidateUsername = `${usernameToUse}${randomNum}`;
|
||||
|
||||
const [existingUser] = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.orgId, orgId),
|
||||
eq(userOrgs.pamUsername, candidateUsername)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!existingUser) {
|
||||
usernameToUse = candidateUsername;
|
||||
foundUniqueUsername = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!foundUniqueUsername) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
"Unable to generate a unique username for SSH certificate"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
usernameToUse = userOrg.pamUsername;
|
||||
}
|
||||
|
||||
// Get and decrypt the org's CA keys
|
||||
const caKeys = await getOrgCAKeys(
|
||||
orgId,
|
||||
@@ -201,35 +301,18 @@ export async function signSshKey(
|
||||
);
|
||||
}
|
||||
|
||||
let usernameToUse;
|
||||
if (req.user?.email) {
|
||||
// Extract username from email (first part before @)
|
||||
usernameToUse = req.user?.email.split("@")[0];
|
||||
if (!usernameToUse) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Unable to extract username from email"
|
||||
)
|
||||
);
|
||||
}
|
||||
} else if (req.user?.username) {
|
||||
usernameToUse = req.user.username;
|
||||
// We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates
|
||||
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "");
|
||||
if (!usernameToUse) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Username is not valid for SSH certificate"
|
||||
)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// get the site
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, resource.siteId))
|
||||
.limit(1);
|
||||
|
||||
if (!newt) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"User does not have a valid email or username for SSH certificate"
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Site associated with resource not found"
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -240,17 +323,53 @@ export async function signSshKey(
|
||||
const validFor = 300n;
|
||||
|
||||
const cert = signPublicKey(caKeys.privateKeyPem, publicKey, {
|
||||
keyId: `${usernameToUse}@${orgId}`,
|
||||
keyId: `${usernameToUse}@${resource.niceId}`,
|
||||
validPrincipals: [usernameToUse, resource.niceId],
|
||||
validAfter: now - 60n, // Start 1 min ago for clock skew
|
||||
validBefore: now + validFor
|
||||
});
|
||||
|
||||
const [message] = await db
|
||||
.insert(roundTripMessageTracker)
|
||||
.values({
|
||||
wsClientId: newt.newtId,
|
||||
messageType: `newt/pam/connection`,
|
||||
sentAt: Math.floor(Date.now() / 1000),
|
||||
})
|
||||
.returning();
|
||||
|
||||
if (!message) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to create message tracker entry"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await sendToClient(newt.newtId, {
|
||||
type: `newt/pam/connection`,
|
||||
data: {
|
||||
messageId: message.messageId,
|
||||
orgId: orgId,
|
||||
agentPort: 8080,
|
||||
agentHost: resource.destination,
|
||||
caCert: publicKey,
|
||||
username: usernameToUse,
|
||||
niceId: resource.niceId,
|
||||
metadata: {
|
||||
sudo: true,
|
||||
homedir: true
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
const expiresIn = Number(validFor); // seconds
|
||||
|
||||
return response<SignSshKeyResponse>(res, {
|
||||
data: {
|
||||
certificate: cert.certificate,
|
||||
messageId: message.messageId,
|
||||
sshUsername: usernameToUse,
|
||||
sshHost: resource.destination,
|
||||
resourceId: resource.siteResourceId,
|
||||
|
||||
@@ -50,6 +50,7 @@ import createHttpError from "http-errors";
|
||||
import { build } from "@server/build";
|
||||
import { createStore } from "#dynamic/lib/rateLimitStore";
|
||||
import { logActionAudit } from "#dynamic/middlewares";
|
||||
import { checkRoundTripMessage } from "./ws";
|
||||
|
||||
// Root routes
|
||||
export const unauthenticated = Router();
|
||||
@@ -1123,6 +1124,8 @@ authenticated.get(
|
||||
blueprints.getBlueprint
|
||||
);
|
||||
|
||||
authenticated.get("/ws/round-trip-message/:messageId", checkRoundTripMessage);
|
||||
|
||||
// Auth routes
|
||||
export const authRouter = Router();
|
||||
unauthenticated.use("/auth", authRouter);
|
||||
|
||||
85
server/routers/ws/checkRoundTripMessage.ts
Normal file
85
server/routers/ws/checkRoundTripMessage.ts
Normal file
@@ -0,0 +1,85 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, roundTripMessageTracker } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const checkRoundTripMessageParamsSchema = z
|
||||
.object({
|
||||
messageId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.number().int().positive())
|
||||
})
|
||||
.strict();
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "get",
|
||||
// path: "/ws/round-trip-message/{messageId}",
|
||||
// description:
|
||||
// "Check if a round trip message has been completed by checking the roundTripMessageTracker table",
|
||||
// tags: [OpenAPITags.WebSocket],
|
||||
// request: {
|
||||
// params: checkRoundTripMessageParamsSchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function checkRoundTripMessage(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = checkRoundTripMessageParamsSchema.safeParse(
|
||||
req.params
|
||||
);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { messageId } = parsedParams.data;
|
||||
|
||||
// Get the round trip message from the tracker
|
||||
const [message] = await db
|
||||
.select()
|
||||
.from(roundTripMessageTracker)
|
||||
.where(eq(roundTripMessageTracker.messageId, messageId))
|
||||
.limit(1);
|
||||
|
||||
if (!message) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Message not found")
|
||||
);
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: {
|
||||
messageId: message.messageId,
|
||||
complete: message.complete,
|
||||
sentAt: message.sentAt,
|
||||
receivedAt: message.receivedAt,
|
||||
error: message.error,
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Round trip message status retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
49
server/routers/ws/handleRoundTripMessage.ts
Normal file
49
server/routers/ws/handleRoundTripMessage.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { db, roundTripMessageTracker } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { eq } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
interface RoundTripCompleteMessage {
|
||||
messageId: number;
|
||||
complete: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export const handleRoundTripMessage: MessageHandler = async (
|
||||
context
|
||||
) => {
|
||||
const { message, client: c } = context;
|
||||
|
||||
logger.info("Handling round trip message");
|
||||
|
||||
const data = message.data as RoundTripCompleteMessage;
|
||||
|
||||
try {
|
||||
const { messageId, complete, error } = data;
|
||||
|
||||
if (!messageId) {
|
||||
logger.error("Round trip message missing messageId");
|
||||
return;
|
||||
}
|
||||
|
||||
// Update the roundTripMessageTracker with completion status
|
||||
await db
|
||||
.update(roundTripMessageTracker)
|
||||
.set({
|
||||
complete: complete,
|
||||
receivedAt: Math.floor(Date.now() / 1000),
|
||||
error: error || null
|
||||
})
|
||||
.where(eq(roundTripMessageTracker.messageId, messageId));
|
||||
|
||||
logger.info(`Round trip message ${messageId} marked as complete: ${complete}`);
|
||||
|
||||
if (error) {
|
||||
logger.warn(`Round trip message ${messageId} completed with error: ${error}`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Error processing round trip message:", error);
|
||||
}
|
||||
|
||||
return;
|
||||
};
|
||||
@@ -1,2 +1,3 @@
|
||||
export * from "./ws";
|
||||
export * from "./types";
|
||||
export * from "./checkRoundTripMessage";
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
handleOlmDisconnecingMessage
|
||||
} from "../olm";
|
||||
import { handleHealthcheckStatusMessage } from "../target";
|
||||
import { handleRoundTripMessage } from "./handleRoundTripMessage";
|
||||
import { MessageHandler } from "./types";
|
||||
|
||||
export const messageHandlers: Record<string, MessageHandler> = {
|
||||
@@ -35,7 +36,8 @@ export const messageHandlers: Record<string, MessageHandler> = {
|
||||
"newt/socket/containers": handleDockerContainersMessage,
|
||||
"newt/ping/request": handleNewtPingRequestMessage,
|
||||
"newt/blueprint/apply": handleApplyBlueprintMessage,
|
||||
"newt/healthcheck/status": handleHealthcheckStatusMessage
|
||||
"newt/healthcheck/status": handleHealthcheckStatusMessage,
|
||||
"ws/round-trip/complete": handleRoundTripMessage
|
||||
};
|
||||
|
||||
startOlmOfflineChecker(); // this is to handle the offline check for olms
|
||||
|
||||
Reference in New Issue
Block a user