diff --git a/server/routers/policy/setResourcePolicyRules.ts b/server/routers/policy/setResourcePolicyRules.ts index f15c1e51a..dfcfb7cdb 100644 --- a/server/routers/policy/setResourcePolicyRules.ts +++ b/server/routers/policy/setResourcePolicyRules.ts @@ -1,7 +1,7 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; import { db, resourcePolicyRules, resourcePolicies } from "@server/db"; -import { eq } from "drizzle-orm"; +import { and, eq, notInArray } from "drizzle-orm"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -14,6 +14,7 @@ import { import { OpenAPITags, registry } from "@server/openApi"; const ruleSchema = z.strictObject({ + ruleId: z.int().positive().optional(), action: z.enum(["ACCEPT", "DROP", "PASS"]).openapi({ type: "string", enum: ["ACCEPT", "DROP", "PASS"], @@ -121,17 +122,74 @@ export async function setResourcePolicyRules( .set({ applyRules }) .where(eq(resourcePolicies.resourcePolicyId, resourcePolicyId)); - await trx - .delete(resourcePolicyRules) - .where( - eq(resourcePolicyRules.resourcePolicyId, resourcePolicyId) - ); + const incomingRuleIds = rules + .map((r) => r.ruleId) + .filter((id): id is number => id !== undefined); - if (rules.length > 0) { + // Delete rules that are no longer in the incoming list + if (incomingRuleIds.length > 0) { + await trx + .delete(resourcePolicyRules) + .where( + and( + eq( + resourcePolicyRules.resourcePolicyId, + resourcePolicyId + ), + notInArray( + resourcePolicyRules.ruleId, + incomingRuleIds + ) + ) + ); + } else { + await trx + .delete(resourcePolicyRules) + .where( + eq( + resourcePolicyRules.resourcePolicyId, + resourcePolicyId + ) + ); + } + + // Update existing rules (those with a ruleId) + const existingRules = rules.filter( + (r): r is typeof r & { ruleId: number } => + r.ruleId !== undefined + ); + for (const rule of existingRules) { + await trx + .update(resourcePolicyRules) + .set({ + action: rule.action, + match: rule.match, + value: rule.value, + priority: rule.priority, + enabled: rule.enabled + }) + .where( + and( + eq(resourcePolicyRules.ruleId, rule.ruleId), + eq( + resourcePolicyRules.resourcePolicyId, + resourcePolicyId + ) + ) + ); + } + + // Insert new rules (those without a ruleId) + const newRules = rules.filter((r) => r.ruleId === undefined); + if (newRules.length > 0) { await trx.insert(resourcePolicyRules).values( - rules.map((rule) => ({ + newRules.map((rule) => ({ resourcePolicyId, - ...rule + action: rule.action, + match: rule.match, + value: rule.value, + priority: rule.priority, + enabled: rule.enabled })) ); } diff --git a/src/components/resource-policy/PolicyAccessRulesSection.tsx b/src/components/resource-policy/PolicyAccessRulesSection.tsx index 7a88cab0b..bd735f9b5 100644 --- a/src/components/resource-policy/PolicyAccessRulesSection.tsx +++ b/src/components/resource-policy/PolicyAccessRulesSection.tsx @@ -340,7 +340,8 @@ function PolicyAccessRulesSectionEdit({ ? rules.filter((rule) => !rule.fromPolicy) : rules; const rulesPayload = rulesToValidate.map( - ({ action, match, value, priority, enabled }) => ({ + ({ ruleId, action, match, value, priority, enabled, new: isNew }) => ({ + ...(isNew ? {} : { ruleId }), action, match, value,