From 7d8185e0ee08f03082fdef891ba6371f6032ac04 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 9 Feb 2026 17:05:14 -0800 Subject: [PATCH] Getting swtiching tiers to work --- server/db/pg/schema/privateSchema.ts | 1 + server/db/sqlite/schema/privateSchema.ts | 1 + server/lib/billing/features.ts | 20 ++++++++ server/private/routers/billing/changeTier.ts | 9 ++-- .../hooks/handleSubscriptionCreated.ts | 5 ++ .../hooks/handleSubscriptionUpdated.ts | 50 +++++++++++++------ .../settings/(private)/billing/page.tsx | 15 +++++- 7 files changed, 81 insertions(+), 20 deletions(-) diff --git a/server/db/pg/schema/privateSchema.ts b/server/db/pg/schema/privateSchema.ts index de5bb1ca..9d493ed9 100644 --- a/server/db/pg/schema/privateSchema.ts +++ b/server/db/pg/schema/privateSchema.ts @@ -97,6 +97,7 @@ export const subscriptionItems = pgTable("subscriptionItems", { }), planId: varchar("planId", { length: 255 }).notNull(), priceId: varchar("priceId", { length: 255 }), + featureId: varchar("featureId", { length: 255 }), meterId: varchar("meterId", { length: 255 }), unitAmount: real("unitAmount"), tiers: text("tiers"), diff --git a/server/db/sqlite/schema/privateSchema.ts b/server/db/sqlite/schema/privateSchema.ts index 1fa8654b..2571a65a 100644 --- a/server/db/sqlite/schema/privateSchema.ts +++ b/server/db/sqlite/schema/privateSchema.ts @@ -86,6 +86,7 @@ export const subscriptionItems = sqliteTable("subscriptionItems", { }), planId: text("planId").notNull(), priceId: text("priceId"), + featureId: text("featureId"), meterId: text("meterId"), unitAmount: real("unitAmount"), tiers: text("tiers"), diff --git a/server/lib/billing/features.ts b/server/lib/billing/features.ts index a9b652a9..82ba0676 100644 --- a/server/lib/billing/features.ts +++ b/server/lib/billing/features.ts @@ -116,6 +116,26 @@ export function getScaleFeaturePriceSet(): FeaturePriceSet { } } +export function getFeatureIdByPriceId(priceId: string): FeatureId | undefined { + // Check all feature price sets + const allPriceSets = [ + getHomeLabFeaturePriceSet(), + getStarterFeaturePriceSet(), + getScaleFeaturePriceSet() + ]; + + for (const priceSet of allPriceSets) { + const entry = (Object.entries(priceSet) as [FeatureId, string][]).find( + ([_, price]) => price === priceId + ); + if (entry) { + return entry[0]; + } + } + + return undefined; +} + export async function getLineItems( featurePriceSet: FeaturePriceSet, orgId: string, diff --git a/server/private/routers/billing/changeTier.ts b/server/private/routers/billing/changeTier.ts index a33a9164..5d67b7e8 100644 --- a/server/private/routers/billing/changeTier.ts +++ b/server/private/routers/billing/changeTier.ts @@ -206,7 +206,8 @@ export async function changeTier( // Keep the existing item unchanged if we can't find it return { id: stripeItem.id, - price: stripeItem.price.id + price: stripeItem.price.id, + quantity: stripeItem.quantity }; } @@ -216,14 +217,16 @@ export async function changeTier( if (newPriceId) { return { id: stripeItem.id, - price: newPriceId + price: newPriceId, + quantity: stripeItem.quantity }; } // If no mapping found, keep existing return { id: stripeItem.id, - price: stripeItem.price.id + price: stripeItem.price.id, + quantity: stripeItem.quantity }; } ); diff --git a/server/private/routers/billing/hooks/handleSubscriptionCreated.ts b/server/private/routers/billing/hooks/handleSubscriptionCreated.ts index 16b64145..773ffbae 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionCreated.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionCreated.ts @@ -31,6 +31,7 @@ import { getLicensePriceSet, LicenseId } from "@server/lib/billing/licenses"; import { sendEmail } from "@server/emails"; import EnterpriseEditionKeyGenerated from "@server/emails/templates/EnterpriseEditionKeyGenerated"; import config from "@server/lib/config"; +import { getFeatureIdByPriceId } from "@server/lib/billing/features"; export async function handleSubscriptionCreated( subscription: Stripe.Subscription @@ -91,11 +92,15 @@ export async function handleSubscriptionCreated( name = product.name || null; } + // Get the feature ID from the price ID + const featureId = getFeatureIdByPriceId(item.price.id); + return { stripeSubscriptionItemId: item.id, subscriptionId: subscription.id, planId: item.plan.id, priceId: item.price.id, + featureId: featureId || null, meterId: item.plan.meter, unitAmount: item.price.unit_amount || 0, currentPeriodStart: item.current_period_start, diff --git a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts index 4288d3c4..83472ac0 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts @@ -23,7 +23,7 @@ import { } from "@server/db"; import { eq, and } from "drizzle-orm"; import logger from "@server/logger"; -import { getFeatureIdByMetricId } from "@server/lib/billing/features"; +import { getFeatureIdByMetricId, getFeatureIdByPriceId } from "@server/lib/billing/features"; import stripe from "#private/lib/stripe"; import { handleSubscriptionLifesycle } from "../subscriptionLifecycle"; import { getSubType } from "./getSubType"; @@ -81,20 +81,40 @@ export async function handleSubscriptionUpdated( // Upsert subscription items if (Array.isArray(fullSubscription.items?.data)) { - const itemsToUpsert = fullSubscription.items.data.map((item) => ({ - stripeSubscriptionItemId: item.id, - subscriptionId: subscription.id, - planId: item.plan.id, - priceId: item.price.id, - meterId: item.plan.meter, - unitAmount: item.price.unit_amount || 0, - currentPeriodStart: item.current_period_start, - currentPeriodEnd: item.current_period_end, - tiers: item.price.tiers - ? JSON.stringify(item.price.tiers) - : null, - interval: item.plan.interval - })); + // First, get existing items to preserve featureId when there's no match + const existingItems = await db + .select() + .from(subscriptionItems) + .where(eq(subscriptionItems.subscriptionId, subscription.id)); + + const itemsToUpsert = fullSubscription.items.data.map((item) => { + // Try to get featureId from price + let featureId: string | null = getFeatureIdByPriceId(item.price.id) || null; + + // If no match, try to preserve existing featureId + if (!featureId) { + const existingItem = existingItems.find( + (ei) => ei.stripeSubscriptionItemId === item.id + ); + featureId = existingItem?.featureId || null; + } + + return { + stripeSubscriptionItemId: item.id, + subscriptionId: subscription.id, + planId: item.plan.id, + priceId: item.price.id, + featureId: featureId, + meterId: item.plan.meter, + unitAmount: item.price.unit_amount || 0, + currentPeriodStart: item.current_period_start, + currentPeriodEnd: item.current_period_end, + tiers: item.price.tiers + ? JSON.stringify(item.price.tiers) + : null, + interval: item.plan.interval + }; + }); if (itemsToUpsert.length > 0) { await db.transaction(async (trx) => { await trx diff --git a/src/app/[orgId]/settings/(private)/billing/page.tsx b/src/app/[orgId]/settings/(private)/billing/page.tsx index d0002cba..0716aa60 100644 --- a/src/app/[orgId]/settings/(private)/billing/page.tsx +++ b/src/app/[orgId]/settings/(private)/billing/page.tsx @@ -453,8 +453,19 @@ export default function BillingPage() { // Calculate current usage cost for display const getUserCount = () => getUsageValue(USERS); const getPricePerUser = () => { - if (currentTier === "tier2") return 5; - if (currentTier === "tier3") return 10; + console.log("Calculating price per user, tierSubscription:", tierSubscription); + if (!tierSubscription?.items) return 0; + + // Find the subscription item for USERS feature + const usersItem = tierSubscription.items.find( + (item) => item.planId === USERS + ); + + // unitAmount is in cents, convert to dollars + if (usersItem?.unitAmount) { + return usersItem.unitAmount / 100; + } + return 0; };