mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-01-26 19:09:52 +00:00
feat: pass target tier to billing portal for subscription updates (#7692)
## Summary Pass target tier to billing portal API for deep linking to Stripe's subscription update confirmation screen when user has an active subscription. ## Changes - **What**: When a user with an active subscription clicks a tier in PricingTable, pass the target tier (including billing cycle) to `accessBillingPortal` which sends it as `target_tier` in the request body. This enables the backend to create a Stripe billing portal deep link directly to the subscription update confirmation screen. - **Dependencies**: Requires comfy-api PR for `POST /customers/billing` `target_tier` support ## Review Focus - PricingTable now differentiates between new subscriptions (checkout flow) and existing subscriptions (billing portal with deep link) - Type derivation uses `Parameters<typeof authStore.accessBillingPortal>[0]` to avoid duplicating the tier union (matches codebase pattern) - Registry types manually updated to include `target_tier` field (will be regenerated when API is deployed) ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-7692-feat-pass-target-tier-to-billing-portal-for-subscription-updates-2d06d73d365081b38fe4c81e95dce58c) by [Unito](https://www.unito.io) --------- Co-authored-by: Christian Byrne <cbyrne@comfy.org> Co-authored-by: GitHub Action <action@github.com>
This commit is contained in:
@@ -11910,6 +11910,8 @@ export interface operations {
|
|||||||
"application/json": {
|
"application/json": {
|
||||||
/** @description Optional URL to redirect the customer after they're done with the billing portal */
|
/** @description Optional URL to redirect the customer after they're done with the billing portal */
|
||||||
return_url?: string;
|
return_url?: string;
|
||||||
|
/** @description Optional target subscription tier. When provided, creates a deep link directly to the subscription update confirmation screen with this tier pre-selected. */
|
||||||
|
target_tier?: "standard" | "creator" | "pro" | "standard-yearly" | "creator-yearly" | "pro-yearly";
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import { useTelemetry } from '@/platform/telemetry'
|
|||||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||||
import { useDialogService } from '@/services/dialogService'
|
import { useDialogService } from '@/services/dialogService'
|
||||||
import { useFirebaseAuthStore } from '@/stores/firebaseAuthStore'
|
import { useFirebaseAuthStore } from '@/stores/firebaseAuthStore'
|
||||||
|
import type { BillingPortalTargetTier } from '@/stores/firebaseAuthStore'
|
||||||
import { usdToMicros } from '@/utils/formatUtil'
|
import { usdToMicros } from '@/utils/formatUtil'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -102,8 +103,11 @@ export const useFirebaseAuthActions = () => {
|
|||||||
window.open(response.checkout_url, '_blank')
|
window.open(response.checkout_url, '_blank')
|
||||||
}, reportError)
|
}, reportError)
|
||||||
|
|
||||||
const accessBillingPortal = wrapWithErrorHandlingAsync(async () => {
|
const accessBillingPortal = wrapWithErrorHandlingAsync<
|
||||||
const response = await authStore.accessBillingPortal()
|
[targetTier?: BillingPortalTargetTier],
|
||||||
|
void
|
||||||
|
>(async (targetTier) => {
|
||||||
|
const response = await authStore.accessBillingPortal(targetTier)
|
||||||
if (!response.billing_portal_url) {
|
if (!response.billing_portal_url) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
t('toastMessages.failedToAccessBillingPortal', {
|
t('toastMessages.failedToAccessBillingPortal', {
|
||||||
|
|||||||
@@ -333,7 +333,7 @@ const { n } = useI18n()
|
|||||||
const { getAuthHeader } = useFirebaseAuthStore()
|
const { getAuthHeader } = useFirebaseAuthStore()
|
||||||
const { isActiveSubscription, subscriptionTier, isYearlySubscription } =
|
const { isActiveSubscription, subscriptionTier, isYearlySubscription } =
|
||||||
useSubscription()
|
useSubscription()
|
||||||
const { reportError } = useFirebaseAuthActions()
|
const { accessBillingPortal, reportError } = useFirebaseAuthActions()
|
||||||
const { wrapWithErrorHandlingAsync } = useErrorHandling()
|
const { wrapWithErrorHandlingAsync } = useErrorHandling()
|
||||||
|
|
||||||
const isLoading = ref(false)
|
const isLoading = ref(false)
|
||||||
@@ -443,9 +443,15 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
|||||||
loadingTier.value = tierKey
|
loadingTier.value = tierKey
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await initiateCheckout(tierKey)
|
if (isActiveSubscription.value) {
|
||||||
if (response.checkout_url) {
|
// Pass the target tier to create a deep link to subscription update confirmation
|
||||||
window.open(response.checkout_url, '_blank')
|
const checkoutTier = getCheckoutTier(tierKey, currentBillingCycle.value)
|
||||||
|
await accessBillingPortal(checkoutTier)
|
||||||
|
} else {
|
||||||
|
const response = await initiateCheckout(tierKey)
|
||||||
|
if (response.checkout_url) {
|
||||||
|
window.open(response.checkout_url, '_blank')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
isLoading.value = false
|
isLoading.value = false
|
||||||
|
|||||||
@@ -42,6 +42,11 @@ type AccessBillingPortalResponse =
|
|||||||
operations['AccessBillingPortal']['responses']['200']['content']['application/json']
|
operations['AccessBillingPortal']['responses']['200']['content']['application/json']
|
||||||
type AccessBillingPortalReqBody =
|
type AccessBillingPortalReqBody =
|
||||||
operations['AccessBillingPortal']['requestBody']
|
operations['AccessBillingPortal']['requestBody']
|
||||||
|
export type BillingPortalTargetTier = NonNullable<
|
||||||
|
NonNullable<
|
||||||
|
NonNullable<AccessBillingPortalReqBody>['content']
|
||||||
|
>['application/json']
|
||||||
|
>['target_tier']
|
||||||
|
|
||||||
export class FirebaseAuthStoreError extends Error {
|
export class FirebaseAuthStoreError extends Error {
|
||||||
constructor(message: string) {
|
constructor(message: string) {
|
||||||
@@ -409,13 +414,15 @@ export const useFirebaseAuthStore = defineStore('firebaseAuth', () => {
|
|||||||
executeAuthAction((_) => addCredits(requestBodyContent))
|
executeAuthAction((_) => addCredits(requestBodyContent))
|
||||||
|
|
||||||
const accessBillingPortal = async (
|
const accessBillingPortal = async (
|
||||||
requestBody?: AccessBillingPortalReqBody
|
targetTier?: BillingPortalTargetTier
|
||||||
): Promise<AccessBillingPortalResponse> => {
|
): Promise<AccessBillingPortalResponse> => {
|
||||||
const authHeader = await getAuthHeader()
|
const authHeader = await getAuthHeader()
|
||||||
if (!authHeader) {
|
if (!authHeader) {
|
||||||
throw new FirebaseAuthStoreError(t('toastMessages.userNotAuthenticated'))
|
throw new FirebaseAuthStoreError(t('toastMessages.userNotAuthenticated'))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const requestBody = targetTier ? { target_tier: targetTier } : undefined
|
||||||
|
|
||||||
const response = await fetch(buildApiUrl('/customers/billing'), {
|
const response = await fetch(buildApiUrl('/customers/billing'), {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
|
|||||||
@@ -0,0 +1,237 @@
|
|||||||
|
import { createTestingPinia } from '@pinia/testing'
|
||||||
|
import { flushPromises, mount } from '@vue/test-utils'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { computed, ref } from 'vue'
|
||||||
|
import { createI18n } from 'vue-i18n'
|
||||||
|
|
||||||
|
import PricingTable from '@/platform/cloud/subscription/components/PricingTable.vue'
|
||||||
|
|
||||||
|
const mockIsActiveSubscription = ref(false)
|
||||||
|
const mockSubscriptionTier = ref<
|
||||||
|
'STANDARD' | 'CREATOR' | 'PRO' | 'FOUNDERS_EDITION' | null
|
||||||
|
>(null)
|
||||||
|
const mockIsYearlySubscription = ref(false)
|
||||||
|
const mockAccessBillingPortal = vi.fn()
|
||||||
|
const mockReportError = vi.fn()
|
||||||
|
const mockGetAuthHeader = vi.fn(() =>
|
||||||
|
Promise.resolve({ Authorization: 'Bearer test-token' })
|
||||||
|
)
|
||||||
|
|
||||||
|
vi.mock('@/platform/cloud/subscription/composables/useSubscription', () => ({
|
||||||
|
useSubscription: () => ({
|
||||||
|
isActiveSubscription: computed(() => mockIsActiveSubscription.value),
|
||||||
|
subscriptionTier: computed(() => mockSubscriptionTier.value),
|
||||||
|
isYearlySubscription: computed(() => mockIsYearlySubscription.value)
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/composables/auth/useFirebaseAuthActions', () => ({
|
||||||
|
useFirebaseAuthActions: () => ({
|
||||||
|
accessBillingPortal: mockAccessBillingPortal,
|
||||||
|
reportError: mockReportError
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/composables/useErrorHandling', () => ({
|
||||||
|
useErrorHandling: () => ({
|
||||||
|
wrapWithErrorHandlingAsync: vi.fn(
|
||||||
|
(fn, errorHandler) =>
|
||||||
|
async (...args: unknown[]) => {
|
||||||
|
try {
|
||||||
|
return await fn(...args)
|
||||||
|
} catch (error) {
|
||||||
|
if (errorHandler) {
|
||||||
|
errorHandler(error)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/stores/firebaseAuthStore', () => ({
|
||||||
|
useFirebaseAuthStore: () => ({
|
||||||
|
getAuthHeader: mockGetAuthHeader
|
||||||
|
}),
|
||||||
|
FirebaseAuthStoreError: class extends Error {}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@/platform/distribution/types', () => ({
|
||||||
|
isCloud: true
|
||||||
|
}))
|
||||||
|
|
||||||
|
global.fetch = vi.fn()
|
||||||
|
|
||||||
|
const i18n = createI18n({
|
||||||
|
legacy: false,
|
||||||
|
locale: 'en',
|
||||||
|
messages: {
|
||||||
|
en: {
|
||||||
|
subscription: {
|
||||||
|
yearly: 'Yearly',
|
||||||
|
monthly: 'Monthly',
|
||||||
|
mostPopular: 'Most Popular',
|
||||||
|
usdPerMonth: '/ month',
|
||||||
|
billedYearly: 'Billed yearly ({total})',
|
||||||
|
billedMonthly: 'Billed monthly',
|
||||||
|
currentPlan: 'Current Plan',
|
||||||
|
subscribeTo: 'Subscribe to {plan}',
|
||||||
|
changeTo: 'Change to {plan}',
|
||||||
|
maxDuration: {
|
||||||
|
standard: '30 min',
|
||||||
|
creator: '30 min',
|
||||||
|
pro: '1 hr'
|
||||||
|
},
|
||||||
|
tiers: {
|
||||||
|
standard: { name: 'Standard' },
|
||||||
|
creator: { name: 'Creator' },
|
||||||
|
pro: { name: 'Pro' }
|
||||||
|
},
|
||||||
|
benefits: {
|
||||||
|
monthlyCredits: '{credits} monthly credits',
|
||||||
|
maxDuration: '{duration} max duration',
|
||||||
|
gpu: 'RTX 6000 Pro GPU',
|
||||||
|
addCredits: 'Add more credits anytime',
|
||||||
|
customLoRAs: 'Import custom LoRAs'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
function createWrapper() {
|
||||||
|
return mount(PricingTable, {
|
||||||
|
global: {
|
||||||
|
plugins: [createTestingPinia({ createSpy: vi.fn }), i18n],
|
||||||
|
stubs: {
|
||||||
|
SelectButton: {
|
||||||
|
template: '<div><slot /></div>',
|
||||||
|
props: ['modelValue', 'options'],
|
||||||
|
emits: ['update:modelValue']
|
||||||
|
},
|
||||||
|
Popover: { template: '<div><slot /></div>' },
|
||||||
|
Button: {
|
||||||
|
template:
|
||||||
|
'<button @click="$emit(\'click\')" :disabled="disabled" :data-tier="dataTier">{{ label }}</button>',
|
||||||
|
props: ['loading', 'label', 'severity', 'disabled', 'dataTier', 'pt'],
|
||||||
|
emits: ['click']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('PricingTable', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
mockIsActiveSubscription.value = false
|
||||||
|
mockSubscriptionTier.value = null
|
||||||
|
mockIsYearlySubscription.value = false
|
||||||
|
vi.mocked(global.fetch).mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
json: async () => ({ checkout_url: 'https://checkout.stripe.com/test' })
|
||||||
|
} as Response)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('billing portal deep linking', () => {
|
||||||
|
it('should call accessBillingPortal with yearly tier suffix when billing cycle is yearly (default)', async () => {
|
||||||
|
mockIsActiveSubscription.value = true
|
||||||
|
mockSubscriptionTier.value = 'STANDARD'
|
||||||
|
|
||||||
|
const wrapper = createWrapper()
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
const creatorButton = wrapper
|
||||||
|
.findAll('button')
|
||||||
|
.find((btn) => btn.text().includes('Creator'))
|
||||||
|
|
||||||
|
expect(creatorButton).toBeDefined()
|
||||||
|
await creatorButton?.trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should call accessBillingPortal with different tiers correctly', async () => {
|
||||||
|
mockIsActiveSubscription.value = true
|
||||||
|
mockSubscriptionTier.value = 'STANDARD'
|
||||||
|
|
||||||
|
const wrapper = createWrapper()
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
const proButton = wrapper
|
||||||
|
.findAll('button')
|
||||||
|
.find((btn) => btn.text().includes('Pro'))
|
||||||
|
|
||||||
|
await proButton?.trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(mockAccessBillingPortal).toHaveBeenCalledWith('pro-yearly')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not call accessBillingPortal when clicking current plan', async () => {
|
||||||
|
mockIsActiveSubscription.value = true
|
||||||
|
mockSubscriptionTier.value = 'CREATOR'
|
||||||
|
|
||||||
|
const wrapper = createWrapper()
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
const currentPlanButton = wrapper
|
||||||
|
.findAll('button')
|
||||||
|
.find((btn) => btn.text().includes('Current Plan'))
|
||||||
|
|
||||||
|
await currentPlanButton?.trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(mockAccessBillingPortal).not.toHaveBeenCalled()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should initiate checkout instead of billing portal for new subscribers', async () => {
|
||||||
|
mockIsActiveSubscription.value = false
|
||||||
|
|
||||||
|
const windowOpenSpy = vi
|
||||||
|
.spyOn(window, 'open')
|
||||||
|
.mockImplementation(() => null)
|
||||||
|
|
||||||
|
const wrapper = createWrapper()
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
const subscribeButton = wrapper
|
||||||
|
.findAll('button')
|
||||||
|
.find((btn) => btn.text().includes('Subscribe'))
|
||||||
|
|
||||||
|
await subscribeButton?.trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(mockAccessBillingPortal).not.toHaveBeenCalled()
|
||||||
|
expect(global.fetch).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('/customers/cloud-subscription-checkout/'),
|
||||||
|
expect.any(Object)
|
||||||
|
)
|
||||||
|
expect(windowOpenSpy).toHaveBeenCalledWith(
|
||||||
|
'https://checkout.stripe.com/test',
|
||||||
|
'_blank'
|
||||||
|
)
|
||||||
|
|
||||||
|
windowOpenSpy.mockRestore()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should pass correct tier for each subscription level', async () => {
|
||||||
|
mockIsActiveSubscription.value = true
|
||||||
|
mockSubscriptionTier.value = 'PRO'
|
||||||
|
|
||||||
|
const wrapper = createWrapper()
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
const standardButton = wrapper
|
||||||
|
.findAll('button')
|
||||||
|
.find((btn) => btn.text().includes('Standard'))
|
||||||
|
|
||||||
|
await standardButton?.trigger('click')
|
||||||
|
await flushPromises()
|
||||||
|
|
||||||
|
expect(mockAccessBillingPortal).toHaveBeenCalledWith('standard-yearly')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -30,7 +30,9 @@ const mockAddCreditsResponse = {
|
|||||||
|
|
||||||
const mockAccessBillingPortalResponse = {
|
const mockAccessBillingPortalResponse = {
|
||||||
ok: true,
|
ok: true,
|
||||||
statusText: 'OK'
|
statusText: 'OK',
|
||||||
|
json: () =>
|
||||||
|
Promise.resolve({ billing_portal_url: 'https://billing.stripe.com/test' })
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.mock('vuefire', () => ({
|
vi.mock('vuefire', () => ({
|
||||||
@@ -129,7 +131,7 @@ describe('useFirebaseAuthStore', () => {
|
|||||||
if (url.endsWith('/customers/credit')) {
|
if (url.endsWith('/customers/credit')) {
|
||||||
return Promise.resolve(mockAddCreditsResponse)
|
return Promise.resolve(mockAddCreditsResponse)
|
||||||
}
|
}
|
||||||
if (url.endsWith('/customers/billing-portal')) {
|
if (url.endsWith('/customers/billing')) {
|
||||||
return Promise.resolve(mockAccessBillingPortalResponse)
|
return Promise.resolve(mockAccessBillingPortalResponse)
|
||||||
}
|
}
|
||||||
return Promise.reject(new Error('Unexpected API call'))
|
return Promise.reject(new Error('Unexpected API call'))
|
||||||
@@ -542,4 +544,75 @@ describe('useFirebaseAuthStore', () => {
|
|||||||
expect(store.loading).toBe(false)
|
expect(store.loading).toBe(false)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('accessBillingPortal', () => {
|
||||||
|
it('should call billing endpoint without body when no targetTier provided', async () => {
|
||||||
|
const result = await store.accessBillingPortal()
|
||||||
|
|
||||||
|
expect(mockFetch).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining('/customers/billing'),
|
||||||
|
expect.objectContaining({
|
||||||
|
method: 'POST',
|
||||||
|
headers: expect.objectContaining({
|
||||||
|
Authorization: 'Bearer mock-id-token',
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
const callArgs = mockFetch.mock.calls.find((call) =>
|
||||||
|
(call[0] as string).endsWith('/customers/billing')
|
||||||
|
)
|
||||||
|
expect(callArgs?.[1]).not.toHaveProperty('body')
|
||||||
|
expect(result).toEqual({
|
||||||
|
billing_portal_url: 'https://billing.stripe.com/test'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include target_tier in request body when targetTier provided', async () => {
|
||||||
|
await store.accessBillingPortal('creator')
|
||||||
|
|
||||||
|
const callArgs = mockFetch.mock.calls.find((call) =>
|
||||||
|
(call[0] as string).endsWith('/customers/billing')
|
||||||
|
)
|
||||||
|
expect(callArgs?.[1]).toHaveProperty('body')
|
||||||
|
expect(JSON.parse(callArgs?.[1]?.body as string)).toEqual({
|
||||||
|
target_tier: 'creator'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle different checkout tier formats', async () => {
|
||||||
|
const tiers = [
|
||||||
|
'standard',
|
||||||
|
'creator',
|
||||||
|
'pro',
|
||||||
|
'standard-yearly',
|
||||||
|
'creator-yearly',
|
||||||
|
'pro-yearly'
|
||||||
|
] as const
|
||||||
|
|
||||||
|
for (const tier of tiers) {
|
||||||
|
mockFetch.mockClear()
|
||||||
|
await store.accessBillingPortal(tier)
|
||||||
|
|
||||||
|
const callArgs = mockFetch.mock.calls.find((call) =>
|
||||||
|
(call[0] as string).endsWith('/customers/billing')
|
||||||
|
)
|
||||||
|
expect(JSON.parse(callArgs?.[1]?.body as string)).toEqual({
|
||||||
|
target_tier: tier
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error when API returns error response', async () => {
|
||||||
|
mockFetch.mockImplementationOnce(() =>
|
||||||
|
Promise.resolve({
|
||||||
|
ok: false,
|
||||||
|
json: () => Promise.resolve({ message: 'Billing portal unavailable' })
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
await expect(store.accessBillingPortal()).rejects.toThrow()
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user