From 3069c24f81b484c4e1033e518dce1e293c03f9ba Mon Sep 17 00:00:00 2001 From: Yourz Date: Thu, 15 Jan 2026 08:57:51 +0800 Subject: [PATCH] feat: handling subscription tier button link parameter (#7553) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Discussion here: https://comfy-organization.slack.com/archives/C0A0XANFJRE/p1764899027465379 Implement: Subscription tier query parameter for direct checkout flow Example button link: `/cloud/subscribe?tier=standard` `tier` could be `standard`, `creator` or `pro` `cycle` could be `monthly` or `yearly`. it is optional, and `monthly` by default. ## Changes - **What**: - Add a landing page called `CloudSubscriptionRedirectView.vue` to handling the subscription tier button link parameter - Extract subscription handling logic from `PriceTable.vue` - **Breaking**: - Code change touched `PriceTable.vue` - **Dependencies**: ## Review Focus - link will redirect to login url, when cloud app not login - after login, the cloud app will redirect to CloudSubscriptionRedirect page - wait for several seconds, the cloud app will be redirected to checkout page ## Screenshots (if applicable) ![Kapture 2025-12-16 at 18 43 28](https://github.com/user-attachments/assets/affbc18f-d45c-4953-b06a-fc797eba6804) ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-7553-feat-handling-subscription-tier-button-link-parameter-2cb6d73d365081ee9580e89090248300) by [Unito](https://www.unito.io) --------- Co-authored-by: GitHub Action --- lint-staged.config.mjs | 23 +++ .../auth/useFirebaseAuthActions.ts | 10 +- src/locales/en/main.json | 1 + .../cloud/onboarding/CloudLoginView.vue | 14 +- .../cloud/onboarding/CloudSignupView.vue | 16 +- .../CloudSubscriptionRedirectView.test.ts | 162 ++++++++++++++++++ .../CloudSubscriptionRedirectView.vue | 130 ++++++++++++++ .../cloud/onboarding/UserCheckView.vue | 2 +- .../cloud/onboarding/onboardingCloudRoutes.ts | 7 + .../onboarding/utils/previousFullPath.test.ts | 38 ++++ .../onboarding/utils/previousFullPath.ts | 27 +++ .../subscription/components/PricingTable.vue | 66 +------ .../composables/useSubscription.ts | 9 +- .../utils/subscriptionCheckoutUtil.ts | 87 ++++++++++ src/router.ts | 15 +- 15 files changed, 536 insertions(+), 71 deletions(-) create mode 100644 lint-staged.config.mjs create mode 100644 src/platform/cloud/onboarding/CloudSubscriptionRedirectView.test.ts create mode 100644 src/platform/cloud/onboarding/CloudSubscriptionRedirectView.vue create mode 100644 src/platform/cloud/onboarding/utils/previousFullPath.test.ts create mode 100644 src/platform/cloud/onboarding/utils/previousFullPath.ts create mode 100644 src/platform/cloud/subscription/utils/subscriptionCheckoutUtil.ts diff --git a/lint-staged.config.mjs b/lint-staged.config.mjs new file mode 100644 index 000000000..97d22c529 --- /dev/null +++ b/lint-staged.config.mjs @@ -0,0 +1,23 @@ +import path from 'node:path' + +export default { + './**/*.js': (stagedFiles) => formatAndEslint(stagedFiles), + + './**/*.{ts,tsx,vue,mts}': (stagedFiles) => [ + ...formatAndEslint(stagedFiles), + 'pnpm typecheck' + ] +} + +function formatAndEslint(fileNames) { + // Convert absolute paths to relative paths for better ESLint resolution + const relativePaths = fileNames.map((f) => path.relative(process.cwd(), f)) + const joinedPaths = relativePaths.map((p) => `"${p}"`).join(' ') + return [ + `pnpm exec prettier --cache --write ${joinedPaths}`, + `pnpm exec oxlint --fix ${joinedPaths}`, + `pnpm exec eslint --cache --fix --no-warn-ignored ${joinedPaths}` + ] +} + + diff --git a/src/composables/auth/useFirebaseAuthActions.ts b/src/composables/auth/useFirebaseAuthActions.ts index da6744c2d..319dfa851 100644 --- a/src/composables/auth/useFirebaseAuthActions.ts +++ b/src/composables/auth/useFirebaseAuthActions.ts @@ -104,9 +104,9 @@ export const useFirebaseAuthActions = () => { }, reportError) const accessBillingPortal = wrapWithErrorHandlingAsync< - [targetTier?: BillingPortalTargetTier], + [targetTier?: BillingPortalTargetTier, openInNewTab?: boolean], void - >(async (targetTier) => { + >(async (targetTier, openInNewTab = true) => { const response = await authStore.accessBillingPortal(targetTier) if (!response.billing_portal_url) { throw new Error( @@ -115,7 +115,11 @@ export const useFirebaseAuthActions = () => { }) ) } - window.open(response.billing_portal_url, '_blank') + if (openInNewTab) { + window.open(response.billing_portal_url, '_blank') + } else { + globalThis.location.href = response.billing_portal_url + } }, reportError) const fetchBalance = wrapWithErrorHandlingAsync(async () => { diff --git a/src/locales/en/main.json b/src/locales/en/main.json index 555ff07fb..75c4bb3fb 100644 --- a/src/locales/en/main.json +++ b/src/locales/en/main.json @@ -2166,6 +2166,7 @@ "renderErrorState": "Render Error State" }, "cloudOnboarding": { + "skipToCloudApp": "Skip to the cloud app", "survey": { "title": "Cloud Survey", "placeholder": "Survey questions placeholder", diff --git a/src/platform/cloud/onboarding/CloudLoginView.vue b/src/platform/cloud/onboarding/CloudLoginView.vue index b7d1cd131..a4fea37a4 100644 --- a/src/platform/cloud/onboarding/CloudLoginView.vue +++ b/src/platform/cloud/onboarding/CloudLoginView.vue @@ -84,6 +84,7 @@ import { useRoute, useRouter } from 'vue-router' import Button from '@/components/ui/button/Button.vue' import { useFirebaseAuthActions } from '@/composables/auth/useFirebaseAuthActions' import CloudSignInForm from '@/platform/cloud/onboarding/components/CloudSignInForm.vue' +import { getSafePreviousFullPath } from '@/platform/cloud/onboarding/utils/previousFullPath' import { useToastStore } from '@/platform/updates/common/toastStore' import type { SignInData } from '@/schemas/signInSchema' @@ -91,12 +92,12 @@ const { t } = useI18n() const router = useRouter() const route = useRoute() const authActions = useFirebaseAuthActions() -const isSecureContext = window.isSecureContext +const isSecureContext = globalThis.isSecureContext const authError = ref('') const toastStore = useToastStore() -const navigateToSignup = () => { - void router.push({ name: 'cloud-signup', query: route.query }) +const navigateToSignup = async () => { + await router.push({ name: 'cloud-signup', query: route.query }) } const onSuccess = async () => { @@ -105,6 +106,13 @@ const onSuccess = async () => { summary: 'Login Completed', life: 2000 }) + + const previousFullPath = getSafePreviousFullPath(route.query) + if (previousFullPath) { + await router.replace(previousFullPath) + return + } + await router.push({ name: 'cloud-user-check' }) } diff --git a/src/platform/cloud/onboarding/CloudSignupView.vue b/src/platform/cloud/onboarding/CloudSignupView.vue index 5a8bf5c3e..756851ec7 100644 --- a/src/platform/cloud/onboarding/CloudSignupView.vue +++ b/src/platform/cloud/onboarding/CloudSignupView.vue @@ -100,6 +100,7 @@ import { useRoute, useRouter } from 'vue-router' import SignUpForm from '@/components/dialog/content/signin/SignUpForm.vue' import Button from '@/components/ui/button/Button.vue' import { useFirebaseAuthActions } from '@/composables/auth/useFirebaseAuthActions' +import { getSafePreviousFullPath } from '@/platform/cloud/onboarding/utils/previousFullPath' import { isCloud } from '@/platform/distribution/types' import { useTelemetry } from '@/platform/telemetry' import { useToastStore } from '@/platform/updates/common/toastStore' @@ -110,13 +111,13 @@ const { t } = useI18n() const router = useRouter() const route = useRoute() const authActions = useFirebaseAuthActions() -const isSecureContext = window.isSecureContext +const isSecureContext = globalThis.isSecureContext const authError = ref('') const userIsInChina = ref(false) const toastStore = useToastStore() -const navigateToLogin = () => { - void router.push({ name: 'cloud-login', query: route.query }) +const navigateToLogin = async () => { + await router.push({ name: 'cloud-login', query: route.query }) } const onSuccess = async () => { @@ -125,7 +126,14 @@ const onSuccess = async () => { summary: 'Sign up Completed', life: 2000 }) - // Direct redirect to main app - email verification removed + + const previousFullPath = getSafePreviousFullPath(route.query) + if (previousFullPath) { + await router.replace(previousFullPath) + return + } + + // Default redirect to the normal onboarding flow await router.push({ path: '/', query: route.query }) } diff --git a/src/platform/cloud/onboarding/CloudSubscriptionRedirectView.test.ts b/src/platform/cloud/onboarding/CloudSubscriptionRedirectView.test.ts new file mode 100644 index 000000000..68fec345d --- /dev/null +++ b/src/platform/cloud/onboarding/CloudSubscriptionRedirectView.test.ts @@ -0,0 +1,162 @@ +import { mount } from '@vue/test-utils' +import { beforeEach, describe, expect, test, vi } from 'vitest' +import { createI18n } from 'vue-i18n' + +import CloudSubscriptionRedirectView from './CloudSubscriptionRedirectView.vue' + +const flushPromises = () => new Promise((resolve) => setTimeout(resolve, 0)) + +// Router mocks +let mockQuery: Record = {} +const mockRouterPush = vi.fn() + +vi.mock('vue-router', () => ({ + useRoute: () => ({ + query: mockQuery + }), + useRouter: () => ({ + push: mockRouterPush + }) +})) + +// Firebase / subscription mocks +const authActionMocks = vi.hoisted(() => ({ + reportError: vi.fn(), + accessBillingPortal: vi.fn() +})) + +vi.mock('@/composables/auth/useFirebaseAuthActions', () => ({ + useFirebaseAuthActions: () => authActionMocks +})) + +vi.mock('@/composables/useErrorHandling', () => ({ + useErrorHandling: () => ({ + wrapWithErrorHandlingAsync: + unknown>(fn: T) => + (...args: Parameters) => + fn(...args) + }) +})) + +const subscriptionMocks = vi.hoisted(() => ({ + isActiveSubscription: { value: false }, + isInitialized: { value: true } +})) + +vi.mock('@/platform/cloud/subscription/composables/useSubscription', () => ({ + useSubscription: () => subscriptionMocks +})) + +// Avoid real network / isCloud behavior +const mockPerformSubscriptionCheckout = vi.fn() +vi.mock('@/platform/cloud/subscription/utils/subscriptionCheckoutUtil', () => ({ + performSubscriptionCheckout: (...args: unknown[]) => + mockPerformSubscriptionCheckout(...args) +})) + +const createI18nInstance = () => + createI18n({ + legacy: false, + locale: 'en', + messages: { + en: { + cloudOnboarding: { + skipToCloudApp: 'Skip to the cloud app' + }, + g: { + comfyOrgLogoAlt: 'Comfy org logo' + }, + subscription: { + subscribeTo: 'Subscribe to {plan}', + tiers: { + standard: { name: 'Standard' }, + creator: { name: 'Creator' }, + pro: { name: 'Pro' } + } + } + } + } + }) + +const mountView = async (query: Record) => { + mockQuery = query + + const wrapper = mount(CloudSubscriptionRedirectView, { + global: { + plugins: [createI18nInstance()] + } + }) + + await flushPromises() + + return { wrapper } +} + +describe('CloudSubscriptionRedirectView', () => { + beforeEach(() => { + vi.clearAllMocks() + mockQuery = {} + subscriptionMocks.isActiveSubscription.value = false + subscriptionMocks.isInitialized.value = true + }) + + test('redirects to home when subscriptionType is missing', async () => { + await mountView({}) + + expect(mockRouterPush).toHaveBeenCalledWith('/') + }) + + test('redirects to home when subscriptionType is invalid', async () => { + await mountView({ tier: 'invalid' }) + + expect(mockRouterPush).toHaveBeenCalledWith('/') + }) + + test('shows subscription copy when subscriptionType is valid', async () => { + const { wrapper } = await mountView({ tier: 'creator' }) + + // Should not redirect to home + expect(mockRouterPush).not.toHaveBeenCalledWith('/') + + // Shows copy under logo + expect(wrapper.text()).toContain('Subscribe to Creator') + + // Triggers checkout flow + expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith( + 'creator', + 'monthly', + false + ) + + // Shows loading affordances + expect(wrapper.findComponent({ name: 'ProgressSpinner' }).exists()).toBe( + true + ) + const skipLink = wrapper.get('a[href="/"]') + expect(skipLink.text()).toContain('Skip to the cloud app') + }) + + test('opens billing portal when subscription is already active', async () => { + subscriptionMocks.isActiveSubscription.value = true + + await mountView({ tier: 'creator' }) + + expect(mockRouterPush).not.toHaveBeenCalledWith('/') + expect(authActionMocks.accessBillingPortal).toHaveBeenCalledTimes(1) + expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled() + }) + + test('uses first value when subscriptionType is an array', async () => { + const { wrapper } = await mountView({ + tier: ['creator', 'pro'] + }) + + expect(mockRouterPush).not.toHaveBeenCalledWith('/') + expect(wrapper.text()).toContain('Subscribe to Creator') + expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith( + 'creator', + 'monthly', + false + ) + }) +}) diff --git a/src/platform/cloud/onboarding/CloudSubscriptionRedirectView.vue b/src/platform/cloud/onboarding/CloudSubscriptionRedirectView.vue new file mode 100644 index 000000000..d4da851ad --- /dev/null +++ b/src/platform/cloud/onboarding/CloudSubscriptionRedirectView.vue @@ -0,0 +1,130 @@ + + + diff --git a/src/platform/cloud/onboarding/UserCheckView.vue b/src/platform/cloud/onboarding/UserCheckView.vue index e2186ff48..e18596f29 100644 --- a/src/platform/cloud/onboarding/UserCheckView.vue +++ b/src/platform/cloud/onboarding/UserCheckView.vue @@ -78,7 +78,7 @@ const { } // User is fully onboarded (active or whitelist check disabled) - window.location.href = '/' + globalThis.location.href = '/' }), null, { resetOnExecute: false } diff --git a/src/platform/cloud/onboarding/onboardingCloudRoutes.ts b/src/platform/cloud/onboarding/onboardingCloudRoutes.ts index 1a613c02e..52ffc7943 100644 --- a/src/platform/cloud/onboarding/onboardingCloudRoutes.ts +++ b/src/platform/cloud/onboarding/onboardingCloudRoutes.ts @@ -65,6 +65,13 @@ export const cloudOnboardingRoutes: RouteRecordRaw[] = [ component: () => import('@/platform/cloud/onboarding/CloudAuthTimeoutView.vue'), props: true + }, + { + path: 'subscribe', + name: 'cloud-subscribe', + component: () => + import('@/platform/cloud/onboarding/CloudSubscriptionRedirectView.vue'), + meta: { requiresAuth: true } } ] } diff --git a/src/platform/cloud/onboarding/utils/previousFullPath.test.ts b/src/platform/cloud/onboarding/utils/previousFullPath.test.ts new file mode 100644 index 000000000..d74c1e609 --- /dev/null +++ b/src/platform/cloud/onboarding/utils/previousFullPath.test.ts @@ -0,0 +1,38 @@ +import { describe, expect, test } from 'vitest' +import type { LocationQuery } from 'vue-router' + +import { getSafePreviousFullPath } from './previousFullPath' + +describe('getSafePreviousFullPath', () => { + test('returns null when missing', () => { + expect(getSafePreviousFullPath({})).toBeNull() + }) + + test('decodes and returns internal relative paths', () => { + const query: LocationQuery = { + previousFullPath: encodeURIComponent('/some/path?x=1') + } + expect(getSafePreviousFullPath(query)).toBe('/some/path?x=1') + }) + + test('rejects protocol-relative urls', () => { + const query: LocationQuery = { + previousFullPath: encodeURIComponent('//evil.com') + } + expect(getSafePreviousFullPath(query)).toBeNull() + }) + + test('rejects absolute external urls', () => { + const query: LocationQuery = { + previousFullPath: encodeURIComponent('https://evil.com/path') + } + expect(getSafePreviousFullPath(query)).toBeNull() + }) + + test('rejects malformed encodings', () => { + const query: LocationQuery = { + previousFullPath: '%E0%A4%A' + } + expect(getSafePreviousFullPath(query)).toBeNull() + }) +}) diff --git a/src/platform/cloud/onboarding/utils/previousFullPath.ts b/src/platform/cloud/onboarding/utils/previousFullPath.ts new file mode 100644 index 000000000..9f6b257e0 --- /dev/null +++ b/src/platform/cloud/onboarding/utils/previousFullPath.ts @@ -0,0 +1,27 @@ +import type { LocationQuery } from 'vue-router' + +const decodeQueryParam = (value: string): string | null => { + try { + return decodeURIComponent(value) + } catch { + return null + } +} + +const isSafeInternalRedirectPath = (path: string): boolean => { + // Must be a relative in-app path. Disallow protocol-relative URLs ("//evil.com"). + return path.startsWith('/') && !path.startsWith('//') +} + +export const getSafePreviousFullPath = ( + query: LocationQuery +): string | null => { + const raw = query.previousFullPath + const value = Array.isArray(raw) ? raw[0] : raw + if (!value) return null + + const decoded = decodeQueryParam(value) + if (!decoded) return null + + return isSafeInternalRedirectPath(decoded) ? decoded : null +} diff --git a/src/platform/cloud/subscription/components/PricingTable.vue b/src/platform/cloud/subscription/components/PricingTable.vue index 7f1a66834..1520a94bf 100644 --- a/src/platform/cloud/subscription/components/PricingTable.vue +++ b/src/platform/cloud/subscription/components/PricingTable.vue @@ -252,7 +252,6 @@ import { useI18n } from 'vue-i18n' import Button from '@/components/ui/button/Button.vue' import { useFirebaseAuthActions } from '@/composables/auth/useFirebaseAuthActions' import { useErrorHandling } from '@/composables/useErrorHandling' -import { getComfyApiBaseUrl } from '@/config/comfyApi' import { t } from '@/i18n' import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription' import { @@ -263,13 +262,10 @@ import type { TierKey, TierPricing } from '@/platform/cloud/subscription/constants/tierPricing' +import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil' import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank' import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank' import { isCloud } from '@/platform/distribution/types' -import { - FirebaseAuthStoreError, - useFirebaseAuthStore -} from '@/stores/firebaseAuthStore' import type { components } from '@/types/comfyRegistryTypes' type SubscriptionTier = components['schemas']['SubscriptionTier'] @@ -332,7 +328,6 @@ const tiers: PricingTierConfig[] = [ ] const { n } = useI18n() -const { getAuthHeader } = useFirebaseAuthStore() const { isActiveSubscription, subscriptionTier, isYearlySubscription } = useSubscription() const { accessBillingPortal, reportError } = useFirebaseAuthActions() @@ -384,12 +379,13 @@ const getButtonLabel = (tier: PricingTierConfig): string => { : t('subscription.subscribeTo', { plan: planName }) } -const getButtonSeverity = (tier: PricingTierConfig): 'primary' | 'secondary' => - isCurrentPlan(tier.key) - ? 'secondary' - : tier.key === 'creator' - ? 'primary' - : 'secondary' +const getButtonSeverity = ( + tier: PricingTierConfig +): 'primary' | 'secondary' => { + if (isCurrentPlan(tier.key)) return 'secondary' + if (tier.key === 'creator') return 'primary' + return 'secondary' +} const getButtonTextClass = (tier: PricingTierConfig): string => tier.key === 'creator' @@ -405,47 +401,6 @@ const getAnnualTotal = (tier: PricingTierConfig): number => const getCreditsDisplay = (tier: PricingTierConfig): number => tier.pricing.credits * (currentBillingCycle.value === 'yearly' ? 12 : 1) -const initiateCheckout = async (tierKey: CheckoutTierKey) => { - const authHeader = await getAuthHeader() - if (!authHeader) { - throw new FirebaseAuthStoreError(t('toastMessages.userNotAuthenticated')) - } - - const checkoutTier = getCheckoutTier(tierKey, currentBillingCycle.value) - const response = await fetch( - `${getComfyApiBaseUrl()}/customers/cloud-subscription-checkout/${checkoutTier}`, - { - method: 'POST', - headers: { ...authHeader, 'Content-Type': 'application/json' } - } - ) - - if (!response.ok) { - let errorMessage = 'Failed to initiate checkout' - try { - const errorData = await response.json() - errorMessage = errorData.message || errorMessage - } catch { - // If JSON parsing fails, try to get text response or use HTTP status - try { - const errorText = await response.text() - errorMessage = - errorText || `HTTP ${response.status} ${response.statusText}` - } catch { - errorMessage = `HTTP ${response.status} ${response.statusText}` - } - } - - throw new FirebaseAuthStoreError( - t('toastMessages.failedToInitiateSubscription', { - error: errorMessage - }) - ) - } - - return await response.json() -} - const handleSubscribe = wrapWithErrorHandlingAsync( async (tierKey: CheckoutTierKey) => { if (!isCloud || isLoading.value || isCurrentPlan(tierKey)) return @@ -475,10 +430,7 @@ const handleSubscribe = wrapWithErrorHandlingAsync( await accessBillingPortal(checkoutTier) } } else { - const response = await initiateCheckout(tierKey) - if (response.checkout_url) { - window.open(response.checkout_url, '_blank') - } + await performSubscriptionCheckout(tierKey, currentBillingCycle.value) } } finally { isLoading.value = false diff --git a/src/platform/cloud/subscription/composables/useSubscription.ts b/src/platform/cloud/subscription/composables/useSubscription.ts index bb5a3b6c0..e80c3642e 100644 --- a/src/platform/cloud/subscription/composables/useSubscription.ts +++ b/src/platform/cloud/subscription/composables/useSubscription.ts @@ -28,6 +28,7 @@ export type CloudSubscriptionStatusResponse = NonNullable< function useSubscriptionInternal() { const subscriptionStatus = ref(null) const telemetry = useTelemetry() + const isInitialized = ref(false) const isSubscribedOrIsNotCloud = computed(() => { if (!isCloud || !window.__CONFIG__?.subscription_required) return true @@ -200,10 +201,15 @@ function useSubscriptionInternal() { () => isLoggedIn.value, async (loggedIn) => { if (loggedIn) { - await fetchSubscriptionStatus() + try { + await fetchSubscriptionStatus() + } finally { + isInitialized.value = true + } } else { subscriptionStatus.value = null stopCancellationWatcher() + isInitialized.value = true } }, { immediate: true } @@ -244,6 +250,7 @@ function useSubscriptionInternal() { return { // State isActiveSubscription: isSubscribedOrIsNotCloud, + isInitialized, isCancelled, formattedRenewalDate, formattedEndDate, diff --git a/src/platform/cloud/subscription/utils/subscriptionCheckoutUtil.ts b/src/platform/cloud/subscription/utils/subscriptionCheckoutUtil.ts new file mode 100644 index 000000000..bc77e68e9 --- /dev/null +++ b/src/platform/cloud/subscription/utils/subscriptionCheckoutUtil.ts @@ -0,0 +1,87 @@ +import { getComfyApiBaseUrl } from '@/config/comfyApi' +import { t } from '@/i18n' +import { isCloud } from '@/platform/distribution/types' +import { + FirebaseAuthStoreError, + useFirebaseAuthStore +} from '@/stores/firebaseAuthStore' +import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing' +import type { BillingCycle } from './subscriptionTierRank' + +type CheckoutTier = TierKey | `${TierKey}-yearly` + +const getCheckoutTier = ( + tierKey: TierKey, + billingCycle: BillingCycle +): CheckoutTier => (billingCycle === 'yearly' ? `${tierKey}-yearly` : tierKey) + +/** + * Core subscription checkout logic shared between PricingTable and + * SubscriptionRedirectView. Handles: + * - Ensuring the user is authenticated + * - Calling the backend checkout endpoint + * - Normalizing error responses + * - Opening the checkout URL in a new tab when available + * + * Callers are responsible for: + * - Guarding on cloud-only behavior (isCloud) + * - Managing loading state + * - Wrapping with error handling (e.g. useErrorHandling) + */ +export async function performSubscriptionCheckout( + tierKey: TierKey, + currentBillingCycle: BillingCycle, + openInNewTab: boolean = true +): Promise { + if (!isCloud) return + + const { getAuthHeader } = useFirebaseAuthStore() + const authHeader = await getAuthHeader() + + if (!authHeader) { + throw new FirebaseAuthStoreError(t('toastMessages.userNotAuthenticated')) + } + + const checkoutTier = getCheckoutTier(tierKey, currentBillingCycle) + + const response = await fetch( + `${getComfyApiBaseUrl()}/customers/cloud-subscription-checkout/${checkoutTier}`, + { + method: 'POST', + headers: { ...authHeader, 'Content-Type': 'application/json' } + } + ) + + if (!response.ok) { + let errorMessage = 'Failed to initiate checkout' + try { + const errorData = await response.json() + errorMessage = errorData.message || errorMessage + } catch { + // If JSON parsing fails, try to get text response or use HTTP status + try { + const errorText = await response.text() + errorMessage = + errorText || `HTTP ${response.status} ${response.statusText}` + } catch { + errorMessage = `HTTP ${response.status} ${response.statusText}` + } + } + + throw new FirebaseAuthStoreError( + t('toastMessages.failedToInitiateSubscription', { + error: errorMessage + }) + ) + } + + const data = await response.json() + + if (data.checkout_url) { + if (openInNewTab) { + window.open(data.checkout_url, '_blank') + } else { + globalThis.location.href = data.checkout_url + } + } +} diff --git a/src/router.ts b/src/router.ts index 5ca66c953..102ec8f2a 100644 --- a/src/router.ts +++ b/src/router.ts @@ -149,9 +149,17 @@ if (isCloud) { return next() } + const query = + to.fullPath === '/' + ? undefined + : { previousFullPath: encodeURIComponent(to.fullPath) } + // Check if route requires authentication if (to.meta.requiresAuth && !isLoggedIn) { - return next({ name: 'cloud-login' }) + return next({ + name: 'cloud-login', + query + }) } // Handle other protected routes @@ -164,7 +172,10 @@ if (isCloud) { } // For web, redirect to login - return next({ name: 'cloud-login' }) + return next({ + name: 'cloud-login', + query + }) } // User is logged in - check if they need onboarding (when enabled)