[backport cloud/1.35] feat: pass target tier to billing portal for subscription updates (#7726)

Backport of #7692 to `cloud/1.35`

Automatically created by backport workflow.

┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-7726-backport-cloud-1-35-feat-pass-target-tier-to-billing-portal-for-subscription-updates-2d16d73d36508173acadf20aa6d97017)
by [Unito](https://www.unito.io)

Co-authored-by: Hunter <huntcsg@users.noreply.github.com>
Co-authored-by: Christian Byrne <cbyrne@comfy.org>
Co-authored-by: GitHub Action <action@github.com>
This commit is contained in:
Comfy Org PR Bot
2025-12-23 03:59:29 +09:00
committed by GitHub
parent 73cc7c6b04
commit 5069a4a272
6 changed files with 338 additions and 9 deletions

View File

@@ -11910,6 +11910,8 @@ export interface operations {
"application/json": {
/** @description Optional URL to redirect the customer after they're done with the billing portal */
return_url?: string;
/** @description Optional target subscription tier. When provided, creates a deep link directly to the subscription update confirmation screen with this tier pre-selected. */
target_tier?: "standard" | "creator" | "pro" | "standard-yearly" | "creator-yearly" | "pro-yearly";
};
};
};

View File

@@ -11,6 +11,7 @@ import { useTelemetry } from '@/platform/telemetry'
import { useToastStore } from '@/platform/updates/common/toastStore'
import { useDialogService } from '@/services/dialogService'
import { useFirebaseAuthStore } from '@/stores/firebaseAuthStore'
import type { BillingPortalTargetTier } from '@/stores/firebaseAuthStore'
import { usdToMicros } from '@/utils/formatUtil'
/**
@@ -102,8 +103,11 @@ export const useFirebaseAuthActions = () => {
window.open(response.checkout_url, '_blank')
}, reportError)
const accessBillingPortal = wrapWithErrorHandlingAsync(async () => {
const response = await authStore.accessBillingPortal()
const accessBillingPortal = wrapWithErrorHandlingAsync<
[targetTier?: BillingPortalTargetTier],
void
>(async (targetTier) => {
const response = await authStore.accessBillingPortal(targetTier)
if (!response.billing_portal_url) {
throw new Error(
t('toastMessages.failedToAccessBillingPortal', {

View File

@@ -333,7 +333,7 @@ const { n } = useI18n()
const { getAuthHeader } = useFirebaseAuthStore()
const { isActiveSubscription, subscriptionTier, isYearlySubscription } =
useSubscription()
const { reportError } = useFirebaseAuthActions()
const { accessBillingPortal, reportError } = useFirebaseAuthActions()
const { wrapWithErrorHandlingAsync } = useErrorHandling()
const isLoading = ref(false)
@@ -443,9 +443,15 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
loadingTier.value = tierKey
try {
const response = await initiateCheckout(tierKey)
if (response.checkout_url) {
window.open(response.checkout_url, '_blank')
if (isActiveSubscription.value) {
// Pass the target tier to create a deep link to subscription update confirmation
const checkoutTier = getCheckoutTier(tierKey, currentBillingCycle.value)
await accessBillingPortal(checkoutTier)
} else {
const response = await initiateCheckout(tierKey)
if (response.checkout_url) {
window.open(response.checkout_url, '_blank')
}
}
} finally {
isLoading.value = false

View File

@@ -42,6 +42,11 @@ type AccessBillingPortalResponse =
operations['AccessBillingPortal']['responses']['200']['content']['application/json']
type AccessBillingPortalReqBody =
operations['AccessBillingPortal']['requestBody']
export type BillingPortalTargetTier = NonNullable<
NonNullable<
NonNullable<AccessBillingPortalReqBody>['content']
>['application/json']
>['target_tier']
export class FirebaseAuthStoreError extends Error {
constructor(message: string) {
@@ -409,13 +414,15 @@ export const useFirebaseAuthStore = defineStore('firebaseAuth', () => {
executeAuthAction((_) => addCredits(requestBodyContent))
const accessBillingPortal = async (
requestBody?: AccessBillingPortalReqBody
targetTier?: BillingPortalTargetTier
): Promise<AccessBillingPortalResponse> => {
const authHeader = await getAuthHeader()
if (!authHeader) {
throw new FirebaseAuthStoreError(t('toastMessages.userNotAuthenticated'))
}
const requestBody = targetTier ? { target_tier: targetTier } : undefined
const response = await fetch(buildApiUrl('/customers/billing'), {
method: 'POST',
headers: {

View File

@@ -0,0 +1,237 @@
import { createTestingPinia } from '@pinia/testing'
import { flushPromises, mount } from '@vue/test-utils'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { computed, ref } from 'vue'
import { createI18n } from 'vue-i18n'
import PricingTable from '@/platform/cloud/subscription/components/PricingTable.vue'
const mockIsActiveSubscription = ref(false)
const mockSubscriptionTier = ref<
'STANDARD' | 'CREATOR' | 'PRO' | 'FOUNDERS_EDITION' | null
>(null)
const mockIsYearlySubscription = ref(false)
const mockAccessBillingPortal = vi.fn()
const mockReportError = vi.fn()
const mockGetAuthHeader = vi.fn(() =>
Promise.resolve({ Authorization: 'Bearer test-token' })
)
vi.mock('@/platform/cloud/subscription/composables/useSubscription', () => ({
useSubscription: () => ({
isActiveSubscription: computed(() => mockIsActiveSubscription.value),
subscriptionTier: computed(() => mockSubscriptionTier.value),
isYearlySubscription: computed(() => mockIsYearlySubscription.value)
})
}))
vi.mock('@/composables/auth/useFirebaseAuthActions', () => ({
useFirebaseAuthActions: () => ({
accessBillingPortal: mockAccessBillingPortal,
reportError: mockReportError
})
}))
vi.mock('@/composables/useErrorHandling', () => ({
useErrorHandling: () => ({
wrapWithErrorHandlingAsync: vi.fn(
(fn, errorHandler) =>
async (...args: unknown[]) => {
try {
return await fn(...args)
} catch (error) {
if (errorHandler) {
errorHandler(error)
}
throw error
}
}
)
})
}))
vi.mock('@/stores/firebaseAuthStore', () => ({
useFirebaseAuthStore: () => ({
getAuthHeader: mockGetAuthHeader
}),
FirebaseAuthStoreError: class extends Error {}
}))
vi.mock('@/platform/distribution/types', () => ({
isCloud: true
}))
global.fetch = vi.fn()
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: {
en: {
subscription: {
yearly: 'Yearly',
monthly: 'Monthly',
mostPopular: 'Most Popular',
usdPerMonth: '/ month',
billedYearly: 'Billed yearly ({total})',
billedMonthly: 'Billed monthly',
currentPlan: 'Current Plan',
subscribeTo: 'Subscribe to {plan}',
changeTo: 'Change to {plan}',
maxDuration: {
standard: '30 min',
creator: '30 min',
pro: '1 hr'
},
tiers: {
standard: { name: 'Standard' },
creator: { name: 'Creator' },
pro: { name: 'Pro' }
},
benefits: {
monthlyCredits: '{credits} monthly credits',
maxDuration: '{duration} max duration',
gpu: 'RTX 6000 Pro GPU',
addCredits: 'Add more credits anytime',
customLoRAs: 'Import custom LoRAs'
}
}
}
}
})
function createWrapper() {
return mount(PricingTable, {
global: {
plugins: [createTestingPinia({ createSpy: vi.fn }), i18n],
stubs: {
SelectButton: {
template: '<div><slot /></div>',
props: ['modelValue', 'options'],
emits: ['update:modelValue']
},
Popover: { template: '<div><slot /></div>' },
Button: {
template:
'<button @click="$emit(\'click\')" :disabled="disabled" :data-tier="dataTier">{{ label }}</button>',
props: ['loading', 'label', 'severity', 'disabled', 'dataTier', 'pt'],
emits: ['click']
}
}
}
})
}
describe('PricingTable', () => {
beforeEach(() => {
vi.clearAllMocks()
mockIsActiveSubscription.value = false
mockSubscriptionTier.value = null
mockIsYearlySubscription.value = false
vi.mocked(global.fetch).mockResolvedValue({
ok: true,
json: async () => ({ checkout_url: 'https://checkout.stripe.com/test' })
} as Response)
})
describe('billing portal deep linking', () => {
it('should call accessBillingPortal with yearly tier suffix when billing cycle is yearly (default)', async () => {
mockIsActiveSubscription.value = true
mockSubscriptionTier.value = 'STANDARD'
const wrapper = createWrapper()
await flushPromises()
const creatorButton = wrapper
.findAll('button')
.find((btn) => btn.text().includes('Creator'))
expect(creatorButton).toBeDefined()
await creatorButton?.trigger('click')
await flushPromises()
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
})
it('should call accessBillingPortal with different tiers correctly', async () => {
mockIsActiveSubscription.value = true
mockSubscriptionTier.value = 'STANDARD'
const wrapper = createWrapper()
await flushPromises()
const proButton = wrapper
.findAll('button')
.find((btn) => btn.text().includes('Pro'))
await proButton?.trigger('click')
await flushPromises()
expect(mockAccessBillingPortal).toHaveBeenCalledWith('pro-yearly')
})
it('should not call accessBillingPortal when clicking current plan', async () => {
mockIsActiveSubscription.value = true
mockSubscriptionTier.value = 'CREATOR'
const wrapper = createWrapper()
await flushPromises()
const currentPlanButton = wrapper
.findAll('button')
.find((btn) => btn.text().includes('Current Plan'))
await currentPlanButton?.trigger('click')
await flushPromises()
expect(mockAccessBillingPortal).not.toHaveBeenCalled()
})
it('should initiate checkout instead of billing portal for new subscribers', async () => {
mockIsActiveSubscription.value = false
const windowOpenSpy = vi
.spyOn(window, 'open')
.mockImplementation(() => null)
const wrapper = createWrapper()
await flushPromises()
const subscribeButton = wrapper
.findAll('button')
.find((btn) => btn.text().includes('Subscribe'))
await subscribeButton?.trigger('click')
await flushPromises()
expect(mockAccessBillingPortal).not.toHaveBeenCalled()
expect(global.fetch).toHaveBeenCalledWith(
expect.stringContaining('/customers/cloud-subscription-checkout/'),
expect.any(Object)
)
expect(windowOpenSpy).toHaveBeenCalledWith(
'https://checkout.stripe.com/test',
'_blank'
)
windowOpenSpy.mockRestore()
})
it('should pass correct tier for each subscription level', async () => {
mockIsActiveSubscription.value = true
mockSubscriptionTier.value = 'PRO'
const wrapper = createWrapper()
await flushPromises()
const standardButton = wrapper
.findAll('button')
.find((btn) => btn.text().includes('Standard'))
await standardButton?.trigger('click')
await flushPromises()
expect(mockAccessBillingPortal).toHaveBeenCalledWith('standard-yearly')
})
})
})

View File

@@ -30,7 +30,9 @@ const mockAddCreditsResponse = {
const mockAccessBillingPortalResponse = {
ok: true,
statusText: 'OK'
statusText: 'OK',
json: () =>
Promise.resolve({ billing_portal_url: 'https://billing.stripe.com/test' })
}
vi.mock('vuefire', () => ({
@@ -129,7 +131,7 @@ describe('useFirebaseAuthStore', () => {
if (url.endsWith('/customers/credit')) {
return Promise.resolve(mockAddCreditsResponse)
}
if (url.endsWith('/customers/billing-portal')) {
if (url.endsWith('/customers/billing')) {
return Promise.resolve(mockAccessBillingPortalResponse)
}
return Promise.reject(new Error('Unexpected API call'))
@@ -542,4 +544,75 @@ describe('useFirebaseAuthStore', () => {
expect(store.loading).toBe(false)
})
})
describe('accessBillingPortal', () => {
it('should call billing endpoint without body when no targetTier provided', async () => {
const result = await store.accessBillingPortal()
expect(mockFetch).toHaveBeenCalledWith(
expect.stringContaining('/customers/billing'),
expect.objectContaining({
method: 'POST',
headers: expect.objectContaining({
Authorization: 'Bearer mock-id-token',
'Content-Type': 'application/json'
})
})
)
const callArgs = mockFetch.mock.calls.find((call) =>
(call[0] as string).endsWith('/customers/billing')
)
expect(callArgs?.[1]).not.toHaveProperty('body')
expect(result).toEqual({
billing_portal_url: 'https://billing.stripe.com/test'
})
})
it('should include target_tier in request body when targetTier provided', async () => {
await store.accessBillingPortal('creator')
const callArgs = mockFetch.mock.calls.find((call) =>
(call[0] as string).endsWith('/customers/billing')
)
expect(callArgs?.[1]).toHaveProperty('body')
expect(JSON.parse(callArgs?.[1]?.body as string)).toEqual({
target_tier: 'creator'
})
})
it('should handle different checkout tier formats', async () => {
const tiers = [
'standard',
'creator',
'pro',
'standard-yearly',
'creator-yearly',
'pro-yearly'
] as const
for (const tier of tiers) {
mockFetch.mockClear()
await store.accessBillingPortal(tier)
const callArgs = mockFetch.mock.calls.find((call) =>
(call[0] as string).endsWith('/customers/billing')
)
expect(JSON.parse(callArgs?.[1]?.body as string)).toEqual({
target_tier: tier
})
}
})
it('should throw error when API returns error response', async () => {
mockFetch.mockImplementationOnce(() =>
Promise.resolve({
ok: false,
json: () => Promise.resolve({ message: 'Billing portal unavailable' })
})
)
await expect(store.accessBillingPortal()).rejects.toThrow()
})
})
})