diff --git a/browser_tests/fixtures/selectors.ts b/browser_tests/fixtures/selectors.ts index 82eec114ed..08e7d76195 100644 --- a/browser_tests/fixtures/selectors.ts +++ b/browser_tests/fixtures/selectors.ts @@ -137,7 +137,8 @@ export const TestIds = { colorPickerCurrentColor: 'color-picker-current-color', colorBlue: 'blue', colorRed: 'red', - convertSubgraph: 'convert-to-subgraph-button' + convertSubgraph: 'convert-to-subgraph-button', + bypass: 'bypass-button' }, menu: { moreMenuContent: 'more-menu-content' diff --git a/browser_tests/tests/selectionToolbox.spec.ts b/browser_tests/tests/selectionToolbox.spec.ts index 03ac01ac1e..4ac3376762 100644 --- a/browser_tests/tests/selectionToolbox.spec.ts +++ b/browser_tests/tests/selectionToolbox.spec.ts @@ -129,23 +129,18 @@ test.describe('Selection Toolbox', { tag: ['@screenshot', '@ui'] }, () => { }) => { // A group + a KSampler node await comfyPage.workflow.loadWorkflow('groups/single_group') + const bypass = comfyPage.page.getByTestId(TestIds.selectionToolbox.bypass) // Select group + node should show bypass button await comfyPage.canvas.focus() - await comfyPage.page.keyboard.press('Control+A') - await expect( - comfyPage.page.locator( - '.selection-toolbox *[data-testid="bypass-button"]' - ) - ).toBeVisible() - - // Deselect node (Only group is selected) should hide bypass button await comfyPage.nodeOps.selectNodes(['KSampler']) - await expect( - comfyPage.page.locator( - '.selection-toolbox *[data-testid="bypass-button"]' - ) - ).toBeHidden() + await expect(bypass).toBeVisible() + await comfyPage.keyboard.delete() + + // (Only empty group is selected) should hide bypass button + await comfyPage.keyboard.selectAll() + await expect(comfyPage.selectionToolbox).toBeVisible() + await expect(bypass).toBeHidden() }) test.describe('Color Picker', () => { diff --git a/browser_tests/tests/vueNodes/groups/groups.spec.ts b/browser_tests/tests/vueNodes/groups/groups.spec.ts index e5b8d586e1..9640678673 100644 --- a/browser_tests/tests/vueNodes/groups/groups.spec.ts +++ b/browser_tests/tests/vueNodes/groups/groups.spec.ts @@ -3,6 +3,8 @@ import { comfyPageFixture as test } from '@e2e/fixtures/ComfyPage' import type { ComfyPage } from '@e2e/fixtures/ComfyPage' +import { TestIds } from '@e2e/fixtures/selectors' +import { getGroupTitlePosition } from '@e2e/fixtures/utils/groupHelpers' const CREATE_GROUP_HOTKEY = 'Control+g' @@ -217,4 +219,40 @@ test.describe('Vue Node Groups', { tag: ['@screenshot', '@vue-nodes'] }, () => { ) }).toPass({ timeout: 5000 }) }) + + test('Bypassing a group bypasses contents', async ({ comfyPage }) => { + await comfyPage.settings.setSetting('Comfy.Canvas.SelectionToolbox', true) + await comfyPage.keyboard.selectAll() + await comfyPage.page.keyboard.press('.') + await comfyPage.page.keyboard.press(CREATE_GROUP_HOTKEY) + + const toggleBypass = () => + comfyPage.page.getByTestId(TestIds.selectionToolbox.bypass).click() + const bypassCount = () => + comfyPage.page.evaluate( + () => graph!.nodes.filter((node) => node.mode === 4).length + ) + expect(await bypassCount()).toBe(0) + const groupCount = () => comfyPage.page.evaluate(() => graph!.groups.length) + await expect.poll(groupCount, 'create group').toBe(1) + + const ksampler = await comfyPage.vueNodes.getFixtureByTitle('KSampler') + await ksampler.select() + await toggleBypass() + await expect.poll(bypassCount, 'setup bypass of single node').toBe(1) + + const groupPos = await getGroupTitlePosition(comfyPage, 'Group') + await comfyPage.page.mouse.click(groupPos.x, groupPos.y) + await toggleBypass() + await expect.poll(bypassCount, 'all nodes are set to bypassed').toBe(7) + await toggleBypass() + await expect.poll(bypassCount, 'all nodes are unbypassed').toBe(0) + + await comfyPage.page.keyboard.down('Shift') + await ksampler.select() + await comfyPage.page.keyboard.up('Shift') + + await toggleBypass() + await expect.poll(bypassCount, "won't toggle double selected node").toBe(7) + }) }) diff --git a/src/components/graph/SelectionToolbox.vue b/src/components/graph/SelectionToolbox.vue index ec807268bd..eb0299606c 100644 --- a/src/components/graph/SelectionToolbox.vue +++ b/src/components/graph/SelectionToolbox.vue @@ -101,6 +101,7 @@ const extensionToolboxCommands = computed(() => { const { hasAnySelection, + hasGroupedNodesSelection, hasMultipleSelection, isSingleNode, isSingleSubgraph, @@ -118,7 +119,10 @@ const showSubgraphButtons = computed(() => isSingleSubgraph.value) const showBypass = computed( () => - isSingleNode.value || isSingleSubgraph.value || hasMultipleSelection.value + isSingleNode.value || + isSingleSubgraph.value || + hasMultipleSelection.value || + hasGroupedNodesSelection.value ) const showLoad3DViewer = computed(() => hasAny3DNodeSelected.value) const showMaskEditor = computed(() => isSingleImageNode.value) diff --git a/src/composables/billing/types.ts b/src/composables/billing/types.ts index 1c2f9477b6..6bc9f180e3 100644 --- a/src/composables/billing/types.ts +++ b/src/composables/billing/types.ts @@ -2,6 +2,9 @@ import type { ComputedRef, Ref } from 'vue' import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing' import type { + BillingStatus, + BillingSubscriptionStatus, + CreateTopupResponse, Plan, PreviewSubscribeResponse, SubscribeResponse, @@ -16,7 +19,9 @@ export interface SubscriptionInfo { tier: SubscriptionTier | null duration: SubscriptionDuration | null planSlug: string | null + /** ISO 8601 */ renewalDate: string | null + /** ISO 8601 */ endDate: string | null isCancelled: boolean hasFunds: boolean @@ -44,6 +49,9 @@ export interface BillingActions { ) => Promise manageSubscription: () => Promise cancelSubscription: () => Promise + resubscribe: () => Promise + /** `amountCents` must be a whole-dollar multiple of 100. */ + topup: (amountCents: number) => Promise fetchPlans: () => Promise /** * Ensures billing is initialized and subscription is active. @@ -65,16 +73,12 @@ export interface BillingState { currentPlanSlug: ComputedRef isLoading: Ref error: Ref - /** - * Convenience computed for checking if subscription is active. - * Equivalent to `subscription.value?.isActive ?? false` - */ isActiveSubscription: ComputedRef - /** - * Whether the current billing context has a FREE tier subscription. - * Workspace-aware: reflects the active workspace's tier, not the user's personal tier. - */ isFreeTier: ComputedRef + billingStatus: ComputedRef + subscriptionStatus: ComputedRef + tier: ComputedRef + renewalDate: ComputedRef } export interface BillingContext extends BillingState, BillingActions { diff --git a/src/composables/billing/useBillingContext.test.ts b/src/composables/billing/useBillingContext.test.ts index 7ab80bb2a5..7a63702ee3 100644 --- a/src/composables/billing/useBillingContext.test.ts +++ b/src/composables/billing/useBillingContext.test.ts @@ -5,13 +5,17 @@ import type { Plan } from '@/platform/workspace/api/workspaceApi' import { useBillingContext } from './useBillingContext' -const { mockTeamWorkspacesEnabled, mockIsPersonal, mockPlans } = vi.hoisted( - () => ({ - mockTeamWorkspacesEnabled: { value: false }, - mockIsPersonal: { value: true }, - mockPlans: { value: [] as Plan[] } - }) -) +const { + mockTeamWorkspacesEnabled, + mockIsPersonal, + mockPlans, + mockPurchaseCredits +} = vi.hoisted(() => ({ + mockTeamWorkspacesEnabled: { value: false }, + mockIsPersonal: { value: true }, + mockPlans: { value: [] as Plan[] }, + mockPurchaseCredits: vi.fn() +})) vi.mock('@vueuse/core', async (importOriginal) => { const original = await importOriginal() @@ -50,8 +54,9 @@ vi.mock('@/platform/cloud/subscription/composables/useSubscription', () => ({ isActiveSubscription: { value: true }, subscriptionTier: { value: 'PRO' }, subscriptionDuration: { value: 'MONTHLY' }, - formattedRenewalDate: { value: 'Jan 1, 2025' }, - formattedEndDate: { value: '' }, + subscriptionStatus: { + value: { renewal_date: '2025-01-01T00:00:00Z', end_date: null } + }, isCancelled: { value: false }, fetchStatus: vi.fn().mockResolvedValue(undefined), manageSubscription: vi.fn().mockResolvedValue(undefined), @@ -70,6 +75,12 @@ vi.mock( }) ) +vi.mock('@/composables/auth/useAuthActions', () => ({ + useAuthActions: () => ({ + purchaseCredits: mockPurchaseCredits + }) +})) + vi.mock('@/stores/authStore', () => ({ useAuthStore: () => ({ balance: { amount_micros: 5000000 }, @@ -129,7 +140,7 @@ describe('useBillingContext', () => { tier: 'PRO', duration: 'MONTHLY', planSlug: null, - renewalDate: 'Jan 1, 2025', + renewalDate: '2025-01-01T00:00:00Z', endDate: null, isCancelled: false, hasFunds: true @@ -173,6 +184,13 @@ describe('useBillingContext', () => { await expect(manageSubscription()).resolves.toBeUndefined() }) + it('converts topup cents to whole dollars for the legacy credit endpoint', async () => { + const { topup } = useBillingContext() + await topup(500) + + expect(mockPurchaseCredits).toHaveBeenCalledWith(5) + }) + it('provides isActiveSubscription convenience computed', () => { const { isActiveSubscription } = useBillingContext() expect(isActiveSubscription.value).toBe(true) diff --git a/src/composables/billing/useBillingContext.ts b/src/composables/billing/useBillingContext.ts index ce1a36cfae..639e45130a 100644 --- a/src/composables/billing/useBillingContext.ts +++ b/src/composables/billing/useBillingContext.ts @@ -122,6 +122,15 @@ function useBillingContextInternal(): BillingContext { const isFreeTier = computed(() => subscription.value?.tier === 'FREE') + const billingStatus = computed(() => + toValue(activeContext.value.billingStatus) + ) + const subscriptionStatus = computed(() => + toValue(activeContext.value.subscriptionStatus) + ) + const tier = computed(() => toValue(activeContext.value.tier)) + const renewalDate = computed(() => toValue(activeContext.value.renewalDate)) + function getMaxSeats(tierKey: TierKey): number { if (type.value === 'legacy') return 1 @@ -218,6 +227,14 @@ function useBillingContextInternal(): BillingContext { return activeContext.value.cancelSubscription() } + async function resubscribe() { + return activeContext.value.resubscribe() + } + + async function topup(amountCents: number) { + return activeContext.value.topup(amountCents) + } + async function fetchPlans() { return activeContext.value.fetchPlans() } @@ -241,6 +258,10 @@ function useBillingContextInternal(): BillingContext { error, isActiveSubscription, isFreeTier, + billingStatus, + subscriptionStatus, + tier, + renewalDate, getMaxSeats, initialize, @@ -250,6 +271,8 @@ function useBillingContextInternal(): BillingContext { previewSubscribe, manageSubscription, cancelSubscription, + resubscribe, + topup, fetchPlans, requireActiveSubscription, showSubscriptionDialog diff --git a/src/composables/billing/useLegacyBilling.ts b/src/composables/billing/useLegacyBilling.ts index c86bbc55bc..1e52599605 100644 --- a/src/composables/billing/useLegacyBilling.ts +++ b/src/composables/billing/useLegacyBilling.ts @@ -1,7 +1,10 @@ import { computed, ref } from 'vue' +import { useAuthActions } from '@/composables/auth/useAuthActions' import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription' import type { + BillingStatus, + BillingSubscriptionStatus, PreviewSubscribeResponse, SubscribeResponse } from '@/platform/workspace/api/workspaceApi' @@ -24,8 +27,7 @@ export function useLegacyBilling(): BillingState & BillingActions { isActiveSubscription: legacyIsActiveSubscription, subscriptionTier, subscriptionDuration, - formattedRenewalDate, - formattedEndDate, + subscriptionStatus: legacySubscriptionStatus, isCancelled, fetchStatus: legacyFetchStatus, manageSubscription: legacyManageSubscription, @@ -34,6 +36,7 @@ export function useLegacyBilling(): BillingState & BillingActions { } = useSubscription() const authStore = useAuthStore() + const authActions = useAuthActions() const isInitialized = ref(false) const isLoading = ref(false) @@ -52,8 +55,8 @@ export function useLegacyBilling(): BillingState & BillingActions { tier: subscriptionTier.value, duration: subscriptionDuration.value, planSlug: null, // Legacy doesn't use plan slugs - renewalDate: formattedRenewalDate.value || null, - endDate: formattedEndDate.value || null, + renewalDate: legacySubscriptionStatus.value?.renewal_date ?? null, + endDate: legacySubscriptionStatus.value?.end_date ?? null, isCancelled: isCancelled.value, hasFunds: (authStore.balance?.amount_micros ?? 0) > 0 } @@ -75,6 +78,18 @@ export function useLegacyBilling(): BillingState & BillingActions { } }) + // Legacy has no coarse billing_status concept (workspace-only). + const billingStatus = computed(() => null) + const subscriptionStatus = computed(() => { + if (isCancelled.value) return 'canceled' + if (legacyIsActiveSubscription.value) return 'active' + return null + }) + const tier = computed(() => subscriptionTier.value) + const renewalDate = computed( + () => legacySubscriptionStatus.value?.renewal_date ?? null + ) + // Legacy billing doesn't have workspace-style plans const plans = computed(() => []) const currentPlanSlug = computed(() => null) @@ -152,6 +167,16 @@ export function useLegacyBilling(): BillingState & BillingActions { await legacyManageSubscription() } + async function resubscribe(): Promise { + // Legacy has no resubscribe endpoint; resubscribing is a fresh checkout. + await legacySubscribe() + } + + async function topup(amountCents: number): Promise { + // Facade standardizes on cents; legacy /customers/credit takes dollars. + await authActions.purchaseCredits(amountCents / 100) + } + async function fetchPlans(): Promise { // Legacy billing doesn't have workspace-style plans // Plans are hardcoded in the UI for legacy subscriptions @@ -179,6 +204,10 @@ export function useLegacyBilling(): BillingState & BillingActions { error, isActiveSubscription, isFreeTier, + billingStatus, + subscriptionStatus, + tier, + renewalDate, // Actions initialize, @@ -188,6 +217,8 @@ export function useLegacyBilling(): BillingState & BillingActions { previewSubscribe, manageSubscription, cancelSubscription, + resubscribe, + topup, fetchPlans, requireActiveSubscription, showSubscriptionDialog diff --git a/src/composables/canvas/useSelectedLiteGraphItems.ts b/src/composables/canvas/useSelectedLiteGraphItems.ts index 48246afd68..fb5b5c1745 100644 --- a/src/composables/canvas/useSelectedLiteGraphItems.ts +++ b/src/composables/canvas/useSelectedLiteGraphItems.ts @@ -1,8 +1,10 @@ +import { uniq } from 'es-toolkit' + import type { LGraphNode, Positionable } from '@/lib/litegraph/src/litegraph' import { LGraphEventMode, Reroute } from '@/lib/litegraph/src/litegraph' import { useCanvasStore } from '@/renderer/core/canvas/canvasStore' import { collectFromNodes } from '@/utils/graphTraversalUtil' -import { isLGraphNode } from '@/utils/litegraphUtil' +import { isLGraphGroup, isLGraphNode } from '@/utils/litegraphUtil' /** * Composable for handling selected LiteGraph items filtering and operations. @@ -71,7 +73,13 @@ export function useSelectedLiteGraphItems() { * the prior null-tolerance for callers wired to early-firing commands. */ const getSelectedNodesShallow = (): LGraphNode[] => - Array.from(canvasStore.canvas?.selectedItems ?? []).filter(isLGraphNode) + uniq( + [...(canvasStore.canvas?.selectedItems ?? [])].flatMap((item) => { + if (isLGraphNode(item)) return [item] + if (isLGraphGroup(item)) return [...item.children].filter(isLGraphNode) + return [] + }) + ) /** * Get only the selected nodes (LGraphNode instances) from the canvas. diff --git a/src/composables/graph/useSelectionState.ts b/src/composables/graph/useSelectionState.ts index 20654acc8f..03a6ab694f 100644 --- a/src/composables/graph/useSelectionState.ts +++ b/src/composables/graph/useSelectionState.ts @@ -7,7 +7,12 @@ import { useSettingStore } from '@/platform/settings/settingStore' import { useCanvasStore } from '@/renderer/core/canvas/canvasStore' import { useNodeDefStore } from '@/stores/nodeDefStore' import { useRightSidePanelStore } from '@/stores/workspace/rightSidePanelStore' -import { isImageNode, isLGraphNode, isLoad3dNode } from '@/utils/litegraphUtil' +import { + isImageNode, + isLGraphGroup, + isLGraphNode, + isLoad3dNode +} from '@/utils/litegraphUtil' import { filterOutputNodes } from '@/utils/nodeFilterUtil' export interface NodeSelectionState { @@ -41,6 +46,11 @@ export function useSelectionState() { const hasAnySelection = computed(() => selectedItems.value.length > 0) const hasSingleSelection = computed(() => selectedItems.value.length === 1) const hasMultipleSelection = computed(() => selectedItems.value.length > 1) + const hasGroupedNodesSelection = computed(() => + selectedItems.value.some( + (item) => isLGraphGroup(item) && [...item.children].some(isLGraphNode) + ) + ) const isSingleNode = computed( () => hasSingleSelection.value && isLGraphNode(selectedItems.value[0]) @@ -112,6 +122,7 @@ export function useSelectionState() { openNodeInfo, hasAny3DNodeSelected, hasAnySelection, + hasGroupedNodesSelection, hasSingleSelection, hasMultipleSelection, isSingleNode, diff --git a/src/platform/cloud/oauth/OAuthConsentView.test.ts b/src/platform/cloud/oauth/OAuthConsentView.test.ts index a3a12c8e8e..40195e10e5 100644 --- a/src/platform/cloud/oauth/OAuthConsentView.test.ts +++ b/src/platform/cloud/oauth/OAuthConsentView.test.ts @@ -147,7 +147,8 @@ describe('OAuthConsentView', () => { oauthRequestId: '550e8400-e29b-41d4-a716-446655440000', csrfToken: 'csrf-token', decision: 'allow', - workspaceId: 'personal-workspace' + workspaceId: 'personal-workspace', + expectedRedirectUri: 'http://127.0.0.1:50632/cb' }) }) diff --git a/src/platform/cloud/oauth/OAuthConsentView.vue b/src/platform/cloud/oauth/OAuthConsentView.vue index dba885f8fe..4b33a5613f 100644 --- a/src/platform/cloud/oauth/OAuthConsentView.vue +++ b/src/platform/cloud/oauth/OAuthConsentView.vue @@ -283,7 +283,8 @@ async function submit(decision: 'allow' | 'deny') { oauthRequestId: challenge.value.oauth_request_id, csrfToken: challenge.value.csrf_token, decision, - workspaceId + workspaceId, + expectedRedirectUri: challenge.value.redirect_uri }) clearOAuthRequestId() } catch (error) { diff --git a/src/platform/cloud/oauth/oauthApi.test.ts b/src/platform/cloud/oauth/oauthApi.test.ts index 94c3022f7d..c750d37964 100644 --- a/src/platform/cloud/oauth/oauthApi.test.ts +++ b/src/platform/cloud/oauth/oauthApi.test.ts @@ -220,6 +220,111 @@ describe('submitOAuthConsentDecision', () => { ).rejects.toThrow('redirect_url') }) + it('navigates to a reverse-DNS custom-scheme redirect_url (native clients)', async () => { + // RFC 8252 native-app callback — the comfy-ios client returns the + // authorization code via org.comfy.ios://oauth-callback. The backend + // has already validated the URL byte-identically against the client's + // registered redirect_uris. + vi.spyOn(globalThis, 'fetch').mockResolvedValue( + okResponse({ + redirect_url: 'org.comfy.ios://oauth-callback?code=xyz&state=s' + }) + ) + const originalLocation = globalThis.location + const hrefSetter = vi.fn() + Object.defineProperty(globalThis, 'location', { + configurable: true, + value: new Proxy(originalLocation, { + set(_target, prop, value) { + if (prop === 'href') { + hrefSetter(value) + return true + } + return Reflect.set(originalLocation, prop, value) + }, + get(_target, prop) { + return Reflect.get(originalLocation, prop) + } + }) + }) + + try { + await submitOAuthConsentDecision({ + oauthRequestId: validChallenge.oauth_request_id, + csrfToken: validChallenge.csrf_token, + decision: 'allow', + workspaceId: 'personal-workspace', + expectedRedirectUri: 'org.comfy.ios://oauth-callback' + }) + + expect(hrefSetter).toHaveBeenCalledWith( + 'org.comfy.ios://oauth-callback?code=xyz&state=s' + ) + expect(hrefSetter).toHaveBeenCalledTimes(1) + } finally { + Object.defineProperty(globalThis, 'location', { + configurable: true, + value: originalLocation + }) + } + }) + + it.for([ + [ + 'org.comfy.ios://oauth-callback?code=xyz', + undefined, + 'unsafe scheme', + 'custom scheme with no expectedRedirectUri is unbindable, falls back to the http(s)-only rule' + ], + [ + 'com.evil.app://oauth-callback?code=xyz', + 'org.comfy.ios://oauth-callback', + 'does not match', + 'bound challenge, different scheme: wrong-client redirect' + ], + [ + 'org.comfy.ios://oauth-callback/../steal?code=xyz', + 'org.comfy.ios://oauth-callback', + 'does not match', + 'bound challenge, same scheme but different path' + ], + [ + 'javascript:alert(1)', + 'javascript:alert(1)', + 'unsafe scheme', + 'executable schemes are rejected even if the challenge claims them' + ], + [ + 'data:text/html,', + 'data:text/html,x', + 'unsafe scheme', + 'data: scheme rejected even if the challenge claims it' + ], + [ + 'blob:https://cloud.comfy.org/abc', + undefined, + 'unsafe scheme', + 'blob: scheme is unsafe' + ] + ] as const)( + 'rejects redirect_url %s (registration %s, expects %s): %s', + async ([redirectUrl, expectedRedirectUri, expectedError]) => { + vi.spyOn(globalThis, 'fetch').mockResolvedValue( + okResponse({ redirect_url: redirectUrl }) + ) + + await expect( + submitOAuthConsentDecision({ + oauthRequestId: validChallenge.oauth_request_id, + csrfToken: validChallenge.csrf_token, + decision: 'allow', + workspaceId: 'personal-workspace', + expectedRedirectUri + }) + ).rejects.toThrow(expectedError) + } + ) + it('rejects an unsafe redirect_url scheme', async () => { // Defense in depth: even though the cloud backend is trusted, never // hand the browser off to a non-http(s) URL. diff --git a/src/platform/cloud/oauth/oauthApi.ts b/src/platform/cloud/oauth/oauthApi.ts index 950a0a76ea..c92578e975 100644 --- a/src/platform/cloud/oauth/oauthApi.ts +++ b/src/platform/cloud/oauth/oauthApi.ts @@ -40,12 +40,33 @@ export type OAuthConsentDecisionParams = { csrfToken: string decision: 'allow' | 'deny' workspaceId: string + /** + * The challenge's registered `redirect_uri`. When present, the + * post-consent navigation must match it (scheme, authority, path) — + * the server only appends `code`/`state` query params to the + * registered URI, so any other destination is rejected. When absent + * (challenges from backends that don't surface it yet), only http(s) + * redirects are navigable. + */ + expectedRedirectUri?: string } export type OAuthConsentDecision = ( params: OAuthConsentDecisionParams ) => Promise +// Schemes that execute in our origin if navigated. Never navigable, +// regardless of what the backend returns. Everything else is governed +// by binding to the challenge's registered redirect_uri — no per-client +// scheme knowledge lives in the frontend. +const EXECUTABLE_SCHEMES: ReadonlySet = new Set([ + 'javascript:', + 'data:', + 'blob:', + 'vbscript:', + 'about:' +]) + export class OAuthApiError extends Error { constructor( message: string, @@ -118,7 +139,8 @@ export async function submitOAuthConsentDecision({ oauthRequestId, csrfToken, decision, - workspaceId + workspaceId, + expectedRedirectUri }: OAuthConsentDecisionParams): Promise { const response = await fetch('/oauth/authorize', { method: 'POST', @@ -144,13 +166,56 @@ export async function submitOAuthConsentDecision({ throw new Error('OAuth consent response did not include redirect_url') } - // Defense in depth: even though the cloud backend is trusted, never hand - // the browser off to a non-http(s) scheme. javascript:/data: URLs would - // execute in our origin. - const target = new URL(redirectUrl, globalThis.location.origin) - if (target.protocol !== 'http:' && target.protocol !== 'https:') { + // Defense in depth at this sink. Two risks: schemes that execute in our + // origin (always rejected, below), and the OS routing the authorization + // code + state to whichever installed app claims an arbitrary custom + // scheme. For the latter we hold the navigation to the redirect the + // backend registered for THIS auth request (the challenge's + // redirect_uri): the server only ever appends code/state query params + // to the registered URI, so scheme, authority, and path must match + // exactly. No per-client scheme list lives in the frontend — new native + // clients need only their backend registration. + const parseTarget = () => { + try { + return new URL(redirectUrl, globalThis.location.origin) + } catch (err) { + throw new Error('OAuth consent redirect_url is not a valid URL', { + cause: err + }) + } + } + const target = parseTarget() + if (EXECUTABLE_SCHEMES.has(target.protocol)) { + throw new Error('OAuth consent redirect_url has an unsafe scheme') + } + if (expectedRedirectUri) { + const parseExpected = () => { + try { + return new URL(expectedRedirectUri) + } catch (err) { + throw new Error( + 'OAuth consent challenge redirect_uri is not a valid URL', + { cause: err } + ) + } + } + const expected = parseExpected() + const matchesRegistration = + target.protocol === expected.protocol && + target.host === expected.host && + target.pathname === expected.pathname + if (!matchesRegistration) { + throw new Error( + 'OAuth consent redirect_url does not match the registered redirect_uri' + ) + } + } else if (target.protocol !== 'http:' && target.protocol !== 'https:') { + // Challenges that don't surface redirect_uri can't be bound; hold the + // pre-existing http(s)-only line for them. throw new Error('OAuth consent redirect_url has an unsafe scheme') } - globalThis.location.href = redirectUrl + // Navigate the parsed URL, not the raw string, so the value validated + // above is byte-for-byte the value the browser receives. + globalThis.location.href = target.href } diff --git a/src/platform/workspace/api/workspaceApi.ts b/src/platform/workspace/api/workspaceApi.ts index 9a881e61d6..05a930e9b8 100644 --- a/src/platform/workspace/api/workspaceApi.ts +++ b/src/platform/workspace/api/workspaceApi.ts @@ -196,9 +196,13 @@ export interface PreviewSubscribeResponse { new_plan: PreviewPlanInfo } -type BillingSubscriptionStatus = 'active' | 'scheduled' | 'ended' | 'canceled' +export type BillingSubscriptionStatus = + | 'active' + | 'scheduled' + | 'ended' + | 'canceled' -type BillingStatus = +export type BillingStatus = | 'awaiting_payment_method' | 'pending_payment' | 'paid' @@ -233,7 +237,7 @@ interface CreateTopupRequest { type TopupStatus = 'pending' | 'completed' | 'failed' -interface CreateTopupResponse { +export interface CreateTopupResponse { billing_op_id: string topup_id: string status: TopupStatus diff --git a/src/platform/workspace/components/SubscriptionPanelContentWorkspace.vue b/src/platform/workspace/components/SubscriptionPanelContentWorkspace.vue index 7e47bc01f3..12a295beea 100644 --- a/src/platform/workspace/components/SubscriptionPanelContentWorkspace.vue +++ b/src/platform/workspace/components/SubscriptionPanelContentWorkspace.vue @@ -371,7 +371,6 @@ import { useBillingContext } from '@/composables/billing/useBillingContext' import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore' import { useSubscriptionActions } from '@/platform/cloud/subscription/composables/useSubscriptionActions' import { useSubscriptionCredits } from '@/platform/cloud/subscription/composables/useSubscriptionCredits' -import { workspaceApi } from '@/platform/workspace/api/workspaceApi' import { useDialogService } from '@/services/dialogService' import { DEFAULT_TIER_KEY, @@ -404,7 +403,8 @@ const { manageSubscription, fetchStatus, fetchBalance, - getMaxSeats + getMaxSeats, + resubscribe } = useBillingContext() const { showCancelSubscriptionDialog } = useDialogService() @@ -415,13 +415,12 @@ const isResubscribing = ref(false) async function handleResubscribe() { isResubscribing.value = true try { - await workspaceApi.resubscribe() + await resubscribe() toast.add({ severity: 'success', summary: t('subscription.resubscribeSuccess'), life: 5000 }) - await Promise.all([fetchStatus(), fetchBalance()]) } catch (error) { const message = error instanceof Error ? error.message : 'Failed to resubscribe' diff --git a/src/platform/workspace/components/TopUpCreditsDialogContentWorkspace.vue b/src/platform/workspace/components/TopUpCreditsDialogContentWorkspace.vue index c484bc5610..55f716fa9a 100644 --- a/src/platform/workspace/components/TopUpCreditsDialogContentWorkspace.vue +++ b/src/platform/workspace/components/TopUpCreditsDialogContentWorkspace.vue @@ -161,7 +161,6 @@ import { useBillingContext } from '@/composables/billing/useBillingContext' import { useExternalLink } from '@/composables/useExternalLink' import { useTelemetry } from '@/platform/telemetry' import { clearTopupTracking } from '@/platform/telemetry/topupTracker' -import { workspaceApi } from '@/platform/workspace/api/workspaceApi' import { useSettingsDialog } from '@/platform/settings/composables/useSettingsDialog' import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore' import { useDialogStore } from '@/stores/dialogStore' @@ -177,7 +176,7 @@ const settingsDialog = useSettingsDialog() const telemetry = useTelemetry() const toast = useToast() const { buildDocsUrl, docsPaths } = useExternalLink() -const { fetchBalance } = useBillingContext() +const { fetchBalance, topup } = useBillingContext() const billingOperationStore = useBillingOperationStore() const isPolling = computed(() => billingOperationStore.hasPendingOperations) @@ -257,7 +256,8 @@ async function handleBuy() { telemetry?.trackApiCreditTopupButtonPurchaseClicked(payAmount.value) const amountCents = payAmount.value * 100 - const response = await workspaceApi.createTopup(amountCents) + const response = await topup(amountCents) + if (!response) return if (response.status === 'completed') { toast.add({ diff --git a/src/platform/workspace/composables/useSubscriptionCheckout.test.ts b/src/platform/workspace/composables/useSubscriptionCheckout.test.ts index b709022ead..8bc969be54 100644 --- a/src/platform/workspace/composables/useSubscriptionCheckout.test.ts +++ b/src/platform/workspace/composables/useSubscriptionCheckout.test.ts @@ -91,10 +91,12 @@ vi.mock('@/composables/billing/useBillingContext', () => ({ previewSubscribe: mockPreviewSubscribe, plans: computed(() => mockPlans.value), fetchStatus: mockFetchStatus, - fetchBalance: mockFetchBalance + fetchBalance: mockFetchBalance, + resubscribe: mockResubscribe }) })) +// Shields the test from the real workspaceApi → @/scripts/api → app.ts import chain vi.mock('@/platform/workspace/api/workspaceApi', () => ({ workspaceApi: { resubscribe: mockResubscribe } })) diff --git a/src/platform/workspace/composables/useSubscriptionCheckout.ts b/src/platform/workspace/composables/useSubscriptionCheckout.ts index 3e88f74c9c..cc101bff70 100644 --- a/src/platform/workspace/composables/useSubscriptionCheckout.ts +++ b/src/platform/workspace/composables/useSubscriptionCheckout.ts @@ -12,7 +12,6 @@ import type { Plan, PreviewSubscribeResponse } from '@/platform/workspace/api/workspaceApi' -import { workspaceApi } from '@/platform/workspace/api/workspaceApi' import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore' type CheckoutStep = 'pricing' | 'preview' | 'success' @@ -36,8 +35,14 @@ export function useSubscriptionCheckout(emit: { }) { const { t } = useI18n() const toast = useToast() - const { subscribe, previewSubscribe, plans, fetchStatus, fetchBalance } = - useBillingContext() + const { + subscribe, + previewSubscribe, + plans, + fetchStatus, + fetchBalance, + resubscribe + } = useBillingContext() const telemetry = useTelemetry() const billingOperationStore = useBillingOperationStore() @@ -184,13 +189,12 @@ export function useSubscriptionCheckout(emit: { async function handleResubscribe() { isResubscribing.value = true try { - await workspaceApi.resubscribe() + await resubscribe() toast.add({ severity: 'success', summary: t('subscription.resubscribeSuccess'), life: 5000 }) - await Promise.all([fetchStatus(), fetchBalance()]) emit('close', true) } catch (error) { const message = diff --git a/src/platform/workspace/composables/useWorkspaceBilling.test.ts b/src/platform/workspace/composables/useWorkspaceBilling.test.ts index d1fbfb041d..106be2f3ac 100644 --- a/src/platform/workspace/composables/useWorkspaceBilling.test.ts +++ b/src/platform/workspace/composables/useWorkspaceBilling.test.ts @@ -11,7 +11,9 @@ const mockWorkspaceApi = vi.hoisted(() => ({ subscribe: vi.fn(), previewSubscribe: vi.fn(), getPaymentPortalUrl: vi.fn(), - cancelSubscription: vi.fn() + cancelSubscription: vi.fn(), + resubscribe: vi.fn(), + createTopup: vi.fn() })) const mockBillingPlans = vi.hoisted(() => ({ @@ -622,6 +624,90 @@ describe('useWorkspaceBilling', () => { }) }) + describe('resubscribe', () => { + it('refreshes status and balance after a successful resubscribe', async () => { + mockWorkspaceApi.resubscribe.mockResolvedValue(undefined) + mockWorkspaceApi.getBillingStatus.mockResolvedValue(activeStatus) + mockWorkspaceApi.getBillingBalance.mockResolvedValue(positiveBalance) + + const billing = setupBilling() + await billing.resubscribe() + + expect(mockWorkspaceApi.resubscribe).toHaveBeenCalledTimes(1) + expect(mockWorkspaceApi.getBillingStatus).toHaveBeenCalledTimes(1) + expect(mockWorkspaceApi.getBillingBalance).toHaveBeenCalledTimes(1) + expect(billing.subscription.value?.tier).toBe('CREATOR') + expect(billing.balance.value?.amountMicros).toBe(5_000_000) + expect(billing.error.value).toBeNull() + expect(billing.isLoading.value).toBe(false) + }) + + it('sets error, rethrows, and skips the refresh when the API call fails', async () => { + mockWorkspaceApi.resubscribe.mockRejectedValue( + new Error('reactivation failed') + ) + + const billing = setupBilling() + + await expect(billing.resubscribe()).rejects.toThrow('reactivation failed') + expect(billing.error.value).toBe('reactivation failed') + expect(billing.isLoading.value).toBe(false) + expect(mockWorkspaceApi.getBillingStatus).not.toHaveBeenCalled() + expect(mockWorkspaceApi.getBillingBalance).not.toHaveBeenCalled() + }) + + it('falls back to a generic error message for non-Error rejections', async () => { + mockWorkspaceApi.resubscribe.mockRejectedValue('boom') + + const billing = setupBilling() + + await expect(billing.resubscribe()).rejects.toBe('boom') + expect(billing.error.value).toBe('Failed to resubscribe') + }) + }) + + describe('topup', () => { + const topupResponse = { + billing_op_id: 'op-topup', + topup_id: 'topup-1', + status: 'completed' as const, + amount_cents: 500 + } + + it('returns the createTopup response without refreshing status or balance', async () => { + mockWorkspaceApi.createTopup.mockResolvedValue(topupResponse) + + const billing = setupBilling() + const result = await billing.topup(500) + + expect(mockWorkspaceApi.createTopup).toHaveBeenCalledWith(500) + expect(result).toBe(topupResponse) + expect(mockWorkspaceApi.getBillingStatus).not.toHaveBeenCalled() + expect(mockWorkspaceApi.getBillingBalance).not.toHaveBeenCalled() + expect(billing.error.value).toBeNull() + expect(billing.isLoading.value).toBe(false) + }) + + it('sets error and rethrows when the API call fails', async () => { + mockWorkspaceApi.createTopup.mockRejectedValue(new Error('card declined')) + + const billing = setupBilling() + + await expect(billing.topup(500)).rejects.toThrow('card declined') + expect(billing.error.value).toBe('card declined') + expect(billing.isLoading.value).toBe(false) + }) + + it('falls back to a generic error message for non-Error rejections', async () => { + mockWorkspaceApi.createTopup.mockRejectedValue('boom') + + const billing = setupBilling() + + await expect(billing.topup(500)).rejects.toBe('boom') + expect(billing.error.value).toBe('Failed to top up credits') + }) + }) + describe('plans / currentPlanSlug / fetchPlans', () => { it('prefers the plan slug from status over the billingPlans fallback', async () => { mockBillingPlans.currentPlanSlug.value = 'plans-fallback' diff --git a/src/platform/workspace/composables/useWorkspaceBilling.ts b/src/platform/workspace/composables/useWorkspaceBilling.ts index cceceee8c6..82583884cc 100644 --- a/src/platform/workspace/composables/useWorkspaceBilling.ts +++ b/src/platform/workspace/composables/useWorkspaceBilling.ts @@ -5,6 +5,7 @@ import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables import type { BillingBalanceResponse, BillingStatusResponse, + CreateTopupResponse, PreviewSubscribeResponse, SubscribeResponse } from '@/platform/workspace/api/workspaceApi' @@ -70,6 +71,13 @@ export function useWorkspaceBilling(): BillingState & BillingActions { } }) + const billingStatus = computed(() => statusData.value?.billing_status ?? null) + const subscriptionStatus = computed( + () => statusData.value?.subscription_status ?? null + ) + const tier = computed(() => statusData.value?.subscription_tier ?? null) + const renewalDate = computed(() => statusData.value?.renewal_date ?? null) + const plans = computed(() => billingPlans.plans.value) const currentPlanSlug = computed( () => statusData.value?.plan_slug ?? billingPlans.currentPlanSlug.value @@ -262,6 +270,34 @@ export function useWorkspaceBilling(): BillingState & BillingActions { } } + async function resubscribe(): Promise { + isLoading.value = true + error.value = null + try { + await workspaceApi.resubscribe() + await Promise.all([fetchStatus(), fetchBalance()]) + } catch (err) { + error.value = err instanceof Error ? err.message : 'Failed to resubscribe' + throw err + } finally { + isLoading.value = false + } + } + + async function topup(amountCents: number): Promise { + isLoading.value = true + error.value = null + try { + return await workspaceApi.createTopup(amountCents) + } catch (err) { + error.value = + err instanceof Error ? err.message : 'Failed to top up credits' + throw err + } finally { + isLoading.value = false + } + } + async function fetchPlans(): Promise { isLoading.value = true error.value = null @@ -303,6 +339,10 @@ export function useWorkspaceBilling(): BillingState & BillingActions { error, isActiveSubscription, isFreeTier, + billingStatus, + subscriptionStatus, + tier, + renewalDate, // Actions initialize, @@ -312,6 +352,8 @@ export function useWorkspaceBilling(): BillingState & BillingActions { previewSubscribe, manageSubscription, cancelSubscription, + resubscribe, + topup, fetchPlans, requireActiveSubscription, showSubscriptionDialog