diff --git a/server/routers/idp/validateOidcCallback.ts b/server/routers/idp/validateOidcCallback.ts index 7d756a4a..e3462185 100644 --- a/server/routers/idp/validateOidcCallback.ts +++ b/server/routers/idp/validateOidcCallback.ts @@ -36,6 +36,10 @@ import { build } from "@server/build"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { isSubscribed } from "#dynamic/lib/isSubscribed"; import { tierMatrix } from "@server/lib/billing/tierMatrix"; +import { + assignUserToOrg, + removeUserFromOrg +} from "@server/lib/userOrg"; const ensureTrailingSlash = (url: string): string => { return url; @@ -455,15 +459,32 @@ export async function validateOidcCallback( ); if (!existingUserOrgs.length) { - // delete all auto -provisioned user orgs - await db - .delete(userOrgs) + // delete all auto-provisioned user orgs + const autoProvisionedUserOrgs = await db + .select() + .from(userOrgs) .where( and( eq(userOrgs.userId, existingUser.userId), eq(userOrgs.autoProvisioned, true) ) ); + const orgIdsToRemove = autoProvisionedUserOrgs.map( + (uo) => uo.orgId + ); + if (orgIdsToRemove.length > 0) { + const orgsToRemove = await db + .select() + .from(orgs) + .where(inArray(orgs.orgId, orgIdsToRemove)); + for (const org of orgsToRemove) { + await removeUserFromOrg( + org, + existingUser.userId, + db + ); + } + } await calculateUserClientsForOrgs(existingUser.userId); @@ -485,7 +506,7 @@ export async function validateOidcCallback( } } - const orgUserCounts: { orgId: string; userCount: number }[] = []; + const orgUserCounts: { orgId: string; userCount: number }[] = []; // sync the user with the orgs and roles await db.transaction(async (trx) => { @@ -539,15 +560,14 @@ export async function validateOidcCallback( ); if (orgsToDelete.length > 0) { - await trx.delete(userOrgs).where( - and( - eq(userOrgs.userId, userId!), - inArray( - userOrgs.orgId, - orgsToDelete.map((org) => org.orgId) - ) - ) - ); + const orgIdsToRemove = orgsToDelete.map((org) => org.orgId); + const fullOrgsToRemove = await trx + .select() + .from(orgs) + .where(inArray(orgs.orgId, orgIdsToRemove)); + for (const org of fullOrgsToRemove) { + await removeUserFromOrg(org, userId!, trx); + } } // Update roles for existing auto-provisioned orgs where the role has changed @@ -588,15 +608,24 @@ export async function validateOidcCallback( ); if (orgsToAdd.length > 0) { - await trx.insert(userOrgs).values( - orgsToAdd.map((org) => ({ - userId: userId!, - orgId: org.orgId, - roleId: org.roleId, - autoProvisioned: true, - dateCreated: new Date().toISOString() - })) - ); + for (const org of orgsToAdd) { + const [fullOrg] = await trx + .select() + .from(orgs) + .where(eq(orgs.orgId, org.orgId)); + if (fullOrg) { + await assignUserToOrg( + fullOrg, + { + orgId: org.orgId, + userId: userId!, + roleId: org.roleId, + autoProvisioned: true, + }, + trx + ); + } + } } // Loop through all the orgs and get the total number of users from the userOrgs table