Merge branch 'private-http-ha' into alerting-rules

This commit is contained in:
Owen
2026-04-15 14:59:50 -07:00
63 changed files with 4120 additions and 1449 deletions

View File

@@ -1,8 +1,8 @@
import { logsDb, primaryLogsDb, requestAuditLog, resources, db, primaryDb } from "@server/db";
import { logsDb, primaryLogsDb, requestAuditLog, resources, siteResources, db, primaryDb } from "@server/db";
import { registry } from "@server/openApi";
import { NextFunction } from "express";
import { Request, Response } from "express";
import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm";
import { eq, gt, lt, and, count, desc, inArray, isNull, or } from "drizzle-orm";
import { OpenAPITags } from "@server/openApi";
import { z } from "zod";
import createHttpError from "http-errors";
@@ -92,7 +92,10 @@ function getWhere(data: Q) {
lt(requestAuditLog.timestamp, data.timeEnd),
eq(requestAuditLog.orgId, data.orgId),
data.resourceId
? eq(requestAuditLog.resourceId, data.resourceId)
? or(
eq(requestAuditLog.resourceId, data.resourceId),
eq(requestAuditLog.siteResourceId, data.resourceId)
)
: undefined,
data.actor ? eq(requestAuditLog.actor, data.actor) : undefined,
data.method ? eq(requestAuditLog.method, data.method) : undefined,
@@ -110,15 +113,16 @@ export function queryRequest(data: Q) {
return primaryLogsDb
.select({
id: requestAuditLog.id,
timestamp: requestAuditLog.timestamp,
orgId: requestAuditLog.orgId,
action: requestAuditLog.action,
reason: requestAuditLog.reason,
actorType: requestAuditLog.actorType,
actor: requestAuditLog.actor,
actorId: requestAuditLog.actorId,
resourceId: requestAuditLog.resourceId,
ip: requestAuditLog.ip,
timestamp: requestAuditLog.timestamp,
orgId: requestAuditLog.orgId,
action: requestAuditLog.action,
reason: requestAuditLog.reason,
actorType: requestAuditLog.actorType,
actor: requestAuditLog.actor,
actorId: requestAuditLog.actorId,
resourceId: requestAuditLog.resourceId,
siteResourceId: requestAuditLog.siteResourceId,
ip: requestAuditLog.ip,
location: requestAuditLog.location,
userAgent: requestAuditLog.userAgent,
metadata: requestAuditLog.metadata,
@@ -137,37 +141,73 @@ export function queryRequest(data: Q) {
}
async function enrichWithResourceDetails(logs: Awaited<ReturnType<typeof queryRequest>>) {
// If logs database is the same as main database, we can do a join
// Otherwise, we need to fetch resource details separately
const resourceIds = logs
.map(log => log.resourceId)
.filter((id): id is number => id !== null && id !== undefined);
if (resourceIds.length === 0) {
const siteResourceIds = logs
.filter(log => log.resourceId == null && log.siteResourceId != null)
.map(log => log.siteResourceId)
.filter((id): id is number => id !== null && id !== undefined);
if (resourceIds.length === 0 && siteResourceIds.length === 0) {
return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null }));
}
// Fetch resource details from main database
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name,
niceId: resources.niceId
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
const resourceMap = new Map<number, { name: string | null; niceId: string | null }>();
// Create a map for quick lookup
const resourceMap = new Map(
resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }])
);
if (resourceIds.length > 0) {
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name,
niceId: resources.niceId
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
for (const r of resourceDetails) {
resourceMap.set(r.resourceId, { name: r.name, niceId: r.niceId });
}
}
const siteResourceMap = new Map<number, { name: string | null; niceId: string | null }>();
if (siteResourceIds.length > 0) {
const siteResourceDetails = await primaryDb
.select({
siteResourceId: siteResources.siteResourceId,
name: siteResources.name,
niceId: siteResources.niceId
})
.from(siteResources)
.where(inArray(siteResources.siteResourceId, siteResourceIds));
for (const r of siteResourceDetails) {
siteResourceMap.set(r.siteResourceId, { name: r.name, niceId: r.niceId });
}
}
// Enrich logs with resource details
return logs.map(log => ({
...log,
resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null,
resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null
}));
return logs.map(log => {
if (log.resourceId != null) {
const details = resourceMap.get(log.resourceId);
return {
...log,
resourceName: details?.name ?? null,
resourceNiceId: details?.niceId ?? null
};
} else if (log.siteResourceId != null) {
const details = siteResourceMap.get(log.siteResourceId);
return {
...log,
resourceId: log.siteResourceId,
resourceName: details?.name ?? null,
resourceNiceId: details?.niceId ?? null
};
}
return { ...log, resourceName: null, resourceNiceId: null };
});
}
export function countRequestQuery(data: Q) {
@@ -211,7 +251,8 @@ async function queryUniqueFilterAttributes(
uniqueLocations,
uniqueHosts,
uniquePaths,
uniqueResources
uniqueResources,
uniqueSiteResources
] = await Promise.all([
primaryLogsDb
.selectDistinct({ actor: requestAuditLog.actor })
@@ -239,6 +280,13 @@ async function queryUniqueFilterAttributes(
})
.from(requestAuditLog)
.where(baseConditions)
.limit(DISTINCT_LIMIT + 1),
primaryLogsDb
.selectDistinct({
id: requestAuditLog.siteResourceId
})
.from(requestAuditLog)
.where(and(baseConditions, isNull(requestAuditLog.resourceId)))
.limit(DISTINCT_LIMIT + 1)
]);
@@ -259,6 +307,10 @@ async function queryUniqueFilterAttributes(
.map(row => row.id)
.filter((id): id is number => id !== null);
const siteResourceIds = uniqueSiteResources
.map(row => row.id)
.filter((id): id is number => id !== null);
let resourcesWithNames: Array<{ id: number; name: string | null }> = [];
if (resourceIds.length > 0) {
@@ -270,10 +322,31 @@ async function queryUniqueFilterAttributes(
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
resourcesWithNames = resourceDetails.map(r => ({
id: r.resourceId,
name: r.name
}));
resourcesWithNames = [
...resourcesWithNames,
...resourceDetails.map(r => ({
id: r.resourceId,
name: r.name
}))
];
}
if (siteResourceIds.length > 0) {
const siteResourceDetails = await primaryDb
.select({
siteResourceId: siteResources.siteResourceId,
name: siteResources.name
})
.from(siteResources)
.where(inArray(siteResources.siteResourceId, siteResourceIds));
resourcesWithNames = [
...resourcesWithNames,
...siteResourceDetails.map(r => ({
id: r.siteResourceId,
name: r.name
}))
];
}
return {

View File

@@ -28,6 +28,7 @@ export type QueryRequestAuditLogResponse = {
actor: string | null;
actorId: string | null;
resourceId: number | null;
siteResourceId: number | null;
resourceNiceId: string | null;
resourceName: string | null;
ip: string | null;

View File

@@ -18,6 +18,7 @@ Reasons:
105 - Valid Password
106 - Valid email
107 - Valid SSO
108 - Connected Client
201 - Resource Not Found
202 - Resource Blocked
@@ -38,6 +39,7 @@ const auditLogBuffer: Array<{
metadata: any;
action: boolean;
resourceId?: number;
siteResourceId?: number;
reason: number;
location?: string;
originalRequestURL: string;
@@ -186,6 +188,7 @@ export async function logRequestAudit(
action: boolean;
reason: number;
resourceId?: number;
siteResourceId?: number;
orgId?: string;
location?: string;
user?: { username: string; userId: string };
@@ -262,6 +265,7 @@ export async function logRequestAudit(
metadata: sanitizeString(metadata),
action: data.action,
resourceId: data.resourceId,
siteResourceId: data.siteResourceId,
reason: data.reason,
location: sanitizeString(data.location),
originalRequestURL: sanitizeString(body.originalRequestURL) ?? "",

View File

@@ -4,8 +4,10 @@ import {
clientSitesAssociationsCache,
db,
ExitNode,
networks,
resources,
Site,
siteNetworks,
siteResources,
targetHealthCheck,
targets
@@ -137,11 +139,14 @@ export async function buildClientConfigurationForNewtClient(
// Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null);
// Get all enabled site resources for this site
// Get all enabled site resources for this site by joining through siteNetworks and networks
const allSiteResources = await db
.select()
.from(siteResources)
.where(eq(siteResources.siteId, siteId));
.innerJoin(networks, eq(siteResources.networkId, networks.networkId))
.innerJoin(siteNetworks, eq(networks.networkId, siteNetworks.networkId))
.where(eq(siteNetworks.siteId, siteId))
.then((rows) => rows.map((r) => r.siteResources));
const targetsToSend: SubnetProxyTargetV2[] = [];
@@ -168,7 +173,7 @@ export async function buildClientConfigurationForNewtClient(
)
);
const resourceTargets = generateSubnetProxyTargetV2(
const resourceTargets = await generateSubnetProxyTargetV2(
resource,
resourceClients
);

View File

@@ -10,7 +10,7 @@ import { convertTargetsIfNessicary } from "../client/targets";
import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config";
export const handleGetConfigMessage: MessageHandler = async (context) => {
export const handleNewtGetConfigMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
const newt = client as Newt;
@@ -56,7 +56,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) {
logger.warn(
`Site last hole punch is too old; skipping this register. The site is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`
`Site last hole punch is too old; skipping this register. The site is failing to hole punch and identify its network address with the server. Can the site reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`
);
return;
}
@@ -113,7 +113,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
exitNode
);
const targetsToSend = await convertTargetsIfNessicary(newt.newtId, targets);
const targetsToSend = await convertTargetsIfNessicary(newt.newtId, targets); // for backward compatibility with old newt versions that don't support the new target format
return {
message: {

View File

@@ -0,0 +1,9 @@
import { MessageHandler } from "@server/routers/ws";
export async function flushRequestLogToDb(): Promise<void> {
return;
}
export const handleRequestLogMessage: MessageHandler = async (context) => {
return;
};

View File

@@ -2,11 +2,12 @@ export * from "./createNewt";
export * from "./getNewtToken";
export * from "./handleNewtRegisterMessage";
export * from "./handleReceiveBandwidthMessage";
export * from "./handleGetConfigMessage";
export * from "./handleNewtGetConfigMessage";
export * from "./handleSocketMessages";
export * from "./handleNewtPingRequestMessage";
export * from "./handleApplyBlueprintMessage";
export * from "./handleNewtPingMessage";
export * from "./handleNewtDisconnectingMessage";
export * from "./handleConnectionLogMessage";
export * from "./handleRequestLogMessage";
export * from "./registerNewt";

View File

@@ -4,6 +4,8 @@ import {
clientSitesAssociationsCache,
db,
exitNodes,
networks,
siteNetworks,
siteResources,
sites
} from "@server/db";
@@ -59,9 +61,17 @@ export async function buildSiteConfigurationForOlmClient(
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
eq(networks.networkId, siteNetworks.networkId)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(siteNetworks.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
@@ -69,6 +79,7 @@ export async function buildSiteConfigurationForOlmClient(
)
);
if (jitMode) {
// Add site configuration to the array
siteConfigurations.push({

View File

@@ -17,7 +17,6 @@ import { getUserDeviceName } from "@server/db/names";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { OlmErrorCodes, sendOlmError } from "./error";
import { handleFingerprintInsertion } from "./fingerprintingUtils";
import { Alias } from "@server/lib/ip";
import { build } from "@server/build";
import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config";

View File

@@ -4,10 +4,12 @@ import {
db,
exitNodes,
Site,
siteResources
siteNetworks,
siteResources,
sites
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, Olm, sites } from "@server/db";
import { clients, Olm } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import logger from "@server/logger";
import { initPeerAddHandshake } from "./peers";
@@ -44,20 +46,31 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
const { siteId, resourceId, chainId } = message.data;
let site: Site | null = null;
const sendCancel = async () => {
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: { chainId }
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
};
let sitesToProcess: Site[] = [];
if (siteId) {
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (siteRes) {
site = siteRes;
sitesToProcess = [siteRes];
}
}
if (resourceId && !site) {
} else if (resourceId) {
const resources = await db
.select()
.from(siteResources)
@@ -72,27 +85,17 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
);
if (!resources || resources.length === 0) {
logger.error(`handleOlmServerPeerAddMessage: Resource not found`);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
logger.error(
`handleOlmServerInitAddPeerHandshake: Resource not found`
);
await sendCancel();
return;
}
if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
logger.error(
`handleOlmServerPeerAddMessage: Multiple resources found matching the criteria`
`handleOlmServerInitAddPeerHandshake: Multiple resources found matching the criteria`
);
return;
}
@@ -117,125 +120,120 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
if (currentResourceAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
`handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
await sendCancel();
return;
}
const siteIdFromResource = resource.siteId;
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteIdFromResource));
if (!siteRes) {
if (!resource.networkId) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site} not found`
`handleOlmServerInitAddPeerHandshake: Resource ${resource.siteResourceId} has no network`
);
await sendCancel();
return;
}
site = siteRes;
// Get all sites associated with this resource's network via siteNetworks
const siteRows = await db
.select({ siteId: siteNetworks.siteId })
.from(siteNetworks)
.where(eq(siteNetworks.networkId, resource.networkId));
if (!siteRows || siteRows.length === 0) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No sites found for resource ${resource.siteResourceId}`
);
await sendCancel();
return;
}
// Fetch full site objects for all network members
const foundSites = await Promise.all(
siteRows.map(async ({ siteId: sid }) => {
const [s] = await db
.select()
.from(sites)
.where(eq(sites.siteId, sid))
.limit(1);
return s ?? null;
})
);
sitesToProcess = foundSites.filter((s): s is Site => s !== null);
}
if (!site) {
logger.error(`handleOlmServerPeerAddMessage: Site not found`);
if (sitesToProcess.length === 0) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No sites to process`
);
await sendCancel();
return;
}
// check if the client can access this site using the cache
const currentSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
let handshakeInitiated = false;
if (currentSiteAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
for (const site of sitesToProcess) {
// Check if the client can access this site using the cache
const currentSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
if (currentSiteAssociationCaches.length === 0) {
logger.warn(
`handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to site ${site.siteId}, skipping`
);
continue;
}
if (!site.exitNodeId) {
logger.error(
`handleOlmServerInitAddPeerHandshake: Site ${site.siteId} has no exit node, skipping`
);
continue;
}
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerInitAddPeerHandshake: Exit node not found for site ${site.siteId}, skipping`
);
continue;
}
// Trigger the peer add handshake — if the peer was already added this will be a no-op
await initPeerAddHandshake(
client.clientId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (!site.exitNodeId) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
// get the exit node from the side
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
chainId
);
return;
handshakeInitiated = true;
}
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
{
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
olm.olmId,
chainId
);
if (!handshakeInitiated) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No accessible sites with valid exit nodes found, cancelling chain`
);
await sendCancel();
}
return;
};
};

View File

@@ -1,43 +1,25 @@
import {
Client,
clientSiteResourcesAssociationsCache,
db,
ExitNode,
Org,
orgs,
roleClients,
roles,
networks,
siteNetworks,
siteResources,
Transaction,
userClients,
userOrgs,
users
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import {
clients,
clientSitesAssociationsCache,
exitNodes,
Olm,
olms,
sites
} from "@server/db";
import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import {
generateAliasConfig,
getNextAvailableClientSubnet
} from "@server/lib/ip";
import { generateRemoteSubnets } from "@server/lib/ip";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import {
addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "@server/routers/newt/peers";
export const handleOlmServerPeerAddMessage: MessageHandler = async (
@@ -153,13 +135,21 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
.innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
eq(networks.networkId, siteNetworks.networkId),
eq(siteNetworks.siteId, site.siteId)
)
)
.where(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
);

View File

@@ -145,7 +145,7 @@ export async function getUserResources(
niceId: string;
destination: string;
mode: string;
protocol: string | null;
scheme: string | null;
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
@@ -158,7 +158,7 @@ export async function getUserResources(
niceId: siteResources.niceId,
destination: siteResources.destination,
mode: siteResources.mode,
protocol: siteResources.protocol,
scheme: siteResources.scheme,
enabled: siteResources.enabled,
alias: siteResources.alias,
aliasAddress: siteResources.aliasAddress
@@ -242,7 +242,7 @@ export async function getUserResources(
name: siteResource.name,
destination: siteResource.destination,
mode: siteResource.mode,
protocol: siteResource.protocol,
protocol: siteResource.scheme,
enabled: siteResource.enabled,
alias: siteResource.alias,
aliasAddress: siteResource.aliasAddress,
@@ -291,7 +291,7 @@ export type GetUserResourcesResponse = {
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
type: 'site';
type: "site";
}>;
};
};

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, Site, siteResources } from "@server/db";
import { db, Site, siteNetworks, siteResources } from "@server/db";
import { newts, newtSessions, sites } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
@@ -71,18 +71,23 @@ export async function deleteSite(
await deletePeer(site.exitNodeId!, site.pubKey);
}
} else if (site.type == "newt") {
// delete all of the site resources on this site
const siteResourcesOnSite = trx
.delete(siteResources)
.where(eq(siteResources.siteId, siteId))
.returning();
const networks = await trx
.select({ networkId: siteNetworks.networkId })
.from(siteNetworks)
.where(eq(siteNetworks.siteId, siteId));
// loop through them
for (const removedSiteResource of await siteResourcesOnSite) {
await rebuildClientAssociationsFromSiteResource(
removedSiteResource,
trx
);
for (const network of await networks) {
const [siteResource] = await trx
.select()
.from(siteResources)
.where(eq(siteResources.networkId, network.networkId));
if (siteResource) {
await rebuildClientAssociationsFromSiteResource(
siteResource,
trx
);
}
}
// get the newt on the site by querying the newt table for siteId

View File

@@ -5,6 +5,8 @@ import {
orgs,
roles,
roleSiteResources,
siteNetworks,
networks,
SiteResource,
siteResources,
sites,
@@ -17,17 +19,18 @@ import {
portRangeStringSchema
} from "@server/lib/ip";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
import response from "@server/lib/response";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { validateAndConstructDomain } from "@server/lib/domainUtils";
const createSiteResourceParamsSchema = z.strictObject({
orgId: z.string()
@@ -36,11 +39,12 @@ const createSiteResourceParamsSchema = z.strictObject({
const createSiteResourceSchema = z
.strictObject({
name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr", "port"]),
siteId: z.int(),
// protocol: z.enum(["tcp", "udp"]).optional(),
mode: z.enum(["host", "cidr", "http"]),
ssl: z.boolean().optional(), // only used for http mode
scheme: z.enum(["http", "https"]).optional(),
siteIds: z.array(z.int()),
// proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(),
destinationPort: z.int().positive().optional(),
destination: z.string().min(1),
enabled: z.boolean().default(true),
alias: z
@@ -57,20 +61,24 @@ const createSiteResourceSchema = z
udpPortRangeString: portRangeStringSchema,
disableIcmp: z.boolean().optional(),
authDaemonPort: z.int().positive().optional(),
authDaemonMode: z.enum(["site", "remote"]).optional()
authDaemonMode: z.enum(["site", "remote"]).optional(),
domainId: z.string().optional(), // only used for http mode, we need this to verify the alias is unique within the org
subdomain: z.string().optional() // only used for http mode, we need this to verify the alias is unique within the org
})
.strict()
.refine(
(data) => {
if (data.mode === "host") {
// Check if it's a valid IP address using zod (v4 or v6)
const isValidIP = z
// .union([z.ipv4(), z.ipv6()])
.union([z.ipv4()]) // for now lets just do ipv4 until we verify ipv6 works everywhere
.safeParse(data.destination).success;
if (data.mode == "host") {
// Check if it's a valid IP address using zod (v4 or v6)
const isValidIP = z
// .union([z.ipv4(), z.ipv6()])
.union([z.ipv4()]) // for now lets just do ipv4 until we verify ipv6 works everywhere
.safeParse(data.destination).success;
if (isValidIP) {
return true;
if (isValidIP) {
return true;
}
}
// Check if it's a valid domain (hostname pattern, TLD not required)
@@ -105,6 +113,21 @@ const createSiteResourceSchema = z
{
message: "Destination must be a valid CIDR notation for cidr mode"
}
)
.refine(
(data) => {
if (data.mode !== "http") return true;
return (
data.scheme !== undefined &&
data.destinationPort !== undefined &&
data.destinationPort >= 1 &&
data.destinationPort <= 65535
);
},
{
message:
"HTTP mode requires scheme (http or https) and a valid destination port"
}
);
export type CreateSiteResourceBody = z.infer<typeof createSiteResourceSchema>;
@@ -159,13 +182,14 @@ export async function createSiteResource(
const { orgId } = parsedParams.data;
const {
name,
siteId,
siteIds,
mode,
// protocol,
scheme,
// proxyPort,
// destinationPort,
destinationPort,
destination,
enabled,
ssl,
alias,
userIds,
roleIds,
@@ -174,18 +198,36 @@ export async function createSiteResource(
udpPortRangeString,
disableIcmp,
authDaemonPort,
authDaemonMode
authDaemonMode,
domainId,
subdomain
} = parsedBody.data;
if (mode == "http") {
const hasHttpFeature = await isLicensedOrSubscribed(
orgId,
tierMatrix[TierFeature.HTTPPrivateResources]
);
if (!hasHttpFeature) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"HTTP private resources are not included in your current plan. Please upgrade."
)
);
}
}
// Verify the site exists and belongs to the org
const [site] = await db
const sitesToAssign = await db
.select()
.from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId)))
.limit(1);
.where(and(inArray(sites.siteId, siteIds), eq(sites.orgId, orgId)));
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
if (sitesToAssign.length !== siteIds.length) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
}
const [org] = await db
@@ -226,29 +268,50 @@ export async function createSiteResource(
);
}
// // check if resource with same protocol and proxy port already exists (only for port mode)
// if (mode === "port" && protocol && proxyPort) {
// const [existingResource] = await db
// .select()
// .from(siteResources)
// .where(
// and(
// eq(siteResources.siteId, siteId),
// eq(siteResources.orgId, orgId),
// eq(siteResources.protocol, protocol),
// eq(siteResources.proxyPort, proxyPort)
// )
// )
// .limit(1);
// if (existingResource && existingResource.siteResourceId) {
// return next(
// createHttpError(
// HttpCode.CONFLICT,
// "A resource with the same protocol and proxy port already exists"
// )
// );
// }
// }
if (domainId && alias) {
// throw an error because we can only have one or the other
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Alias and domain cannot both be set. Please choose one or the other."
)
);
}
let fullDomain: string | null = null;
let finalSubdomain: string | null = null;
if (domainId) {
// Validate domain and construct full domain
const domainResult = await validateAndConstructDomain(
domainId,
orgId,
subdomain
);
if (!domainResult.success) {
return next(
createHttpError(HttpCode.BAD_REQUEST, domainResult.error)
);
}
fullDomain = domainResult.fullDomain;
finalSubdomain = domainResult.subdomain;
// make sure the full domain is unique
const existingResource = await db
.select()
.from(siteResources)
.where(eq(siteResources.fullDomain, fullDomain));
if (existingResource.length > 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Resource with that domain already exists"
)
);
}
}
// make sure the alias is unique within the org if provided
if (alias) {
@@ -280,27 +343,49 @@ export async function createSiteResource(
const niceId = await getUniqueSiteResourceName(orgId);
let aliasAddress: string | null = null;
if (mode == "host") {
// we can only have an alias on a host
if (mode === "host" || mode === "http") {
aliasAddress = await getNextAvailableAliasAddress(orgId);
}
let newSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => {
const [network] = await trx
.insert(networks)
.values({
scope: "resource",
orgId: orgId
})
.returning();
if (!network) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Failed to create network`
)
);
}
// Create the site resource
const insertValues: typeof siteResources.$inferInsert = {
siteId,
niceId,
orgId,
name,
mode: mode as "host" | "cidr",
mode,
ssl,
networkId: network.networkId,
destination,
scheme,
destinationPort,
enabled,
alias,
alias: alias ? alias.trim() : null,
aliasAddress,
tcpPortRangeString,
udpPortRangeString,
disableIcmp
disableIcmp,
domainId,
subdomain: finalSubdomain,
fullDomain
};
if (isLicensedSshPam) {
if (authDaemonPort !== undefined)
@@ -317,6 +402,13 @@ export async function createSiteResource(
//////////////////// update the associations ////////////////////
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: network.networkId
});
}
const [adminRole] = await trx
.select()
.from(roles)
@@ -359,16 +451,21 @@ export async function createSiteResource(
);
}
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
for (const siteToAssign of sitesToAssign) {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, siteToAssign.siteId))
.limit(1);
if (!newt) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found")
);
if (!newt) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Newt not found for site ${siteToAssign.siteId}`
)
);
}
}
await rebuildClientAssociationsFromSiteResource(
@@ -387,7 +484,7 @@ export async function createSiteResource(
}
logger.info(
`Created site resource ${newSiteResource.siteResourceId} for site ${siteId}`
`Created site resource ${newSiteResource.siteResourceId} for org ${orgId}`
);
return response(res, {

View File

@@ -70,17 +70,18 @@ export async function deleteSiteResource(
.where(and(eq(siteResources.siteResourceId, siteResourceId)))
.returning();
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, removedSiteResource.siteId))
.limit(1);
// not sure why this is here...
// const [newt] = await trx
// .select()
// .from(newts)
// .where(eq(newts.siteId, removedSiteResource.siteId))
// .limit(1);
if (!newt) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found")
);
}
// if (!newt) {
// return next(
// createHttpError(HttpCode.NOT_FOUND, "Newt not found")
// );
// }
await rebuildClientAssociationsFromSiteResource(
removedSiteResource,

View File

@@ -17,38 +17,34 @@ const getSiteResourceParamsSchema = z.strictObject({
.transform((val) => (val ? Number(val) : undefined))
.pipe(z.int().positive().optional())
.optional(),
siteId: z.string().transform(Number).pipe(z.int().positive()),
niceId: z.string().optional(),
orgId: z.string()
});
async function query(
siteResourceId?: number,
siteId?: number,
niceId?: string,
orgId?: string
) {
if (siteResourceId && siteId && orgId) {
if (siteResourceId && orgId) {
const [siteResource] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
)
)
.limit(1);
return siteResource;
} else if (niceId && siteId && orgId) {
} else if (niceId && orgId) {
const [siteResource] = await db
.select()
.from(siteResources)
.where(
and(
eq(siteResources.niceId, niceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
)
)
@@ -84,7 +80,6 @@ registry.registerPath({
request: {
params: z.object({
niceId: z.string(),
siteId: z.number(),
orgId: z.string()
})
},
@@ -107,10 +102,10 @@ export async function getSiteResource(
);
}
const { siteResourceId, siteId, niceId, orgId } = parsedParams.data;
const { siteResourceId, niceId, orgId } = parsedParams.data;
// Get the site resource
const siteResource = await query(siteResourceId, siteId, niceId, orgId);
const siteResource = await query(siteResourceId, niceId, orgId);
if (!siteResource) {
return next(

View File

@@ -1,4 +1,4 @@
import { db, SiteResource, siteResources, sites } from "@server/db";
import { db, DB_TYPE, SiteResource, siteNetworks, siteResources, sites } from "@server/db";
import response from "@server/lib/response";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
@@ -41,12 +41,12 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({
}),
query: z.string().optional(),
mode: z
.enum(["host", "cidr"])
.enum(["host", "cidr", "http"])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: ["host", "cidr"],
enum: ["host", "cidr", "http"],
description: "Filter site resources by mode"
}),
sort_by: z
@@ -73,22 +73,58 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({
export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{
siteResources: (SiteResource & {
siteName: string;
siteNiceId: string;
siteAddress: string | null;
siteOnlines: boolean[];
siteIds: number[];
siteNames: string[];
siteNiceIds: string[];
siteAddresses: (string | null)[];
})[];
}>;
/**
* Returns an aggregation expression compatible with both SQLite and PostgreSQL.
* - SQLite: json_group_array(col) → returns a JSON array string, parsed after fetch
* - PostgreSQL: array_agg(col) → returns a native array
*/
function aggCol<T>(column: any) {
if (DB_TYPE === "sqlite") {
return sql<T>`json_group_array(${column})`;
}
return sql<T>`array_agg(${column})`;
}
/**
* For SQLite the aggregated columns come back as JSON strings; parse them into
* proper arrays. For PostgreSQL the driver already returns native arrays, so
* the row is returned unchanged.
*/
function transformSiteResourceRow(row: any) {
if (DB_TYPE !== "sqlite") {
return row;
}
return {
...row,
siteNames: JSON.parse(row.siteNames) as string[],
siteNiceIds: JSON.parse(row.siteNiceIds) as string[],
siteIds: JSON.parse(row.siteIds) as number[],
siteAddresses: JSON.parse(row.siteAddresses) as (string | null)[],
// SQLite stores booleans as 0/1 integers
siteOnlines: (JSON.parse(row.siteOnlines) as (0 | 1)[]).map(
(v) => v === 1
) as boolean[]
};
}
function querySiteResourcesBase() {
return db
.select({
siteResourceId: siteResources.siteResourceId,
siteId: siteResources.siteId,
orgId: siteResources.orgId,
niceId: siteResources.niceId,
name: siteResources.name,
mode: siteResources.mode,
protocol: siteResources.protocol,
ssl: siteResources.ssl,
scheme: siteResources.scheme,
proxyPort: siteResources.proxyPort,
destinationPort: siteResources.destinationPort,
destination: siteResources.destination,
@@ -100,12 +136,24 @@ function querySiteResourcesBase() {
disableIcmp: siteResources.disableIcmp,
authDaemonMode: siteResources.authDaemonMode,
authDaemonPort: siteResources.authDaemonPort,
siteName: sites.name,
siteNiceId: sites.niceId,
siteAddress: sites.address
subdomain: siteResources.subdomain,
domainId: siteResources.domainId,
fullDomain: siteResources.fullDomain,
networkId: siteResources.networkId,
defaultNetworkId: siteResources.defaultNetworkId,
siteNames: aggCol<string[]>(sites.name),
siteNiceIds: aggCol<string[]>(sites.niceId),
siteIds: aggCol<number[]>(sites.siteId),
siteAddresses: aggCol<(string | null)[]>(sites.address),
siteOnlines: aggCol<boolean[]>(sites.online)
})
.from(siteResources)
.innerJoin(sites, eq(siteResources.siteId, sites.siteId));
.innerJoin(
siteNetworks,
eq(siteResources.networkId, siteNetworks.networkId)
)
.innerJoin(sites, eq(siteNetworks.siteId, sites.siteId))
.groupBy(siteResources.siteResourceId);
}
registry.registerPath({
@@ -193,10 +241,12 @@ export async function listAllSiteResourcesByOrg(
const baseQuery = querySiteResourcesBase().where(and(...conditions));
const countQuery = db.$count(
querySiteResourcesBase().where(and(...conditions)).as("filtered_site_resources")
querySiteResourcesBase()
.where(and(...conditions))
.as("filtered_site_resources")
);
const [siteResourcesList, totalCount] = await Promise.all([
const [siteResourcesRaw, totalCount] = await Promise.all([
baseQuery
.limit(pageSize)
.offset(pageSize * (page - 1))
@@ -210,6 +260,8 @@ export async function listAllSiteResourcesByOrg(
countQuery
]);
const siteResourcesList = siteResourcesRaw.map(transformSiteResourceRow);
return response<ListAllSiteResourcesByOrgResponse>(res, {
data: {
siteResources: siteResourcesList,
@@ -233,4 +285,4 @@ export async function listAllSiteResourcesByOrg(
)
);
}
}
}

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { db, networks, siteNetworks } from "@server/db";
import { siteResources, sites, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -108,13 +108,21 @@ export async function listSiteResources(
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Get site resources
// Get site resources by joining networks to siteResources via siteNetworks
const siteResourcesList = await db
.select()
.from(siteResources)
.from(siteNetworks)
.innerJoin(
networks,
eq(siteNetworks.networkId, networks.networkId)
)
.innerJoin(
siteResources,
eq(siteResources.networkId, networks.networkId)
)
.where(
and(
eq(siteResources.siteId, siteId),
eq(siteNetworks.siteId, siteId),
eq(siteResources.orgId, orgId)
)
)
@@ -128,6 +136,7 @@ export async function listSiteResources(
.limit(limit)
.offset(offset);
return response(res, {
data: { siteResources: siteResourcesList },
success: true,

View File

@@ -1,4 +1,3 @@
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import {
clientSiteResources,
clientSiteResourcesAssociationsCache,
@@ -7,13 +6,21 @@ import {
orgs,
roles,
roleSiteResources,
siteNetworks,
SiteResource,
siteResources,
sites,
networks,
Transaction,
userSiteResources
} from "@server/db";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { validateAndConstructDomain } from "@server/lib/domainUtils";
import response from "@server/lib/response";
import { eq, and, ne, inArray } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { updatePeerData, updateTargets } from "@server/routers/client/targets";
import {
generateAliasConfig,
generateRemoteSubnets,
@@ -22,12 +29,8 @@ import {
portRangeStringSchema
} from "@server/lib/ip";
import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
import response from "@server/lib/response";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { updatePeerData, updateTargets } from "@server/routers/client/targets";
import HttpCode from "@server/types/HttpCode";
import { and, eq, ne } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
@@ -40,7 +43,8 @@ const updateSiteResourceParamsSchema = z.strictObject({
const updateSiteResourceSchema = z
.strictObject({
name: z.string().min(1).max(255).optional(),
siteId: z.int(),
siteIds: z.array(z.int()),
// niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
niceId: z
.string()
.min(1)
@@ -51,10 +55,11 @@ const updateSiteResourceSchema = z
)
.optional(),
// mode: z.enum(["host", "cidr", "port"]).optional(),
mode: z.enum(["host", "cidr"]).optional(),
// protocol: z.enum(["tcp", "udp"]).nullish(),
mode: z.enum(["host", "cidr", "http"]).optional(),
ssl: z.boolean().optional(),
scheme: z.enum(["http", "https"]).nullish(),
// proxyPort: z.int().positive().nullish(),
// destinationPort: z.int().positive().nullish(),
destinationPort: z.int().positive().nullish(),
destination: z.string().min(1).optional(),
enabled: z.boolean().optional(),
alias: z
@@ -71,7 +76,9 @@ const updateSiteResourceSchema = z
udpPortRangeString: portRangeStringSchema,
disableIcmp: z.boolean().optional(),
authDaemonPort: z.int().positive().nullish(),
authDaemonMode: z.enum(["site", "remote"]).optional()
authDaemonMode: z.enum(["site", "remote"]).optional(),
domainId: z.string().optional(),
subdomain: z.string().optional()
})
.strict()
.refine(
@@ -118,6 +125,23 @@ const updateSiteResourceSchema = z
{
message: "Destination must be a valid CIDR notation for cidr mode"
}
)
.refine(
(data) => {
if (data.mode !== "http") return true;
return (
data.scheme !== undefined &&
data.scheme !== null &&
data.destinationPort !== undefined &&
data.destinationPort !== null &&
data.destinationPort >= 1 &&
data.destinationPort <= 65535
);
},
{
message:
"HTTP mode requires scheme (http or https) and a valid destination port"
}
);
export type UpdateSiteResourceBody = z.infer<typeof updateSiteResourceSchema>;
@@ -172,11 +196,14 @@ export async function updateSiteResource(
const { siteResourceId } = parsedParams.data;
const {
name,
siteId, // because it can change
siteIds, // because it can change
niceId,
mode,
scheme,
destination,
destinationPort,
alias,
ssl,
enabled,
userIds,
roleIds,
@@ -185,19 +212,11 @@ export async function updateSiteResource(
udpPortRangeString,
disableIcmp,
authDaemonPort,
authDaemonMode
authDaemonMode,
domainId,
subdomain
} = parsedBody.data;
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Check if site resource exists
const [existingSiteResource] = await db
.select()
@@ -211,6 +230,21 @@ export async function updateSiteResource(
);
}
if (mode == "http") {
const hasHttpFeature = await isLicensedOrSubscribed(
existingSiteResource.orgId,
tierMatrix[TierFeature.HTTPPrivateResources]
);
if (!hasHttpFeature) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"HTTP private resources are not included in your current plan. Please upgrade."
)
);
}
}
const isLicensedSshPam = await isLicensedOrSubscribed(
existingSiteResource.orgId,
tierMatrix.sshPam
@@ -237,6 +271,23 @@ export async function updateSiteResource(
);
}
// Verify the site exists and belongs to the org
const sitesToAssign = await db
.select()
.from(sites)
.where(
and(
inArray(sites.siteId, siteIds),
eq(sites.orgId, existingSiteResource.orgId)
)
);
if (sitesToAssign.length !== siteIds.length) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
}
// Only check if destination is an IP address
const isIp = z
.union([z.ipv4(), z.ipv6()])
@@ -254,22 +305,60 @@ export async function updateSiteResource(
);
}
let existingSite = site;
let siteChanged = false;
if (existingSiteResource.siteId !== siteId) {
siteChanged = true;
// get the existing site
[existingSite] = await db
.select()
.from(sites)
.where(eq(sites.siteId, existingSiteResource.siteId))
.limit(1);
let sitesChanged = false;
const existingSiteIds = existingSiteResource.networkId
? await db
.select()
.from(siteNetworks)
.where(
eq(siteNetworks.networkId, existingSiteResource.networkId)
)
: [];
if (!existingSite) {
const existingSiteIdSet = new Set(existingSiteIds.map((s) => s.siteId));
const newSiteIdSet = new Set(siteIds);
if (
existingSiteIdSet.size !== newSiteIdSet.size ||
![...existingSiteIdSet].every((id) => newSiteIdSet.has(id))
) {
sitesChanged = true;
}
let fullDomain: string | null = null;
let finalSubdomain: string | null = null;
if (domainId) {
// Validate domain and construct full domain
const domainResult = await validateAndConstructDomain(
domainId,
org.orgId,
subdomain
);
if (!domainResult.success) {
return next(
createHttpError(HttpCode.BAD_REQUEST, domainResult.error)
);
}
fullDomain = domainResult.fullDomain;
finalSubdomain = domainResult.subdomain;
// make sure the full domain is unique
const [existingDomain] = await db
.select()
.from(siteResources)
.where(eq(siteResources.fullDomain, fullDomain));
if (
existingDomain &&
existingDomain.siteResourceId !==
existingSiteResource.siteResourceId
) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Existing site not found"
HttpCode.CONFLICT,
"Resource with that domain already exists"
)
);
}
@@ -302,7 +391,7 @@ export async function updateSiteResource(
let updatedSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => {
// if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place
if (siteChanged) {
if (sitesChanged) {
// delete the existing site resource
await trx
.delete(siteResources)
@@ -343,15 +432,20 @@ export async function updateSiteResource(
.update(siteResources)
.set({
name,
siteId,
niceId,
mode,
scheme,
ssl,
destination,
destinationPort,
enabled,
alias: alias && alias.trim() ? alias : null,
alias: alias ? alias.trim() : null,
tcpPortRangeString,
udpPortRangeString,
disableIcmp,
domainId,
subdomain: finalSubdomain,
fullDomain,
...sshPamSet
})
.where(
@@ -372,6 +466,23 @@ export async function updateSiteResource(
//////////////////// update the associations ////////////////////
// delete the site - site resources associations
await trx
.delete(siteNetworks)
.where(
eq(
siteNetworks.networkId,
updatedSiteResource.networkId!
)
);
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: updatedSiteResource.networkId!
});
}
const [adminRole] = await trx
.select()
.from(roles)
@@ -447,14 +558,20 @@ export async function updateSiteResource(
.update(siteResources)
.set({
name: name,
siteId: siteId,
niceId: niceId,
mode: mode,
scheme,
ssl,
destination: destination,
destinationPort: destinationPort,
enabled: enabled,
alias: alias && alias.trim() ? alias : null,
alias: alias ? alias.trim() : null,
tcpPortRangeString: tcpPortRangeString,
udpPortRangeString: udpPortRangeString,
disableIcmp: disableIcmp,
domainId,
subdomain: finalSubdomain,
fullDomain,
...sshPamSet
})
.where(
@@ -464,6 +581,23 @@ export async function updateSiteResource(
//////////////////// update the associations ////////////////////
// delete the site - site resources associations
await trx
.delete(siteNetworks)
.where(
eq(
siteNetworks.networkId,
updatedSiteResource.networkId!
)
);
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: updatedSiteResource.networkId!
});
}
await trx
.delete(clientSiteResources)
.where(
@@ -533,14 +667,15 @@ export async function updateSiteResource(
);
}
logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}`
);
logger.info(`Updated site resource ${siteResourceId}`);
await handleMessagingForUpdatedSiteResource(
existingSiteResource,
updatedSiteResource,
{ siteId: site.siteId, orgId: site.orgId },
siteIds.map((siteId) => ({
siteId,
orgId: existingSiteResource.orgId
})),
trx
);
}
@@ -567,7 +702,7 @@ export async function updateSiteResource(
export async function handleMessagingForUpdatedSiteResource(
existingSiteResource: SiteResource | undefined,
updatedSiteResource: SiteResource,
site: { siteId: number; orgId: string },
sites: { siteId: number; orgId: string }[],
trx: Transaction
) {
logger.debug(
@@ -589,9 +724,14 @@ export async function handleMessagingForUpdatedSiteResource(
const destinationChanged =
existingSiteResource &&
existingSiteResource.destination !== updatedSiteResource.destination;
const destinationPortChanged =
existingSiteResource &&
existingSiteResource.destinationPort !==
updatedSiteResource.destinationPort;
const aliasChanged =
existingSiteResource &&
existingSiteResource.alias !== updatedSiteResource.alias;
(existingSiteResource.alias !== updatedSiteResource.alias ||
existingSiteResource.fullDomain !== updatedSiteResource.fullDomain); // because the full domain gets sent down to the stuff as an alias
const portRangesChanged =
existingSiteResource &&
(existingSiteResource.tcpPortRangeString !==
@@ -603,106 +743,122 @@ export async function handleMessagingForUpdatedSiteResource(
// if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all
if (destinationChanged || aliasChanged || portRangesChanged) {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) {
throw new Error(
"Newt not found for site during site resource update"
);
}
// Only update targets on newt if destination changed
if (destinationChanged || portRangesChanged) {
const oldTargets = generateSubnetProxyTargetV2(
existingSiteResource,
mergedAllClients
);
const newTargets = generateSubnetProxyTargetV2(
updatedSiteResource,
mergedAllClients
);
await updateTargets(
newt.newtId,
{
oldTargets: oldTargets ? oldTargets : [],
newTargets: newTargets ? newTargets : []
},
newt.version
);
}
const olmJobs: Promise<void>[] = [];
for (const client of mergedAllClients) {
// does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet
// todo: optimize this query if needed
const oldDestinationStillInUseSites = await trx
if (
destinationChanged ||
aliasChanged ||
portRangesChanged ||
destinationPortChanged
) {
for (const site of sites) {
const [newt] = await trx
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
siteResources.siteResourceId
)
)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
),
eq(siteResources.siteId, site.siteId),
eq(
siteResources.destination,
existingSiteResource.destination
),
ne(
siteResources.siteResourceId,
existingSiteResource.siteResourceId
)
)
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) {
throw new Error(
"Newt not found for site during site resource update"
);
}
// Only update targets on newt if destination changed
if (
destinationChanged ||
portRangesChanged ||
destinationPortChanged
) {
const oldTargets = await generateSubnetProxyTargetV2(
existingSiteResource,
mergedAllClients
);
const newTargets = await generateSubnetProxyTargetV2(
updatedSiteResource,
mergedAllClients
);
const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0;
await updateTargets(
newt.newtId,
{
oldTargets: oldTargets ? oldTargets : [],
newTargets: newTargets ? newTargets : []
},
newt.version
);
}
// we also need to update the remote subnets on the olms for each client that has access to this site
olmJobs.push(
updatePeerData(
client.clientId,
updatedSiteResource.siteId,
destinationChanged
? {
oldRemoteSubnets: !oldDestinationStillInUseByASite
? generateRemoteSubnets([
existingSiteResource
])
: [],
newRemoteSubnets: generateRemoteSubnets([
updatedSiteResource
])
}
: undefined,
aliasChanged
? {
oldAliases: generateAliasConfig([
existingSiteResource
]),
newAliases: generateAliasConfig([
updatedSiteResource
])
}
: undefined
)
);
const olmJobs: Promise<void>[] = [];
for (const client of mergedAllClients) {
// does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet
// todo: optimize this query if needed
const oldDestinationStillInUseSites = await trx
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
siteResources.siteResourceId
)
)
.innerJoin(
siteNetworks,
eq(siteNetworks.networkId, siteResources.networkId)
)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
),
eq(siteNetworks.siteId, site.siteId),
eq(
siteResources.destination,
existingSiteResource.destination
),
ne(
siteResources.siteResourceId,
existingSiteResource.siteResourceId
)
)
);
const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0;
// we also need to update the remote subnets on the olms for each client that has access to this site
olmJobs.push(
updatePeerData(
client.clientId,
site.siteId,
destinationChanged
? {
oldRemoteSubnets:
!oldDestinationStillInUseByASite
? generateRemoteSubnets([
existingSiteResource
])
: [],
newRemoteSubnets: generateRemoteSubnets([
updatedSiteResource
])
}
: undefined,
aliasChanged
? {
oldAliases: generateAliasConfig([
existingSiteResource
]),
newAliases: generateAliasConfig([
updatedSiteResource
])
}
: undefined
)
);
}
await Promise.all(olmJobs);
}
await Promise.all(olmJobs);
}
}

View File

@@ -2,7 +2,7 @@ import { build } from "@server/build";
import {
handleNewtRegisterMessage,
handleReceiveBandwidthMessage,
handleGetConfigMessage,
handleNewtGetConfigMessage,
handleDockerStatusMessage,
handleDockerContainersMessage,
handleNewtPingRequestMessage,
@@ -37,7 +37,7 @@ export const messageHandlers: Record<string, MessageHandler> = {
"newt/disconnecting": handleNewtDisconnectingMessage,
"newt/ping": handleNewtPingMessage,
"newt/wg/register": handleNewtRegisterMessage,
"newt/wg/get-config": handleGetConfigMessage,
"newt/wg/get-config": handleNewtGetConfigMessage,
"newt/receive-bandwidth": handleReceiveBandwidthMessage,
"newt/socket/status": handleDockerStatusMessage,
"newt/socket/containers": handleDockerContainersMessage,