diff --git a/packages/registry-types/src/comfyRegistryTypes.ts b/packages/registry-types/src/comfyRegistryTypes.ts index f814e64e9..f2d69fd29 100644 --- a/packages/registry-types/src/comfyRegistryTypes.ts +++ b/packages/registry-types/src/comfyRegistryTypes.ts @@ -11910,6 +11910,8 @@ export interface operations { "application/json": { /** @description Optional URL to redirect the customer after they're done with the billing portal */ return_url?: string; + /** @description Optional target subscription tier. When provided, creates a deep link directly to the subscription update confirmation screen with this tier pre-selected. */ + target_tier?: "standard" | "creator" | "pro" | "standard-yearly" | "creator-yearly" | "pro-yearly"; }; }; }; diff --git a/src/composables/auth/useFirebaseAuthActions.ts b/src/composables/auth/useFirebaseAuthActions.ts index eed7eb021..da6744c2d 100644 --- a/src/composables/auth/useFirebaseAuthActions.ts +++ b/src/composables/auth/useFirebaseAuthActions.ts @@ -11,6 +11,7 @@ import { useTelemetry } from '@/platform/telemetry' import { useToastStore } from '@/platform/updates/common/toastStore' import { useDialogService } from '@/services/dialogService' import { useFirebaseAuthStore } from '@/stores/firebaseAuthStore' +import type { BillingPortalTargetTier } from '@/stores/firebaseAuthStore' import { usdToMicros } from '@/utils/formatUtil' /** @@ -102,8 +103,11 @@ export const useFirebaseAuthActions = () => { window.open(response.checkout_url, '_blank') }, reportError) - const accessBillingPortal = wrapWithErrorHandlingAsync(async () => { - const response = await authStore.accessBillingPortal() + const accessBillingPortal = wrapWithErrorHandlingAsync< + [targetTier?: BillingPortalTargetTier], + void + >(async (targetTier) => { + const response = await authStore.accessBillingPortal(targetTier) if (!response.billing_portal_url) { throw new Error( t('toastMessages.failedToAccessBillingPortal', { diff --git a/src/platform/cloud/subscription/components/PricingTable.vue b/src/platform/cloud/subscription/components/PricingTable.vue index ea9266304..82ec83c97 100644 --- a/src/platform/cloud/subscription/components/PricingTable.vue +++ b/src/platform/cloud/subscription/components/PricingTable.vue @@ -333,7 +333,7 @@ const { n } = useI18n() const { getAuthHeader } = useFirebaseAuthStore() const { isActiveSubscription, subscriptionTier, isYearlySubscription } = useSubscription() -const { reportError } = useFirebaseAuthActions() +const { accessBillingPortal, reportError } = useFirebaseAuthActions() const { wrapWithErrorHandlingAsync } = useErrorHandling() const isLoading = ref(false) @@ -443,9 +443,15 @@ const handleSubscribe = wrapWithErrorHandlingAsync( loadingTier.value = tierKey try { - const response = await initiateCheckout(tierKey) - if (response.checkout_url) { - window.open(response.checkout_url, '_blank') + if (isActiveSubscription.value) { + // Pass the target tier to create a deep link to subscription update confirmation + const checkoutTier = getCheckoutTier(tierKey, currentBillingCycle.value) + await accessBillingPortal(checkoutTier) + } else { + const response = await initiateCheckout(tierKey) + if (response.checkout_url) { + window.open(response.checkout_url, '_blank') + } } } finally { isLoading.value = false diff --git a/src/stores/firebaseAuthStore.ts b/src/stores/firebaseAuthStore.ts index 7eaf498d6..9f3039887 100644 --- a/src/stores/firebaseAuthStore.ts +++ b/src/stores/firebaseAuthStore.ts @@ -42,6 +42,11 @@ type AccessBillingPortalResponse = operations['AccessBillingPortal']['responses']['200']['content']['application/json'] type AccessBillingPortalReqBody = operations['AccessBillingPortal']['requestBody'] +export type BillingPortalTargetTier = NonNullable< + NonNullable< + NonNullable['content'] + >['application/json'] +>['target_tier'] export class FirebaseAuthStoreError extends Error { constructor(message: string) { @@ -409,13 +414,15 @@ export const useFirebaseAuthStore = defineStore('firebaseAuth', () => { executeAuthAction((_) => addCredits(requestBodyContent)) const accessBillingPortal = async ( - requestBody?: AccessBillingPortalReqBody + targetTier?: BillingPortalTargetTier ): Promise => { const authHeader = await getAuthHeader() if (!authHeader) { throw new FirebaseAuthStoreError(t('toastMessages.userNotAuthenticated')) } + const requestBody = targetTier ? { target_tier: targetTier } : undefined + const response = await fetch(buildApiUrl('/customers/billing'), { method: 'POST', headers: { diff --git a/tests-ui/tests/platform/cloud/subscription/components/PricingTable.test.ts b/tests-ui/tests/platform/cloud/subscription/components/PricingTable.test.ts new file mode 100644 index 000000000..d379b3d5c --- /dev/null +++ b/tests-ui/tests/platform/cloud/subscription/components/PricingTable.test.ts @@ -0,0 +1,237 @@ +import { createTestingPinia } from '@pinia/testing' +import { flushPromises, mount } from '@vue/test-utils' +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { computed, ref } from 'vue' +import { createI18n } from 'vue-i18n' + +import PricingTable from '@/platform/cloud/subscription/components/PricingTable.vue' + +const mockIsActiveSubscription = ref(false) +const mockSubscriptionTier = ref< + 'STANDARD' | 'CREATOR' | 'PRO' | 'FOUNDERS_EDITION' | null +>(null) +const mockIsYearlySubscription = ref(false) +const mockAccessBillingPortal = vi.fn() +const mockReportError = vi.fn() +const mockGetAuthHeader = vi.fn(() => + Promise.resolve({ Authorization: 'Bearer test-token' }) +) + +vi.mock('@/platform/cloud/subscription/composables/useSubscription', () => ({ + useSubscription: () => ({ + isActiveSubscription: computed(() => mockIsActiveSubscription.value), + subscriptionTier: computed(() => mockSubscriptionTier.value), + isYearlySubscription: computed(() => mockIsYearlySubscription.value) + }) +})) + +vi.mock('@/composables/auth/useFirebaseAuthActions', () => ({ + useFirebaseAuthActions: () => ({ + accessBillingPortal: mockAccessBillingPortal, + reportError: mockReportError + }) +})) + +vi.mock('@/composables/useErrorHandling', () => ({ + useErrorHandling: () => ({ + wrapWithErrorHandlingAsync: vi.fn( + (fn, errorHandler) => + async (...args: unknown[]) => { + try { + return await fn(...args) + } catch (error) { + if (errorHandler) { + errorHandler(error) + } + throw error + } + } + ) + }) +})) + +vi.mock('@/stores/firebaseAuthStore', () => ({ + useFirebaseAuthStore: () => ({ + getAuthHeader: mockGetAuthHeader + }), + FirebaseAuthStoreError: class extends Error {} +})) + +vi.mock('@/platform/distribution/types', () => ({ + isCloud: true +})) + +global.fetch = vi.fn() + +const i18n = createI18n({ + legacy: false, + locale: 'en', + messages: { + en: { + subscription: { + yearly: 'Yearly', + monthly: 'Monthly', + mostPopular: 'Most Popular', + usdPerMonth: '/ month', + billedYearly: 'Billed yearly ({total})', + billedMonthly: 'Billed monthly', + currentPlan: 'Current Plan', + subscribeTo: 'Subscribe to {plan}', + changeTo: 'Change to {plan}', + maxDuration: { + standard: '30 min', + creator: '30 min', + pro: '1 hr' + }, + tiers: { + standard: { name: 'Standard' }, + creator: { name: 'Creator' }, + pro: { name: 'Pro' } + }, + benefits: { + monthlyCredits: '{credits} monthly credits', + maxDuration: '{duration} max duration', + gpu: 'RTX 6000 Pro GPU', + addCredits: 'Add more credits anytime', + customLoRAs: 'Import custom LoRAs' + } + } + } + } +}) + +function createWrapper() { + return mount(PricingTable, { + global: { + plugins: [createTestingPinia({ createSpy: vi.fn }), i18n], + stubs: { + SelectButton: { + template: '
', + props: ['modelValue', 'options'], + emits: ['update:modelValue'] + }, + Popover: { template: '
' }, + Button: { + template: + '', + props: ['loading', 'label', 'severity', 'disabled', 'dataTier', 'pt'], + emits: ['click'] + } + } + } + }) +} + +describe('PricingTable', () => { + beforeEach(() => { + vi.clearAllMocks() + mockIsActiveSubscription.value = false + mockSubscriptionTier.value = null + mockIsYearlySubscription.value = false + vi.mocked(global.fetch).mockResolvedValue({ + ok: true, + json: async () => ({ checkout_url: 'https://checkout.stripe.com/test' }) + } as Response) + }) + + describe('billing portal deep linking', () => { + it('should call accessBillingPortal with yearly tier suffix when billing cycle is yearly (default)', async () => { + mockIsActiveSubscription.value = true + mockSubscriptionTier.value = 'STANDARD' + + const wrapper = createWrapper() + await flushPromises() + + const creatorButton = wrapper + .findAll('button') + .find((btn) => btn.text().includes('Creator')) + + expect(creatorButton).toBeDefined() + await creatorButton?.trigger('click') + await flushPromises() + + expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly') + }) + + it('should call accessBillingPortal with different tiers correctly', async () => { + mockIsActiveSubscription.value = true + mockSubscriptionTier.value = 'STANDARD' + + const wrapper = createWrapper() + await flushPromises() + + const proButton = wrapper + .findAll('button') + .find((btn) => btn.text().includes('Pro')) + + await proButton?.trigger('click') + await flushPromises() + + expect(mockAccessBillingPortal).toHaveBeenCalledWith('pro-yearly') + }) + + it('should not call accessBillingPortal when clicking current plan', async () => { + mockIsActiveSubscription.value = true + mockSubscriptionTier.value = 'CREATOR' + + const wrapper = createWrapper() + await flushPromises() + + const currentPlanButton = wrapper + .findAll('button') + .find((btn) => btn.text().includes('Current Plan')) + + await currentPlanButton?.trigger('click') + await flushPromises() + + expect(mockAccessBillingPortal).not.toHaveBeenCalled() + }) + + it('should initiate checkout instead of billing portal for new subscribers', async () => { + mockIsActiveSubscription.value = false + + const windowOpenSpy = vi + .spyOn(window, 'open') + .mockImplementation(() => null) + + const wrapper = createWrapper() + await flushPromises() + + const subscribeButton = wrapper + .findAll('button') + .find((btn) => btn.text().includes('Subscribe')) + + await subscribeButton?.trigger('click') + await flushPromises() + + expect(mockAccessBillingPortal).not.toHaveBeenCalled() + expect(global.fetch).toHaveBeenCalledWith( + expect.stringContaining('/customers/cloud-subscription-checkout/'), + expect.any(Object) + ) + expect(windowOpenSpy).toHaveBeenCalledWith( + 'https://checkout.stripe.com/test', + '_blank' + ) + + windowOpenSpy.mockRestore() + }) + + it('should pass correct tier for each subscription level', async () => { + mockIsActiveSubscription.value = true + mockSubscriptionTier.value = 'PRO' + + const wrapper = createWrapper() + await flushPromises() + + const standardButton = wrapper + .findAll('button') + .find((btn) => btn.text().includes('Standard')) + + await standardButton?.trigger('click') + await flushPromises() + + expect(mockAccessBillingPortal).toHaveBeenCalledWith('standard-yearly') + }) + }) +}) diff --git a/tests-ui/tests/store/firebaseAuthStore.test.ts b/tests-ui/tests/store/firebaseAuthStore.test.ts index fa67700bc..4e29fdaea 100644 --- a/tests-ui/tests/store/firebaseAuthStore.test.ts +++ b/tests-ui/tests/store/firebaseAuthStore.test.ts @@ -30,7 +30,9 @@ const mockAddCreditsResponse = { const mockAccessBillingPortalResponse = { ok: true, - statusText: 'OK' + statusText: 'OK', + json: () => + Promise.resolve({ billing_portal_url: 'https://billing.stripe.com/test' }) } vi.mock('vuefire', () => ({ @@ -129,7 +131,7 @@ describe('useFirebaseAuthStore', () => { if (url.endsWith('/customers/credit')) { return Promise.resolve(mockAddCreditsResponse) } - if (url.endsWith('/customers/billing-portal')) { + if (url.endsWith('/customers/billing')) { return Promise.resolve(mockAccessBillingPortalResponse) } return Promise.reject(new Error('Unexpected API call')) @@ -542,4 +544,75 @@ describe('useFirebaseAuthStore', () => { expect(store.loading).toBe(false) }) }) + + describe('accessBillingPortal', () => { + it('should call billing endpoint without body when no targetTier provided', async () => { + const result = await store.accessBillingPortal() + + expect(mockFetch).toHaveBeenCalledWith( + expect.stringContaining('/customers/billing'), + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + Authorization: 'Bearer mock-id-token', + 'Content-Type': 'application/json' + }) + }) + ) + + const callArgs = mockFetch.mock.calls.find((call) => + (call[0] as string).endsWith('/customers/billing') + ) + expect(callArgs?.[1]).not.toHaveProperty('body') + expect(result).toEqual({ + billing_portal_url: 'https://billing.stripe.com/test' + }) + }) + + it('should include target_tier in request body when targetTier provided', async () => { + await store.accessBillingPortal('creator') + + const callArgs = mockFetch.mock.calls.find((call) => + (call[0] as string).endsWith('/customers/billing') + ) + expect(callArgs?.[1]).toHaveProperty('body') + expect(JSON.parse(callArgs?.[1]?.body as string)).toEqual({ + target_tier: 'creator' + }) + }) + + it('should handle different checkout tier formats', async () => { + const tiers = [ + 'standard', + 'creator', + 'pro', + 'standard-yearly', + 'creator-yearly', + 'pro-yearly' + ] as const + + for (const tier of tiers) { + mockFetch.mockClear() + await store.accessBillingPortal(tier) + + const callArgs = mockFetch.mock.calls.find((call) => + (call[0] as string).endsWith('/customers/billing') + ) + expect(JSON.parse(callArgs?.[1]?.body as string)).toEqual({ + target_tier: tier + }) + } + }) + + it('should throw error when API returns error response', async () => { + mockFetch.mockImplementationOnce(() => + Promise.resolve({ + ok: false, + json: () => Promise.resolve({ message: 'Billing portal unavailable' }) + }) + ) + + await expect(store.accessBillingPortal()).rejects.toThrow() + }) + }) })