Compare commits

..

5 Commits

Author SHA1 Message Date
Alexis Rolland
90b2b3ba1d Refactor CLA workflow to handle issue comments 2026-07-02 13:20:28 +08:00
Alexis Rolland
25c4250b2f Update BASE_ALLOWLIST in CLA workflow 2026-07-02 12:58:21 +08:00
Alexis Rolland
a84ec49f75 Update BASE_ALLOWLIST in CLA workflow 2026-07-02 12:57:22 +08:00
Alexis Rolland
4f1a9700c4 Update CLA workflow to build author-only allowlist 2026-07-02 12:53:02 +08:00
Benjamin Lu
2ec2a0e091 feat: attribute payment intent through paywall, checkout, and top-up telemetry (#13363)
## Summary

Answers "why did this user want to pay?" by capturing the triggering
product moment at every paywall/upsell entry point and carrying it
through checkout and success telemetry.

## Changes

- **What**:
- Widen `SubscriptionDialogReason` from 4 coarse values to 13 grounded
intent sources (`subscribe_to_run`, `upgrade_to_add_credits`,
`invite_member_upsell`, `settings_billing_panel`, etc.)
- Fire `app:subscription_required_modal_opened` from
`useSubscriptionDialog` (the choke point all dialog variants pass
through) — the workspace/unified path previously emitted nothing; remove
the now-duplicate emitters in `useSubscription` and
`usePricingTableUrlLoader`
- Add `payment_intent_source` to
`BeginCheckoutMetadata`/`SubscriptionSuccessMetadata`, threaded via the
existing `reason` prop: dialog → `PricingTable` →
`performSubscriptionCheckout` → pending-attempt record, so legacy
`app:monthly_subscription_succeeded` carries intent alongside
`checkout_attempt_id`
- Fire `begin_checkout` on the workspace checkout path
(`useSubscriptionCheckout`, personal + team confirm) and the team
deep-link util — both previously emitted nothing; `tier` widened to
`TierKey | 'team'`
- Implement `trackBeginCheckout` in `PostHogTelemetryProvider` (was
GTM/host-only, so `begin_checkout` never reached PostHog)
- Thread `showSubscriptionDialog(options)` through the billing-context
adapters and pass a reason at ~14 call sites; add `source` to
`app:add_api_credit_button_clicked`

## Review Focus

- `modal_opened` now fires once per dialog actually shown, so a
free-tier user clicking Upgrade emits two events (free-tier dialog, then
pricing table) where the legacy path emitted one
- Intent is threaded explicitly via props/params rather than shared
state; `useSubscriptionCheckout` gained an optional second parameter
2026-07-02 03:11:21 +00:00
93 changed files with 918 additions and 5413 deletions

View File

@@ -5,7 +5,6 @@ on:
types: [created]
pull_request_target:
types: [opened, synchronize, closed]
merge_group:
permissions:
actions: write
@@ -17,13 +16,41 @@ jobs:
cla-assistant:
runs-on: ubuntu-latest
steps:
# The CLA action normally requires every commit author in a PR to sign.
# We only want the PR author to sign, so we allowlist all other committers
# by computing them from the PR's commits and excluding the PR author.
- name: Build author-only allowlist
id: allowlist
if: >
github.event_name == 'pull_request_target' ||
(github.event_name == 'issue_comment' && github.event.issue.pull_request && (
github.event.comment.body == 'recheck' ||
github.event.comment.body == 'I have read and agree to the Contributor License Agreement'
))
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number || github.event.issue.number }}
PR_AUTHOR: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
BASE_ALLOWLIST: action@github.com,actions-user,ampagent,claude,comfy-pr-bot,GitHub Action,github-actions,github-actions[bot],Glary Bot,Glary-Bot,*[bot]
run: |
others=$(gh api "repos/${{ github.repository }}/pulls/${PR_NUMBER}/commits" --paginate \
--jq '.[] | (.author.login // empty), (.committer.login // empty)' \
| sort -u | grep -vix "${PR_AUTHOR}" | paste -sd, -)
if [ -n "$others" ]; then
echo "allowlist=${BASE_ALLOWLIST},${others}" >> "$GITHUB_OUTPUT"
else
echo "allowlist=${BASE_ALLOWLIST}" >> "$GITHUB_OUTPUT"
fi
- name: CLA Assistant
# Run on PR events, on "recheck" comment, or when someone posts the exact signing phrase.
# IMPORTANT: this phrase must match `custom-pr-sign-comment` below.
if: >
github.event_name == 'pull_request_target' ||
github.event.comment.body == 'recheck' ||
github.event.comment.body == 'I have read and agree to the Contributor License Agreement'
(github.event_name == 'issue_comment' && github.event.issue.pull_request && (
github.event.comment.body == 'recheck' ||
github.event.comment.body == 'I have read and agree to the Contributor License Agreement'
))
uses: contributor-assistant/github-action@ca4a40a7d1004f18d9960b404b97e5f30a505a08 # v2.6.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@@ -39,9 +66,10 @@ jobs:
path-to-signatures: signatures/cla.json
branch: main
# Allowlist bots so they don't need to sign (optional, comma-separated).
# Only the PR author must sign: bots plus every non-author committer
# are allowlisted via the "Build author-only allowlist" step above.
# *[bot] is a catch-all for any GitHub App bot account.
allowlist: action@github.com,actions-user,ampagent,claude,comfy-pr-bot,GitHub Action,github-actions,Glary Bot,Glary-Bot,*[bot]
allowlist: ${{ steps.allowlist.outputs.allowlist }}
# Custom PR comment messages
custom-notsigned-prcomment: |

View File

@@ -1,304 +0,0 @@
import { expect, mergeTests } from '@playwright/test'
import type { Page } from '@playwright/test'
import type { ComfyPage } from '@e2e/fixtures/ComfyPage'
import { comfyPageFixture } from '@e2e/fixtures/ComfyPage'
import { ConfirmDialog } from '@e2e/fixtures/components/ConfirmDialog'
import { jsonRoute } from '@e2e/fixtures/utils/jsonRoute'
import { webSocketFixture } from '@e2e/fixtures/ws'
import type {
DownloadStatus,
EnqueueResponse,
HostCredentialUpsert,
HostCredentialView
} from '@/platform/modelManager/types'
const test = mergeTests(comfyPageFixture, webSocketFixture)
const DOWNLOADS_ROUTE = /\/api\/download$/
const ENQUEUE_ROUTE = /\/api\/download\/enqueue$/
const CREDENTIALS_ROUTE = /\/api\/download\/credentials$/
const CREDENTIAL_ROUTE = /\/api\/download\/credentials\/([^/?]+)$/
const CLEAR_ROUTE = /\/api\/download\/clear$/
const DOWNLOAD_ID = 'e2e-download-1'
const MODEL_URL =
'https://huggingface.co/e2e/test/resolve/main/model.safetensors'
const MODEL_ID = 'checkpoints/e2e-test-model.safetensors'
function makeDownloadStatus(
overrides: Partial<DownloadStatus> = {}
): DownloadStatus {
const now = Math.floor(Date.now() / 1000)
return {
download_id: DOWNLOAD_ID,
model_id: MODEL_ID,
url: MODEL_URL,
status: 'queued',
priority: 0,
total_bytes: null,
bytes_done: 0,
progress: null,
speed_bps: null,
eta_seconds: null,
segments: null,
error: null,
created_at: now,
updated_at: now,
...overrides
}
}
async function enableServerSideModelDownloads(comfyPage: ComfyPage) {
await comfyPage.page.evaluate(() => {
window.app!.api.serverFeatureFlags.value = {
...window.app!.api.serverFeatureFlags.value,
server_side_model_downloads: true
}
})
}
/** Keeps GET /download consistent so the 5s stale-poll can't wipe test state. */
async function mockDownloadsList(
page: Page,
getDownloads: () => DownloadStatus[]
) {
await page.route(DOWNLOADS_ROUTE, async (route) => {
if (route.request().method() !== 'GET') return route.fallback()
await route.fulfill(jsonRoute({ downloads: getDownloads() }))
})
}
function modelDownloaderTabButton(comfyPage: ComfyPage) {
return comfyPage.page.locator('.model-manager-tab-button')
}
function modelDownloaderBadge(comfyPage: ComfyPage) {
return modelDownloaderTabButton(comfyPage).locator('.sidebar-icon-badge')
}
function modelDownloaderPanel(comfyPage: ComfyPage) {
return comfyPage.page.locator('.sidebar-content-container')
}
async function openModelDownloaderTab(comfyPage: ComfyPage) {
await modelDownloaderTabButton(comfyPage).click()
const panel = modelDownloaderPanel(comfyPage)
await expect(panel.getByText('Downloads', { exact: true })).toBeVisible()
// Toolbar buttons are only interactive while the tab panel is hovered
// (`group-hover/sidebar-tab`), so keep the cursor over it for later clicks.
await panel.hover()
}
test.describe('Model Downloader sidebar', { tag: '@ui' }, () => {
test.beforeEach(async ({ comfyPage }) => {
await comfyPage.modelLibrary.mockFoldersWithFiles({ checkpoints: [] })
// Isolate from any real downloads already tracked by the backend.
await mockDownloadsList(comfyPage.page, () => [])
await enableServerSideModelDownloads(comfyPage)
})
test('adds a model by URL and reflects live progress over the websocket', async ({
comfyPage,
getWebSocket
}) => {
let downloads: DownloadStatus[] = []
await mockDownloadsList(comfyPage.page, () => downloads)
await comfyPage.page.route(ENQUEUE_ROUTE, async (route) => {
downloads = [makeDownloadStatus()]
const response: EnqueueResponse = {
download_id: DOWNLOAD_ID,
accepted: true
}
await route.fulfill({
status: 202,
contentType: 'application/json',
body: JSON.stringify(response)
})
})
await openModelDownloaderTab(comfyPage)
const panel = modelDownloaderPanel(comfyPage)
await panel.getByTitle('Add model').click()
const addDialog = comfyPage.page.getByRole('dialog')
await expect(addDialog.getByText('Add model')).toBeVisible()
await addDialog.getByLabel('URL').fill(MODEL_URL)
await expect(addDialog.getByLabel('Filename')).toHaveValue(
'model.safetensors'
)
await addDialog.getByLabel('Filename').fill('e2e-test-model.safetensors')
await addDialog.getByRole('combobox', { name: 'Select a folder' }).click()
await comfyPage.page
.getByRole('option', { name: 'Checkpoints', exact: true })
.click()
await addDialog
.getByRole('button', { name: 'Download', exact: true })
.click()
await expect(addDialog).toBeHidden()
await expect(panel.getByText('e2e-test-model.safetensors')).toBeVisible()
await expect(panel.getByText('Queued', { exact: true })).toBeVisible()
await expect(modelDownloaderBadge(comfyPage)).toHaveText('1')
const ws = await getWebSocket()
function sendProgress(overrides: Partial<DownloadStatus>) {
const payload = makeDownloadStatus(overrides)
downloads = [payload]
ws.send(JSON.stringify({ type: 'download_progress', data: payload }))
}
sendProgress({
status: 'active',
progress: 0.4,
bytes_done: 400,
total_bytes: 1000,
speed_bps: 500_000
})
await expect(panel.getByText('Downloading', { exact: true })).toBeVisible()
await expect(panel.getByText(/40%/)).toBeVisible()
await expect(modelDownloaderBadge(comfyPage)).toHaveText('1')
sendProgress({
status: 'completed',
progress: 1,
bytes_done: 1000,
total_bytes: 1000,
speed_bps: null,
eta_seconds: null
})
await expect(panel.getByText('Completed', { exact: true })).toBeVisible()
await expect(panel.getByText('History', { exact: true })).toBeVisible()
await expect(modelDownloaderBadge(comfyPage)).toHaveCount(0)
})
test('clears history and the cleared rows stay gone after reopening the tab', async ({
comfyPage
}) => {
let downloads: DownloadStatus[] = [
makeDownloadStatus({ status: 'completed', progress: 1, bytes_done: 1000 })
]
await mockDownloadsList(comfyPage.page, () => downloads)
await comfyPage.page.route(CLEAR_ROUTE, async (route) => {
const count = downloads.length
downloads = []
await route.fulfill(jsonRoute({ deleted: count }))
})
await openModelDownloaderTab(comfyPage)
const panel = modelDownloaderPanel(comfyPage)
await expect(panel.getByText('History', { exact: true })).toBeVisible()
await expect(panel.getByText('e2e-test-model.safetensors')).toBeVisible()
await panel.getByRole('button', { name: 'Clear history' }).click()
await expect(panel.getByText('No downloads yet')).toBeVisible()
// Reopening the tab re-runs hydrate() -> GET /api/download. The bug was
// that the cleared row reappeared here; the persisted delete must prevent
// the backend list from returning it again.
await modelDownloaderTabButton(comfyPage).click()
await openModelDownloaderTab(comfyPage)
await expect(
modelDownloaderPanel(comfyPage).getByText('No downloads yet')
).toBeVisible()
await expect(
modelDownloaderPanel(comfyPage).getByText('e2e-test-model.safetensors')
).toBeHidden()
})
test('manages download credentials: add, edit, and delete', async ({
comfyPage
}) => {
let credentials: HostCredentialView[] = []
let nextId = 1
await comfyPage.page.route(CREDENTIALS_ROUTE, async (route) => {
const method = route.request().method()
if (method === 'GET') {
await route.fulfill(jsonRoute({ credentials }))
return
}
if (method === 'POST') {
const body = route.request().postDataJSON() as HostCredentialUpsert
const existing = credentials.find((c) => c.host === body.host)
const now = Math.floor(Date.now() / 1000)
const view: HostCredentialView = {
id: existing?.id ?? `cred-${nextId++}`,
host: body.host,
auth_scheme: body.auth_scheme ?? 'bearer',
header_name: body.header_name ?? null,
query_param: body.query_param ?? null,
label: body.label ?? null,
match_subdomains: body.match_subdomains ?? false,
enabled: body.enabled ?? true,
secret_last4: body.secret.slice(-4),
created_at: existing?.created_at ?? now,
updated_at: now
}
credentials = existing
? credentials.map((c) => (c.id === view.id ? view : c))
: [...credentials, view]
await route.fulfill(jsonRoute(view))
return
}
await route.fallback()
})
await comfyPage.page.route(CREDENTIAL_ROUTE, async (route) => {
if (route.request().method() !== 'DELETE') return route.fallback()
const id = route.request().url().match(CREDENTIAL_ROUTE)?.[1]
credentials = credentials.filter((c) => c.id !== id)
await route.fulfill(jsonRoute({ deleted: true }))
})
const dialog = comfyPage.page
.getByRole('dialog')
.filter({ hasText: 'Credentials Manager' })
// The success toast shown after saving can momentarily steal focus and
// close the dialog, so re-opening is retried until it sticks.
async function ensureCredentialsDialogOpen() {
await expect(async () => {
const panel = modelDownloaderPanel(comfyPage)
await panel.hover()
await panel.getByTitle('Credentials Manager').click()
await expect(dialog).toBeVisible({ timeout: 1000 })
}).toPass({ timeout: 10_000 })
}
await openModelDownloaderTab(comfyPage)
await ensureCredentialsDialogOpen()
await dialog.getByLabel('Host').fill('huggingface.co')
await dialog.getByLabel('API key').fill('secret-key-1234')
await dialog.getByRole('button', { name: 'Save', exact: true }).click()
await ensureCredentialsDialogOpen()
await expect(
dialog.getByText('huggingface.co · Bearer token · ••••1234')
).toBeVisible()
await dialog.getByTitle('Edit').click()
await expect(dialog.getByText('Update credential')).toBeVisible()
await expect(dialog.getByLabel('API key')).toHaveValue('')
await dialog.getByLabel('Label').fill('My HF Key')
await dialog.getByLabel('API key').fill('updated-secret-5678')
await dialog.getByRole('button', { name: 'Save', exact: true }).click()
await ensureCredentialsDialogOpen()
await expect(dialog.getByText('My HF Key')).toBeVisible()
await expect(
dialog.getByText('huggingface.co · Bearer token · ••••5678')
).toBeVisible()
await dialog.getByTitle('Delete').click()
const confirm = new ConfirmDialog(comfyPage.page)
await confirm.click('delete')
await expect(dialog.getByText('My HF Key')).toBeHidden()
})
})

View File

@@ -243,7 +243,7 @@ onMounted(() => {
--sidebar-padding: 4px;
--sidebar-icon-size: 1rem;
--sidebar-default-floating-width: 50px;
--sidebar-default-floating-width: 48px;
--sidebar-default-connected-width: calc(
var(--sidebar-default-floating-width) + var(--sidebar-padding) * 2
);

View File

@@ -224,7 +224,7 @@ const handleOpenUserSettings = () => {
}
const handleOpenPlansAndPricing = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
emit('close')
}
@@ -239,8 +239,7 @@ const handleOpenPlanAndCreditsSettings = () => {
}
const handleTopUp = () => {
// Track purchase credits entry from avatar popover
useTelemetry()?.trackAddApiCreditButtonClicked()
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
dialogService.showTopUpCreditsDialog()
emit('close')
}
@@ -254,7 +253,7 @@ const handleOpenPartnerNodesInfo = () => {
}
const handleUpgradeToAddCredits = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
emit('close')
}

View File

@@ -21,6 +21,6 @@ const { isFreeTier } = useBillingContext()
const subscriptionDialog = useSubscriptionDialog()
function handleClick() {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'subscribe_now_button' })
}
</script>

View File

@@ -1,17 +1,12 @@
<script setup lang="ts">
import type { DialogContentEmits, DialogContentProps } from 'reka-ui'
import {
DialogContent,
injectDialogRootContext,
useForwardPropsEmits
} from 'reka-ui'
import { DialogContent, useForwardPropsEmits } from 'reka-ui'
import type { HTMLAttributes } from 'vue'
import { cn } from '@comfyorg/tailwind-utils'
import type { DialogContentSize } from './dialog.variants'
import { dialogContentVariants } from './dialog.variants'
import { useModalPointerLock } from './useModalPointerLock'
const {
size,
@@ -28,11 +23,6 @@ const {
const emits = defineEmits<DialogContentEmits>()
const forwarded = useForwardPropsEmits(restProps, emits)
const dialogRootContext = injectDialogRootContext(null)
if (dialogRootContext?.modal.value) {
useModalPointerLock(() => dialogRootContext.open.value)
}
</script>
<template>

View File

@@ -1,114 +0,0 @@
import { render } from '@testing-library/vue'
import { describe, expect, it } from 'vitest'
import {
SelectContent,
SelectPortal,
SelectRoot,
SelectTrigger,
SelectViewport
} from 'reka-ui'
import { defineComponent, h, nextTick, ref } from 'vue'
import Dialog from './Dialog.vue'
import DialogContent from './DialogContent.vue'
import DialogPortal from './DialogPortal.vue'
async function flush() {
await nextTick()
await nextTick()
await new Promise((resolve) => setTimeout(resolve, 0))
await nextTick()
}
function mountDialogWithSelect(modal: boolean) {
const dialogOpen = ref(true)
const selectOpen = ref(false)
const Parent = defineComponent({
setup() {
return () =>
h(
Dialog,
{
modal,
open: dialogOpen.value,
'onUpdate:open': (value: boolean) => (dialogOpen.value = value)
},
() =>
h(DialogPortal, null, () =>
h(DialogContent, null, () =>
h(
SelectRoot,
{
open: selectOpen.value,
'onUpdate:open': (value: boolean) =>
(selectOpen.value = value)
},
() => [
h(SelectTrigger, null, () => 'trigger'),
h(SelectPortal, null, () =>
h(SelectContent, { position: 'popper' }, () =>
h(SelectViewport, null, () => 'items')
)
)
]
)
)
)
)
}
})
return { ...render(Parent), dialogOpen, selectOpen }
}
describe('modal dialog pointer lock', () => {
it('keeps body inert after a nested combobox popover opens and closes', async () => {
const { selectOpen } = mountDialogWithSelect(true)
await flush()
expect(document.body.style.pointerEvents).toBe('none')
selectOpen.value = true
await flush()
expect(document.body.style.pointerEvents).toBe('none')
// Reka restores body pointer events when the popover layer unmounts; the
// lock must re-assert it so the canvas behind the dialog stays inert.
selectOpen.value = false
await flush()
expect(document.body.style.pointerEvents).toBe('none')
})
it('restores body pointer events once the dialog closes', async () => {
const { dialogOpen } = mountDialogWithSelect(true)
await flush()
expect(document.body.style.pointerEvents).toBe('none')
dialogOpen.value = false
await flush()
expect(document.body.style.pointerEvents).toBe('')
})
it('does not lock body for a non-modal dialog', async () => {
mountDialogWithSelect(false)
await flush()
expect(document.body.style.pointerEvents).toBe('')
})
it('holds the body lock until every open modal dialog closes', async () => {
const first = mountDialogWithSelect(true)
const second = mountDialogWithSelect(true)
await flush()
expect(document.body.style.pointerEvents).toBe('none')
// With one modal still open the shared lock must keep the body inert.
first.dialogOpen.value = false
await flush()
expect(document.body.style.pointerEvents).toBe('none')
// Once the last modal closes the lock releases and stops forcing inert.
second.dialogOpen.value = false
await flush()
expect(document.body.style.pointerEvents).not.toBe('none')
})
})

View File

@@ -1,54 +0,0 @@
import { onScopeDispose, toValue, watch } from 'vue'
import type { MaybeRefOrGetter } from 'vue'
/**
* Keeps the canvas behind a modal dialog inert by holding `document.body`'s
* pointer-events lock for as long as at least one modal dialog is open.
*
* Reka-UI locks body pointer events per modal layer, but a nested dismissable
* layer that is portalled to the body — e.g. a `Select` popover inside the
* dialog — restores the body's pointer events when it closes, even while the
* outer modal dialog is still open. That momentarily re-enables the canvas, so
* combobox clicks leak through to it and can select a node or dismiss the
* dialog. Reka still performs the initial lock and final restore; the
* `MutationObserver` only re-asserts `none` if the lock is cleared while a
* modal dialog is still open.
*/
let openModalCount = 0
let observer: MutationObserver | null = null
function enforceLock() {
if (openModalCount > 0 && document.body.style.pointerEvents !== 'none') {
document.body.style.pointerEvents = 'none'
}
}
function acquire() {
openModalCount += 1
if (observer) return
observer = new MutationObserver(enforceLock)
observer.observe(document.body, {
attributes: true,
attributeFilter: ['style']
})
}
function release() {
openModalCount = Math.max(0, openModalCount - 1)
if (openModalCount > 0) return
observer?.disconnect()
observer = null
document.body.style.pointerEvents = ''
}
export function useModalPointerLock(isOpen: MaybeRefOrGetter<boolean>) {
let holding = false
const sync = (open: boolean) => {
if (open === holding) return
holding = open
if (open) acquire()
else release()
}
watch(() => toValue(isOpen), sync, { immediate: true })
onScopeDispose(() => sync(false))
}

View File

@@ -1,5 +1,6 @@
import type { ComputedRef, Ref } from 'vue'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type {
BillingStatus,
@@ -75,9 +76,10 @@ export interface BillingActions {
*/
requireActiveSubscription: () => Promise<void>
/**
* Shows the subscription dialog.
* Shows the subscription dialog. Pass a reason so the paywall open and any
* downstream checkout stay attributed to the triggering product moment.
*/
showSubscriptionDialog: () => void
showSubscriptionDialog: (options?: SubscriptionDialogOptions) => void
}
export interface BillingState {

View File

@@ -7,6 +7,7 @@ import {
getTierFeatures
} from '@/platform/cloud/subscription/constants/tierPricing'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
PreviewSubscribeOptions,
SubscribeOptions
@@ -281,8 +282,8 @@ function useBillingContextInternal(): BillingContext {
return activeContext.value.requireActiveSubscription()
}
function showSubscriptionDialog() {
return activeContext.value.showSubscriptionDialog()
function showSubscriptionDialog(options?: SubscriptionDialogOptions) {
return activeContext.value.showSubscriptionDialog(options)
}
return {

View File

@@ -2,6 +2,7 @@ import { computed, ref } from 'vue'
import { useAuthActions } from '@/composables/auth/useAuthActions'
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
BillingStatus,
BillingSubscriptionStatus,
@@ -189,12 +190,12 @@ export function useLegacyBilling(): BillingState & BillingActions {
async function requireActiveSubscription(): Promise<void> {
await fetchStatus()
if (!isActiveSubscription.value) {
legacyShowSubscriptionDialog()
legacyShowSubscriptionDialog({ reason: 'subscription_required' })
}
}
function showSubscriptionDialog(): void {
legacyShowSubscriptionDialog()
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
legacyShowSubscriptionDialog(options)
}
return {

View File

@@ -503,7 +503,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}
@@ -526,7 +526,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}
@@ -548,7 +548,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}

View File

@@ -30,8 +30,7 @@ export enum ServerFeatureFlag {
COMFYHUB_PROFILE_GATE_ENABLED = 'comfyhub_profile_gate_enabled',
SHOW_SIGNIN_BUTTON = 'show_signin_button',
UNIFIED_CLOUD_AUTH = 'unified_cloud_auth',
SIGNUP_TURNSTILE = 'signup_turnstile',
SERVER_SIDE_MODEL_DOWNLOADS = 'server_side_model_downloads'
SIGNUP_TURNSTILE = 'signup_turnstile'
}
/**
@@ -182,12 +181,6 @@ export function useFeatureFlags() {
remoteConfig.value.signup_turnstile,
'off'
)
},
get serverSideModelDownloads() {
return api.getServerFeature(
ServerFeatureFlag.SERVER_SIDE_MODEL_DOWNLOADS,
false
)
}
})

View File

@@ -1069,74 +1069,6 @@
"update": "Update",
"description": "Check out the latest improvements and features in this update."
},
"modelManager": {
"title": "Downloads",
"active": "Active",
"history": "History",
"empty": "No downloads yet",
"clearHistory": "Clear history",
"addModel": "Add model",
"addModelDescription": "Download a model directly to your ComfyUI models folder.",
"url": "URL",
"urlPlaceholder": "https://huggingface.co/.../model.safetensors",
"hostNotAllowedHint": "This host may not be on the download allowlist.",
"folder": "Model folder",
"selectFolder": "Select a folder",
"filename": "Filename",
"filenamePlaceholder": "model.safetensors",
"allowAnyExtension": "Allow any file extension (advanced)",
"download": "Download",
"downloadQueued": "Download queued",
"alreadyInstalled": "Model already installed",
"alreadyDownloading": "Model is already downloading",
"resume": "Resume",
"raisePriority": "Raise priority",
"removeFromList": "Remove from list",
"addCredentials": "Add credentials",
"authErrorHint": "{host} needs an API key. Add one in the Credentials Manager, then resume.",
"authErrorHintNoHost": "This host/model needs an API key. Add one in the Credentials Manager, then resume.",
"gatedModelHint": "This model is gated. Accept its license on the model's page, then add an API key and resume.",
"openModelPage": "Accept license",
"actionFailed": "Action failed",
"cancelConfirmTitle": "Cancel download?",
"cancelConfirmMessage": "This will stop the download and delete the partial file for \"{name}\".",
"cancelConfirm": "Cancel download",
"status": {
"queued": "Queued",
"active": "Downloading",
"paused": "Paused",
"verifying": "Verifying",
"completed": "Completed",
"failed": "Failed",
"cancelled": "Cancelled"
},
"credentials": {
"title": "Credentials Manager",
"description": "Store API keys per host (e.g. HuggingFace, Civitai). Secrets are write-only and never shown again.",
"add": "Add credential",
"update": "Update credential",
"edit": "Edit",
"save": "Save",
"saved": "Credential saved",
"host": "Host",
"secret": "API key",
"secretPlaceholder": "Paste API key",
"authScheme": "Auth scheme",
"headerName": "Header name",
"queryParam": "Query parameter",
"label": "Label",
"disabled": "Disabled",
"matchSubdomains": "Match subdomains",
"matchSubdomainsWarning": "Not recommended: hubs redirect to sibling CDN hosts that must not receive your key.",
"deleteTitle": "Delete credential?",
"deleteMessage": "Delete the stored credential for \"{host}\"?",
"scheme": {
"bearer": "Bearer token",
"header": "Custom header",
"query": "Query parameter"
}
}
},
"menu": {
"hideMenu": "Hide Menu",
"showMenu": "Show Menu",

View File

@@ -25,6 +25,6 @@ function handleClose() {
}
function handleSubscribe() {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'upload_model_upgrade' })
}
</script>

View File

@@ -140,7 +140,10 @@ describe('CloudSubscriptionRedirectView', () => {
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
'creator',
'monthly',
false
{
openInNewTab: false,
paymentIntentSource: 'deep_link'
}
)
// Shows loading affordances
@@ -169,7 +172,10 @@ describe('CloudSubscriptionRedirectView', () => {
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
'creator',
'monthly',
false
{
openInNewTab: false,
paymentIntentSource: 'deep_link'
}
)
})
@@ -180,7 +186,8 @@ describe('CloudSubscriptionRedirectView', () => {
expect(screen.getByText('Subscribe to Team Plan')).toBeInTheDocument()
expect(mockPerformTeamSubscriptionCheckout).toHaveBeenCalledWith(
'team_700',
'yearly'
'yearly',
{ paymentIntentSource: 'deep_link' }
)
// Team never goes through the personal checkout path
expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled()

View File

@@ -94,7 +94,9 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
return
}
isTeamCheckout.value = true
await performTeamSubscriptionCheckout(stopId, billingCycle)
await performTeamSubscriptionCheckout(stopId, billingCycle, {
paymentIntentSource: 'deep_link'
})
return
}
@@ -112,7 +114,10 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
if (isActiveSubscription.value) {
await accessBillingPortal(undefined, false)
} else {
await performSubscriptionCheckout(tierKeyParam, billingCycle, false)
await performSubscriptionCheckout(tierKeyParam, billingCycle, {
openInNewTab: false,
paymentIntentSource: 'deep_link'
})
}
}, reportError)

View File

@@ -351,12 +351,12 @@ const handleRefresh = wrapWithErrorHandlingAsync(async () => {
})
function handleAddCredits() {
telemetry?.trackAddApiCreditButtonClicked()
telemetry?.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
void dialogService.showTopUpCreditsDialog()
}
function handleUpgradeToAddCredits() {
showPricingTable()
showPricingTable({ reason: 'upgrade_to_add_credits' })
}
async function handleWindowFocus() {

View File

@@ -5,6 +5,8 @@ import { render, screen } from '@testing-library/vue'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import FreeTierDialogContent from './FreeTierDialogContent.vue'
const mockRenewalDate = vi.hoisted(() => ({ value: null as string | null }))
@@ -15,7 +17,7 @@ vi.mock('@/composables/billing/useBillingContext', () => ({
}))
}))
function renderComponent() {
function renderComponent(props?: { reason?: PaymentIntentSource }) {
const i18n = createI18n({
legacy: false,
locale: 'en',
@@ -23,6 +25,7 @@ function renderComponent() {
})
return render(FreeTierDialogContent, {
props,
global: {
plugins: [i18n]
}
@@ -43,4 +46,18 @@ describe('FreeTierDialogContent', () => {
renderComponent()
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
})
it('keeps the generic copy for intent reasons outside the credits variants', () => {
mockRenewalDate.value = '2026-07-15T10:00:00Z'
renderComponent({ reason: 'subscribe_to_run' })
expect(
screen.getByText('Your credits refresh on Jul 15, 2026.')
).toBeInTheDocument()
})
it('swaps to the out-of-credits copy without the refresh line', () => {
mockRenewalDate.value = '2026-07-15T10:00:00Z'
renderComponent({ reason: 'out_of_credits' })
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
})
})

View File

@@ -52,7 +52,7 @@
</p>
<p
v-if="!reason || reason === 'subscription_required'"
v-if="!isCreditsBlockedVariant"
class="m-0 text-sm text-text-secondary"
>
{{
@@ -65,10 +65,7 @@
</p>
<p
v-if="
(!reason || reason === 'subscription_required') &&
formattedRenewalDate
"
v-if="!isCreditsBlockedVariant && formattedRenewalDate"
class="m-0 text-sm text-text-secondary"
>
{{
@@ -88,7 +85,7 @@
@click="$emit('upgrade')"
>
{{
reason === 'out_of_credits' || reason === 'top_up_blocked'
isCreditsBlockedVariant
? $t('subscription.freeTier.upgradeCta')
: $t('subscription.freeTier.subscribeCta')
}}
@@ -103,12 +100,12 @@ import { computed } from 'vue'
import Button from '@/components/ui/button/Button.vue'
import { useBillingContext } from '@/composables/billing/useBillingContext'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import SubscriptionBenefits from '@/platform/cloud/subscription/components/SubscriptionBenefits.vue'
import { getTierCredits } from '@/platform/cloud/subscription/constants/tierPricing'
defineProps<{
reason?: SubscriptionDialogReason
const { reason } = defineProps<{
reason?: PaymentIntentSource
}>()
defineEmits<{
@@ -129,4 +126,10 @@ const formattedRenewalDate = computed(() => {
})
const freeTierCredits = computed(() => getTierCredits('free'))
// Only these two variants replace the generic free-tier copy; any other
// intent reason (subscribe_to_run, deep_link, ...) keeps the default pitch.
const isCreditsBlockedVariant = computed(
() => reason === 'out_of_credits' || reason === 'top_up_blocked'
)
</script>

View File

@@ -261,6 +261,7 @@ describe('PricingTable', () => {
tier: 'creator',
cycle: 'yearly',
checkout_type: 'change',
checkout_attempt_id: expect.any(String),
previous_tier: 'standard'
})
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
@@ -341,6 +342,7 @@ describe('PricingTable', () => {
expect(
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
).toBeNull()
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('should use the latest userId value when it changes after mount', async () => {
@@ -366,6 +368,7 @@ describe('PricingTable', () => {
tier: 'creator',
cycle: 'yearly',
checkout_type: 'change',
checkout_attempt_id: expect.any(String),
previous_tier: 'standard'
})
})

View File

@@ -277,13 +277,19 @@ import type {
TierKey,
TierPricing
} from '@/platform/cloud/subscription/constants/tierPricing'
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import {
recordPendingSubscriptionCheckoutAttempt,
withPendingCheckoutAttemptId
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil'
import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import type {
CheckoutAttributionMetadata,
PaymentIntentSource
} from '@/platform/telemetry/types'
import { useAuthStore } from '@/stores/authStore'
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
@@ -321,6 +327,10 @@ interface PricingTierConfig {
isPopular?: boolean
}
const { reason } = defineProps<{
reason?: PaymentIntentSource
}>()
const emit = defineEmits<{
chooseTeamWorkspace: []
}>()
@@ -463,16 +473,17 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
} as const
const previousPlan = currentPlanDescriptor.value
const checkoutAttribution = await getCheckoutAttributionForCloud()
if (userId.value) {
telemetry?.trackBeginCheckout({
user_id: userId.value,
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change',
...checkoutAttribution,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
})
}
const beginCheckoutMetadata = userId.value
? {
user_id: userId.value,
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change' as const,
...(reason ? { payment_intent_source: reason } : {}),
...checkoutAttribution,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
}
: null
// Pass the target tier to create a deep link to subscription update confirmation
const checkoutTier = getCheckoutTier(
targetPlan.tierKey,
@@ -487,29 +498,39 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
if (downgrade) {
// TODO(COMFY-StripeProration): Remove once backend checkout creation mirrors portal proration ("change at billing end")
await accessBillingPortal()
const didOpenPortal = await accessBillingPortal()
if (didOpenPortal && beginCheckoutMetadata) {
telemetry?.trackBeginCheckout(beginCheckoutMetadata)
}
} else {
const didOpenPortal = await accessBillingPortal(checkoutTier)
if (!didOpenPortal) {
return
}
recordPendingSubscriptionCheckoutAttempt({
const pendingAttempt = recordPendingSubscriptionCheckoutAttempt({
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change',
payment_intent_source: reason,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {}),
...(previousPlan
? { previous_cycle: previousPlan.billingCycle }
: {})
})
if (beginCheckoutMetadata) {
telemetry?.trackBeginCheckout(
withPendingCheckoutAttemptId(
beginCheckoutMetadata,
pendingAttempt
)
)
}
}
} else {
await performSubscriptionCheckout(
tierKey,
currentBillingCycle.value,
true
)
await performSubscriptionCheckout(tierKey, currentBillingCycle.value, {
paymentIntentSource: reason
})
}
} finally {
isLoading.value = false

View File

@@ -56,7 +56,7 @@ const handleSubscribe = () => {
current_tier: tier.value?.toLowerCase()
})
isAwaitingStripeSubscription.value = true
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_now_button' })
}
onBeforeUnmount(() => {

View File

@@ -54,6 +54,6 @@ function handleSubscribeToRun() {
trackRunButton({ subscribe_to_run: true })
}
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
}
</script>

View File

@@ -48,7 +48,9 @@
v-if="isActiveSubscription"
variant="primary"
class="rounded-lg px-4 py-2 text-sm font-normal text-text-primary"
@click="showSubscriptionDialog"
@click="
showSubscriptionDialog({ reason: 'settings_billing_panel' })
"
>
{{ $t('subscription.upgradePlan') }}
</Button>

View File

@@ -33,7 +33,11 @@
</i18n-t>
</div>
<PricingTable class="flex-1" @choose-team-workspace="handleChooseTeam" />
<PricingTable
:reason
class="flex-1"
@choose-team-workspace="handleChooseTeam"
/>
<!-- Contact and Enterprise Links -->
<div class="flex flex-col items-center gap-2">
@@ -157,11 +161,11 @@ import { useBillingContext } from '@/composables/billing/useBillingContext'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import { useCommandStore } from '@/stores/commandStore'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
const { onClose, reason, onChooseTeam } = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
onChooseTeam?: () => void
}>()

View File

@@ -24,7 +24,9 @@ export function useAccountPreconditionDialog() {
)
return
case 'subscription':
void dialogService.showSubscriptionRequiredDialog()
void dialogService.showSubscriptionRequiredDialog({
reason: 'subscription_required'
})
return
case 'credits':
void dialogService.showTopUpCreditsDialog({

View File

@@ -55,12 +55,6 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
})
}))
const mockTrackSubscription = vi.hoisted(() => vi.fn())
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
}))
describe('usePricingTableUrlLoader', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -96,9 +90,6 @@ describe('usePricingTableUrlLoader', () => {
reason: 'deep_link',
planMode: undefined
})
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
reason: 'deep_link'
})
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
})
@@ -150,7 +141,6 @@ describe('usePricingTableUrlLoader', () => {
await loadPricingTableFromUrl()
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
it('denies, strips, and clears together when the user is not eligible', async () => {
@@ -161,7 +151,6 @@ describe('usePricingTableUrlLoader', () => {
await loadPricingTableFromUrl()
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
expect(mockRouterReplace).toHaveBeenCalledWith({
query: { other: 'param' }
})
@@ -230,7 +219,6 @@ describe('usePricingTableUrlLoader', () => {
)
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledWith(
'pricing'

View File

@@ -7,7 +7,6 @@ import {
mergePreservedQueryIntoQuery
} from '@/platform/navigation/preservedQueryManager'
import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces'
import { useTelemetry } from '@/platform/telemetry'
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
@@ -62,7 +61,6 @@ export function usePricingTableUrlLoader() {
const planMode =
param === 'team' || param === 'personal' ? param : undefined
useTelemetry()?.trackSubscription('modal_opened', { reason: 'deep_link' })
subscriptionDialog.showPricingTable({ reason: 'deep_link', planMode })
}

View File

@@ -15,7 +15,7 @@ import { t } from '@/i18n'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
import { useDialogService } from '@/services/dialogService'
@@ -237,14 +237,7 @@ function useSubscriptionInternal() {
})
}, reportError)
const showSubscriptionDialog = (options?: {
reason?: SubscriptionDialogReason
}) => {
useTelemetry()?.trackSubscription('modal_opened', {
current_tier: subscriptionTier.value?.toLowerCase(),
reason: options?.reason
})
const showSubscriptionDialog = (options?: SubscriptionDialogOptions) => {
void showSubscriptionRequiredDialog(options)
}
@@ -277,7 +270,7 @@ function useSubscriptionInternal() {
await fetchSubscriptionStatus()
if (!isSubscribedOrIsNotCloud.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscription_required' })
}
}

View File

@@ -39,15 +39,23 @@ vi.mock('@/stores/commandStore', () => ({
}))
// useTelemetry() returns null in OSS, a dispatcher in cloud — toggle via mockIsCloud.
const { mockIsCloud, mockTrackHelpResourceClicked } = vi.hoisted(() => ({
const {
mockIsCloud,
mockTrackHelpResourceClicked,
mockTrackAddApiCreditButtonClicked
} = vi.hoisted(() => ({
mockIsCloud: { value: true },
mockTrackHelpResourceClicked: vi.fn()
mockTrackHelpResourceClicked: vi.fn(),
mockTrackAddApiCreditButtonClicked: vi.fn()
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () =>
mockIsCloud.value
? { trackHelpResourceClicked: mockTrackHelpResourceClicked }
? {
trackHelpResourceClicked: mockTrackHelpResourceClicked,
trackAddApiCreditButtonClicked: mockTrackAddApiCreditButtonClicked
}
: null
}))
@@ -69,6 +77,9 @@ describe('useSubscriptionActions', () => {
const { handleAddApiCredits } = useSubscriptionActions()
handleAddApiCredits()
expect(mockShowTopUpCreditsDialog).toHaveBeenCalledOnce()
expect(mockTrackAddApiCreditButtonClicked).toHaveBeenCalledWith({
source: 'settings_billing_panel'
})
})
})

View File

@@ -21,6 +21,9 @@ export function useSubscriptionActions() {
})
const handleAddApiCredits = () => {
telemetry?.trackAddApiCreditButtonClicked({
source: 'settings_billing_panel'
})
void dialogService.showTopUpCreditsDialog()
}

View File

@@ -5,8 +5,10 @@ import { useSubscriptionDialog } from './useSubscriptionDialog'
const mockCloseDialog = vi.fn()
const mockShowLayoutDialog = vi.fn()
const mockShowTeamWorkspacesDialog = vi.fn()
const mockTrackSubscription = vi.hoisted(() => vi.fn())
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
const mockIsCloud = vi.hoisted(() => ({ value: true }))
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
@@ -60,10 +62,15 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
vi.mock('@/composables/billing/useBillingContext', () => ({
useBillingContext: () => ({
isFreeTier: mockIsFreeTier,
isLegacyTeamPlan: mockIsLegacyTeamPlan
isLegacyTeamPlan: mockIsLegacyTeamPlan,
tier: mockTier
})
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
}))
vi.mock('@/platform/workspace/composables/useWorkspaceUI', () => ({
useWorkspaceUI: () => ({
permissions: {
@@ -80,6 +87,7 @@ describe('useSubscriptionDialog', () => {
mockIsCloud.value = true
mockIsInPersonalWorkspace.value = true
mockIsFreeTier.value = false
mockTier.value = 'FREE'
mockTeamWorkspacesEnabled.value = false
mockIsLegacyTeamPlan.value = false
mockCanManageSubscription.value = true
@@ -198,6 +206,51 @@ describe('useSubscriptionDialog', () => {
const props = mockShowLayoutDialog.mock.calls[0][0].props
expect(props.initialPlanMode).toBe('team')
})
it('tracks modal_opened with the caller reason and current tier', () => {
mockTier.value = 'STANDARD'
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'upgrade_to_add_credits' })
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
current_tier: 'standard',
reason: 'upgrade_to_add_credits'
})
})
it('tracks modal_opened on the workspace (unified) path too', () => {
mockTeamWorkspacesEnabled.value = true
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockTrackSubscription).toHaveBeenCalledWith(
'modal_opened',
expect.objectContaining({ reason: 'subscribe_to_run' })
)
})
it('does not track modal_opened for the inactive member dialog', () => {
mockTeamWorkspacesEnabled.value = true
mockIsInPersonalWorkspace.value = false
mockCanManageSubscription.value = false
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockShowLayoutDialog).toHaveBeenCalledTimes(1)
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
it('does not track on non-cloud', () => {
mockIsCloud.value = false
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
})
describe('show', () => {
@@ -235,6 +288,20 @@ describe('useSubscriptionDialog', () => {
expect.objectContaining({ key: 'subscription-required' })
)
})
it('tracks modal_opened with the reason for the free-tier dialog', () => {
mockIsFreeTier.value = true
mockIsInPersonalWorkspace.value = true
const { show } = useSubscriptionDialog()
show({ reason: 'out_of_credits' })
expect(mockTrackSubscription).toHaveBeenCalledTimes(1)
expect(mockTrackSubscription).toHaveBeenCalledWith(
'modal_opened',
expect.objectContaining({ reason: 'out_of_credits' })
)
})
})
describe('startTeamWorkspaceUpgradeFlow', () => {

View File

@@ -4,6 +4,8 @@ import { useDialogStore } from '@/stores/dialogStore'
import { useBillingContext } from '@/composables/billing/useBillingContext'
import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
@@ -11,14 +13,8 @@ const DIALOG_KEY = 'subscription-required'
const FREE_TIER_DIALOG_KEY = 'free-tier-info'
const RESUME_PRICING_KEY = 'comfy:resume-team-pricing'
export type SubscriptionDialogReason =
| 'subscription_required'
| 'out_of_credits'
| 'top_up_blocked'
| 'deep_link'
interface SubscriptionDialogOptions {
reason?: SubscriptionDialogReason
export interface SubscriptionDialogOptions {
reason?: PaymentIntentSource
/**
* Forces the unified pricing dialog to open on a specific plan tab,
* overriding the workspace-derived default (e.g. an "Upgrade to Team" CTA
@@ -38,6 +34,17 @@ export const useSubscriptionDialog = () => {
dialogStore.closeDialog({ key: FREE_TIER_DIALOG_KEY })
}
// Fired here — the choke point every paywall/pricing dialog variant passes
// through — so both the legacy and workspace billing paths emit it.
function trackModalOpened(reason?: PaymentIntentSource) {
// Resolved lazily to avoid the useBillingContext import cycle (see below).
const { tier } = useBillingContext()
useTelemetry()?.trackSubscription('modal_opened', {
current_tier: tier.value?.toLowerCase(),
reason
})
}
function showPricingTable(options?: SubscriptionDialogOptions) {
if (!isCloud) return
@@ -71,6 +78,8 @@ export const useSubscriptionDialog = () => {
return
}
trackModalOpened(options?.reason)
// Shared dialog shell styling for both variants.
const dialogComponentProps = {
style: 'width: min(1328px, 95vw); max-height: 958px;',
@@ -167,6 +176,8 @@ export const useSubscriptionDialog = () => {
// (not at composable setup) to avoid the useBillingContext import cycle.
const { isFreeTier } = useBillingContext()
if (isFreeTier.value && workspaceStore.isInPersonalWorkspace) {
trackModalOpened(options?.reason)
const component = defineAsyncComponent(
() =>
import('@/platform/cloud/subscription/components/FreeTierDialogContent.vue')
@@ -236,7 +247,7 @@ export const useSubscriptionDialog = () => {
sessionStorage.removeItem(RESUME_PRICING_KEY)
if (!workspaceStore.isInPersonalWorkspace) {
showPricingTable()
showPricingTable({ reason: 'team_upgrade_resume' })
}
} catch {
// sessionStorage may be unavailable

View File

@@ -0,0 +1,49 @@
import { beforeEach, describe, expect, it } from 'vitest'
import {
clearPendingSubscriptionCheckoutAttempt,
consumePendingSubscriptionCheckoutSuccess,
recordPendingSubscriptionCheckoutAttempt
} from './subscriptionCheckoutTracker'
const activeProStatus = {
is_active: true,
subscription_tier: 'PRO',
subscription_duration: 'MONTHLY'
} as const
describe('subscriptionCheckoutTracker', () => {
beforeEach(() => {
clearPendingSubscriptionCheckoutAttempt()
})
it('round-trips payment_intent_source from attempt to success metadata', () => {
recordPendingSubscriptionCheckoutAttempt({
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
expect(metadata).toMatchObject({
tier: 'pro',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
})
it('omits payment_intent_source when the attempt had none', () => {
recordPendingSubscriptionCheckoutAttempt({
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new'
})
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
expect(metadata).not.toBeNull()
expect(metadata).not.toHaveProperty('payment_intent_source')
})
})

View File

@@ -7,7 +7,12 @@ import type {
TierKey
} from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { SubscriptionSuccessMetadata } from '@/platform/telemetry/types'
import type {
BeginCheckoutMetadata,
PaymentIntentSource,
SubscriptionCheckoutType,
SubscriptionSuccessMetadata
} from '@/platform/telemetry/types'
const PENDING_SUBSCRIPTION_CHECKOUT_MAX_AGE_MS = 6 * 60 * 60 * 1000
const VALID_TIER_KEYS = new Set<TierKey>([
@@ -23,7 +28,6 @@ export const PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY =
export const PENDING_SUBSCRIPTION_CHECKOUT_EVENT =
'comfy:subscription-checkout-attempt-changed'
type CheckoutType = 'new' | 'change'
type SubscriptionDuration = 'MONTHLY' | 'ANNUAL'
interface SubscriptionStatusSnapshot {
@@ -32,22 +36,24 @@ interface SubscriptionStatusSnapshot {
subscription_duration?: SubscriptionDuration | null
}
interface PendingSubscriptionCheckoutAttempt {
export interface PendingSubscriptionCheckoutAttempt {
attempt_id: string
started_at_ms: number
tier: TierKey
cycle: BillingCycle
checkout_type: CheckoutType
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
previous_cycle?: BillingCycle
payment_intent_source?: PaymentIntentSource
}
interface RecordPendingSubscriptionCheckoutAttemptInput {
interface PendingSubscriptionCheckoutAttemptInput {
tier: TierKey
cycle: BillingCycle
checkout_type: CheckoutType
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
previous_cycle?: BillingCycle
payment_intent_source?: PaymentIntentSource
}
const dispatchPendingCheckoutChangeEvent = () => {
@@ -168,6 +174,9 @@ const normalizeAttempt = (
...(candidate.previous_cycle === 'monthly' ||
candidate.previous_cycle === 'yearly'
? { previous_cycle: candidate.previous_cycle }
: {}),
...(typeof candidate.payment_intent_source === 'string'
? { payment_intent_source: candidate.payment_intent_source }
: {})
}
}
@@ -224,20 +233,27 @@ const getPendingSubscriptionCheckoutAttempt =
export const hasPendingSubscriptionCheckoutAttempt = (): boolean =>
getPendingSubscriptionCheckoutAttempt() !== null
export const recordPendingSubscriptionCheckoutAttempt = (
input: RecordPendingSubscriptionCheckoutAttemptInput
export const createPendingSubscriptionCheckoutAttempt = (
input: PendingSubscriptionCheckoutAttemptInput
): PendingSubscriptionCheckoutAttempt => {
const storage = getStorage()
const attempt: PendingSubscriptionCheckoutAttempt = {
return {
attempt_id: createAttemptId(),
started_at_ms: Date.now(),
tier: input.tier,
cycle: input.cycle,
checkout_type: input.checkout_type,
...(input.previous_tier ? { previous_tier: input.previous_tier } : {}),
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {})
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {}),
...(input.payment_intent_source
? { payment_intent_source: input.payment_intent_source }
: {})
}
}
export const persistPendingSubscriptionCheckoutAttempt = (
attempt: PendingSubscriptionCheckoutAttempt
): PendingSubscriptionCheckoutAttempt => {
const storage = getStorage()
if (!storage) {
return attempt
}
@@ -255,6 +271,21 @@ export const recordPendingSubscriptionCheckoutAttempt = (
return attempt
}
export const recordPendingSubscriptionCheckoutAttempt = (
input: PendingSubscriptionCheckoutAttemptInput
): PendingSubscriptionCheckoutAttempt =>
persistPendingSubscriptionCheckoutAttempt(
createPendingSubscriptionCheckoutAttempt(input)
)
export const withPendingCheckoutAttemptId = (
metadata: BeginCheckoutMetadata,
attempt: PendingSubscriptionCheckoutAttempt
): BeginCheckoutMetadata => ({
...metadata,
checkout_attempt_id: attempt.attempt_id
})
const didAttemptSucceed = (
attempt: PendingSubscriptionCheckoutAttempt,
status: SubscriptionStatusSnapshot
@@ -287,6 +318,9 @@ export const consumePendingSubscriptionCheckoutSuccess = (
cycle: attempt.cycle,
checkout_type: attempt.checkout_type,
...(attempt.previous_tier ? { previous_tier: attempt.previous_tier } : {}),
...(attempt.payment_intent_source
? { payment_intent_source: attempt.payment_intent_source }
: {}),
value,
currency: 'USD',
ecommerce: {

View File

@@ -132,13 +132,14 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'yearly', true)
await performSubscriptionCheckout('pro', 'yearly')
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-123',
tier: 'pro',
cycle: 'yearly',
checkout_type: 'new',
checkout_attempt_id: expect.any(String),
ga_client_id: 'ga-client-id',
ga_session_id: 'ga-session-id',
ga_session_number: 'ga-session-number',
@@ -150,6 +151,12 @@ describe('performSubscriptionCheckout', () => {
gbraid: 'gbraid-456',
wbraid: 'wbraid-789'
})
const beginCheckoutMetadata =
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
JSON.parse(storedAttempt).attempt_id
)
expect(global.fetch).toHaveBeenCalledWith(
expect.stringContaining(
'/customers/cloud-subscription-checkout/pro-yearly'
@@ -186,7 +193,7 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', true)
await performSubscriptionCheckout('pro', 'monthly')
expect(warnSpy).toHaveBeenCalledWith(
'[SubscriptionCheckout] Failed to collect checkout attribution',
@@ -203,11 +210,43 @@ describe('performSubscriptionCheckout', () => {
user_id: 'user-123',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new'
checkout_type: 'new',
checkout_attempt_id: expect.any(String)
})
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
})
it('carries the payment intent source into begin_checkout and the pending attempt', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi
.spyOn(window, 'open')
.mockImplementation(() => window as unknown as Window)
vi.mocked(global.fetch).mockResolvedValue({
ok: true,
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', {
paymentIntentSource: 'out_of_credits'
})
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({ payment_intent_source: 'out_of_credits' })
)
const beginCheckoutMetadata =
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
const pendingAttempt = JSON.parse(storedAttempt)
expect(pendingAttempt).toMatchObject({
payment_intent_source: 'out_of_credits'
})
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
pendingAttempt.attempt_id
)
openSpy.mockRestore()
})
it('uses the latest userId when it changes after checkout starts', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi
@@ -222,7 +261,7 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly', true)
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly')
mockUserId.value = 'user-late'
authHeader.resolve({ Authorization: 'Bearer test-token' })
@@ -235,13 +274,14 @@ describe('performSubscriptionCheckout', () => {
user_id: 'user-late',
tier: 'pro',
cycle: 'yearly',
checkout_type: 'new'
checkout_type: 'new',
checkout_attempt_id: expect.any(String)
})
)
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
})
it('does not persist a pending attempt when the checkout popup is blocked', async () => {
it('does not persist the pending attempt when the checkout popup is blocked', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
@@ -250,11 +290,18 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', true)
await performSubscriptionCheckout('pro', 'monthly')
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
expect(
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
).toBeNull()
const storedAttempt = window.localStorage.getItem(
PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY
)
expect(storedAttempt).toBeNull()
expect(mockLocalStorage.setItem).not.toHaveBeenCalled()
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
checkout_attempt_id: expect.any(String)
})
)
})
})

View File

@@ -4,12 +4,19 @@ import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { getComfyApiBaseUrl } from '@/config/comfyApi'
import { t } from '@/i18n'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import {
createPendingSubscriptionCheckoutAttempt,
persistPendingSubscriptionCheckoutAttempt,
withPendingCheckoutAttemptId
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type {
CheckoutAttributionMetadata,
PaymentIntentSource
} from '@/platform/telemetry/types'
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import type { BillingCycle } from './subscriptionTierRank'
type CheckoutTier = TierKey | `${TierKey}-yearly`
@@ -31,6 +38,11 @@ const getCheckoutAttributionForCloud =
return getCheckoutAttribution()
}
interface PerformSubscriptionCheckoutOptions {
openInNewTab?: boolean
paymentIntentSource?: PaymentIntentSource
}
/**
* Core subscription checkout logic shared between PricingTable and
* SubscriptionRedirectView. Handles:
@@ -47,10 +59,12 @@ const getCheckoutAttributionForCloud =
export async function performSubscriptionCheckout(
tierKey: TierKey,
currentBillingCycle: BillingCycle,
openInNewTab: boolean = true
options: PerformSubscriptionCheckoutOptions = {}
): Promise<void> {
if (!isCloud) return
const { openInNewTab = true, paymentIntentSource } = options
const authStore = useAuthStore()
const { userId } = storeToRefs(authStore)
const telemetry = useTelemetry()
@@ -108,14 +122,29 @@ export async function performSubscriptionCheckout(
const data = await response.json()
if (data.checkout_url) {
const pendingAttempt = createPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
payment_intent_source: paymentIntentSource
})
if (userId.value) {
telemetry?.trackBeginCheckout({
user_id: userId.value,
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
...checkoutAttribution
})
telemetry?.trackBeginCheckout(
withPendingCheckoutAttemptId(
{
user_id: userId.value,
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
...(paymentIntentSource
? { payment_intent_source: paymentIntentSource }
: {}),
...checkoutAttribution
},
pendingAttempt
)
)
}
if (openInNewTab) {
@@ -123,18 +152,9 @@ export async function performSubscriptionCheckout(
if (!checkoutWindow) {
return
}
recordPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new'
})
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
} else {
recordPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new'
})
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
globalThis.location.href = data.checkout_url
}
}

View File

@@ -1,9 +1,13 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { computed, reactive } from 'vue'
const { mockIsCloud, mockSubscribe } = vi.hoisted(() => ({
mockIsCloud: { value: true },
mockSubscribe: vi.fn()
}))
const { mockIsCloud, mockSubscribe, mockTrackBeginCheckout, mockUserId } =
vi.hoisted(() => ({
mockIsCloud: { value: true },
mockSubscribe: vi.fn(),
mockTrackBeginCheckout: vi.fn(),
mockUserId: { value: 'user-1' as string | null }
}))
vi.mock('@/platform/distribution/types', () => ({
get isCloud() {
@@ -16,6 +20,12 @@ vi.mock('@/config/comfyApi', () => ({
vi.mock('@/platform/workspace/api/workspaceApi', () => ({
workspaceApi: { subscribe: mockSubscribe }
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackBeginCheckout: mockTrackBeginCheckout })
}))
vi.mock('@/stores/authStore', () => ({
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
}))
import { performTeamSubscriptionCheckout } from './teamSubscriptionCheckoutUtil'
@@ -43,7 +53,9 @@ describe('performTeamSubscriptionCheckout', () => {
billing_op_id: 'op_1'
})
await performTeamSubscriptionCheckout('team_700', 'yearly')
await performTeamSubscriptionCheckout('team_700', 'yearly', {
paymentIntentSource: 'deep_link'
})
expect(mockSubscribe).toHaveBeenCalledWith('team_per_credit_annual', {
returnUrl: 'https://app.test/payment/success',
@@ -51,6 +63,14 @@ describe('performTeamSubscriptionCheckout', () => {
teamCreditStopId: 'team_700'
})
expect(assignedHref).toBe('https://stripe.test/pay')
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-1',
tier: 'team',
cycle: 'yearly',
checkout_type: 'new',
billing_op_id: 'op_1',
payment_intent_source: 'deep_link'
})
})
it('uses the monthly slug and lands in the app when no Stripe step is needed', async () => {
@@ -82,6 +102,16 @@ describe('performTeamSubscriptionCheckout', () => {
expect(assignedHref).toBeUndefined()
})
it('does not track begin_checkout when subscribe fails', async () => {
mockSubscribe.mockRejectedValueOnce(new Error('subscribe failed'))
await expect(
performTeamSubscriptionCheckout('team_700', 'yearly')
).rejects.toThrow('subscribe failed')
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('does nothing off cloud', async () => {
mockIsCloud.value = false

View File

@@ -1,10 +1,16 @@
import { getComfyPlatformBaseUrl } from '@/config/comfyApi'
import { getTeamPlanSlug } from '@/platform/cloud/subscription/constants/teamPlanCreditStops'
import { isCloud } from '@/platform/distribution/types'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
import type { BillingCycle } from './subscriptionTierRank'
interface PerformTeamSubscriptionCheckoutOptions {
paymentIntentSource?: PaymentIntentSource
}
/**
* Direct team-plan checkout for the marketing `/cloud/subscribe?tier=team` deep
* link: subscribes to the per-credit Team plan at the chosen slider stop and
@@ -22,7 +28,8 @@ import type { BillingCycle } from './subscriptionTierRank'
*/
export async function performTeamSubscriptionCheckout(
teamCreditStopId: string,
billingCycle: BillingCycle
billingCycle: BillingCycle,
options: PerformTeamSubscriptionCheckoutOptions = {}
): Promise<void> {
if (!isCloud) return
@@ -33,6 +40,14 @@ export async function performTeamSubscriptionCheckout(
teamCreditStopId
})
trackWorkspaceCheckoutStarted({
tier: 'team',
cycle: billingCycle,
checkoutType: 'new',
billingOpId: response.billing_op_id,
paymentIntentSource: options.paymentIntentSource
})
if (response.status === 'needs_payment_method') {
// A needs_payment_method response without a URL is unusable: surface it to
// the caller's error handling rather than silently dropping the user home

View File

@@ -1,7 +1,5 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { DownloadApiError } from '@/platform/modelManager/types'
import {
downloadModel,
fetchModelMetadata,
@@ -9,23 +7,13 @@ import {
toBrowsableUrl
} from './missingModelDownload'
const {
fetchMock,
mockIsDesktop,
mockSidebarTabStore,
mockStartDownload,
mockEnqueue,
mockToastAdd,
mockFlags
} = vi.hoisted(() => ({
fetchMock: vi.fn(),
mockIsDesktop: { value: false },
mockSidebarTabStore: { activeSidebarTabId: null as string | null },
mockStartDownload: vi.fn(),
mockEnqueue: vi.fn(),
mockToastAdd: vi.fn(),
mockFlags: { serverSideModelDownloads: false }
}))
const { fetchMock, mockIsDesktop, mockSidebarTabStore, mockStartDownload } =
vi.hoisted(() => ({
fetchMock: vi.fn(),
mockIsDesktop: { value: false },
mockSidebarTabStore: { activeSidebarTabId: null as string | null },
mockStartDownload: vi.fn()
}))
vi.stubGlobal('fetch', fetchMock)
@@ -45,38 +33,12 @@ vi.mock('@/stores/workspace/sidebarTabStore', () => ({
useSidebarTabStore: () => mockSidebarTabStore
}))
vi.mock('@/composables/useFeatureFlags', () => ({
useFeatureFlags: () => ({ flags: mockFlags })
}))
vi.mock('@/platform/modelManager/stores/modelDownloadStore', () => ({
useModelDownloadStore: () => ({ enqueue: mockEnqueue })
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: () => ({ add: mockToastAdd })
}))
const mockRefreshModelFolder = vi.fn()
const mockRefreshMissingModels = vi.fn()
vi.mock('@/stores/modelStore', () => ({
useModelStore: () => ({ refreshModelFolder: mockRefreshModelFolder })
}))
vi.mock('@/platform/missingModel/missingModelStore', () => ({
useMissingModelStore: () => ({
refreshMissingModels: mockRefreshMissingModels
})
}))
let testId = 0
beforeEach(() => {
vi.restoreAllMocks()
vi.resetAllMocks()
delete window.__comfyDesktop2
mockFlags.serverSideModelDownloads = false
})
describe('fetchModelMetadata', () => {
@@ -155,26 +117,6 @@ describe('fetchModelMetadata', () => {
expect(fetchMock).not.toHaveBeenCalled()
})
it('returns null metadata when the Civitai request throws', async () => {
fetchMock.mockRejectedValueOnce(new Error('network down'))
const metadata = await fetchModelMetadata(
`https://civitai.com/api/download/models/${testId}`
)
expect(metadata).toEqual({ fileSize: null, gatedRepoUrl: null })
})
it('returns null metadata when the HEAD request throws', async () => {
fetchMock.mockRejectedValueOnce(new Error('network down'))
const metadata = await fetchModelMetadata(
`https://huggingface.co/org/model/resolve/main/throws-${testId}.safetensors`
)
expect(metadata).toEqual({ fileSize: null, gatedRepoUrl: null })
})
it('returns cached metadata on second call', async () => {
const url = `https://huggingface.co/org/model/resolve/main/cached-${testId}.safetensors`
@@ -298,16 +240,6 @@ describe('isModelDownloadable', () => {
})
).toBe(false)
})
it('allows explicitly whitelisted URLs from an otherwise disallowed host', () => {
expect(
isModelDownloadable({
name: 'RealESRGAN_x4plus.pth',
url: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
directory: 'upscale_models'
})
).toBe(true)
})
})
describe('downloadModel', () => {
@@ -447,152 +379,6 @@ describe('downloadModel', () => {
expect(anchorClick).toHaveBeenCalledTimes(1)
})
it('uses the browser fallback on web when server-side downloads are disabled', () => {
const anchorClick = vi
.spyOn(HTMLAnchorElement.prototype, 'click')
.mockImplementation(() => {})
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
expect(anchorClick).toHaveBeenCalledTimes(1)
expect(mockEnqueue).not.toHaveBeenCalled()
})
it('enqueues a server-side download and reveals the manager when enabled', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockResolvedValue({ download_id: 'd1', accepted: true })
const anchorClick = vi
.spyOn(HTMLAnchorElement.prototype, 'click')
.mockImplementation(() => {})
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockSidebarTabStore.activeSidebarTabId).toBe('model-manager')
})
expect(mockEnqueue).toHaveBeenCalledWith({
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
model_id: 'checkpoints/model.safetensors'
})
expect(anchorClick).not.toHaveBeenCalled()
})
it('shows a toast when a server-side enqueue fails', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockRejectedValue(new Error('boom'))
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error', detail: 'boom' })
)
})
expect(mockSidebarTabStore.activeSidebarTabId).toBeNull()
})
it('reveals the download manager and shows an info toast for an in-progress download', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockRejectedValue(
new DownloadApiError('exists', 'ALREADY_DOWNLOADING', 409)
)
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockSidebarTabStore.activeSidebarTabId).toBe('model-manager')
})
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({
severity: 'info',
detail: 'model.safetensors'
})
)
})
it('refreshes the model folder and re-scans missing models when already available', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockRejectedValue(
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
)
mockRefreshModelFolder.mockResolvedValue(undefined)
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({
severity: 'info',
detail: 'model.safetensors'
})
)
})
await vi.waitFor(() => {
expect(mockRefreshModelFolder).toHaveBeenCalledWith('checkpoints')
})
expect(mockRefreshMissingModels).toHaveBeenCalled()
expect(mockSidebarTabStore.activeSidebarTabId).toBeNull()
})
it('still re-scans missing models when the post-available folder refresh fails', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockRejectedValue(
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
)
const consoleWarn = vi.spyOn(console, 'warn').mockImplementation(() => {})
mockRefreshModelFolder.mockRejectedValue(new Error('boom'))
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockRefreshMissingModels).toHaveBeenCalled()
})
expect(consoleWarn).toHaveBeenCalled()
consoleWarn.mockRestore()
})
it('opens the model library sidebar before starting a desktop download', () => {
mockIsDesktop.value = true

View File

@@ -1,11 +1,5 @@
import { downloadUrlToHfRepoUrl, isCivitaiModelUrl } from '@/utils/formatUtil'
import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { t } from '@/i18n'
import { isDesktop } from '@/platform/distribution/types'
import { useModelDownloadStore } from '@/platform/modelManager/stores/modelDownloadStore'
import { DownloadApiError } from '@/platform/modelManager/types'
import { buildModelId } from '@/platform/modelManager/utils/modelId'
import { useToastStore } from '@/platform/updates/common/toastStore'
import { useElectronDownloadStore } from '@/stores/electronDownloadStore'
import { useSidebarTabStore } from '@/stores/workspace/sidebarTabStore'
import type { ComfyDesktop2Bridge } from '@/types'
@@ -35,7 +29,6 @@ const WHITE_LISTED_URLS: ReadonlySet<string> = new Set([
])
const MODEL_LIBRARY_TAB_ID = 'model-library'
const MODEL_MANAGER_TAB_ID = 'model-manager'
export interface ModelWithUrl {
name: string
@@ -54,96 +47,6 @@ async function startDesktop2ModelDownload(
}
}
function revealDownloadManager(): void {
useSidebarTabStore().activeSidebarTabId = MODEL_MANAGER_TAB_ID
}
/**
* Already on disk: surface a confirmation, refresh the model folder, and
* re-scan missing models so any node error for this model clears. Loaded
* lazily to keep this module's import graph (and its unit tests) light.
*/
async function refreshAfterModelAvailable(model: ModelWithUrl): Promise<void> {
try {
const [{ useModelStore }, { useMissingModelStore }] = await Promise.all([
import('@/stores/modelStore'),
import('@/platform/missingModel/missingModelStore')
])
if (model.directory) {
try {
await useModelStore().refreshModelFolder(model.directory)
} catch (error) {
console.warn(
'[MissingModel] Failed to refresh model folder after model available',
error
)
}
}
void useMissingModelStore().refreshMissingModels()
} catch (error) {
console.warn(
'[MissingModel] Failed to refresh after model available',
error
)
}
}
/**
* Enqueues a server-side download and reveals the Model Manager panel so the
* user can watch live progress, status, and completion. The two benign `409`
* cases get an info toast: `ALREADY_DOWNLOADING` links to the existing job,
* `ALREADY_AVAILABLE` confirms it's installed and clears the node error. Any
* other failure is reported via an error toast.
*/
async function startServerSideModelDownload(
model: ModelWithUrl
): Promise<void> {
const toast = useToastStore()
try {
await useModelDownloadStore().enqueue({
url: model.url,
model_id: buildModelId(model.directory, model.name)
})
revealDownloadManager()
} catch (error: unknown) {
if (error instanceof DownloadApiError && error.is('ALREADY_DOWNLOADING')) {
revealDownloadManager()
toast.add({
severity: 'info',
summary: t('modelManager.alreadyDownloading'),
detail: model.name,
life: 4000
})
return
}
if (error instanceof DownloadApiError && error.is('ALREADY_AVAILABLE')) {
toast.add({
severity: 'info',
summary: t('modelManager.alreadyInstalled'),
detail: model.name,
life: 4000
})
void refreshAfterModelAvailable(model)
return
}
toast.add({
severity: 'error',
summary: t('modelManager.actionFailed'),
detail: error instanceof Error ? error.message : String(error),
life: 5000
})
}
}
function startBrowserModelDownload(model: ModelWithUrl): void {
const link = document.createElement('a')
link.href = model.url
link.download = model.name
link.target = '_blank'
link.rel = 'noopener noreferrer'
link.click()
}
/**
* Converts a model download URL to a browsable page URL.
* - HuggingFace: `/resolve/` → `/blob/` (file page with model info)
@@ -179,11 +82,12 @@ export function downloadModel(
}
if (!isDesktop) {
if (useFeatureFlags().flags.serverSideModelDownloads) {
void startServerSideModelDownload(model)
} else {
startBrowserModelDownload(model)
}
const link = document.createElement('a')
link.href = model.url
link.download = model.name
link.target = '_blank'
link.rel = 'noopener noreferrer'
link.click()
return
}

View File

@@ -1,240 +0,0 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { api } from '@/scripts/api'
import { DownloadApiError } from '../types'
import {
cancelDownload,
checkAvailability,
clearDownloads,
deleteCredential,
deleteDownload,
enqueueDownload,
listCredentials,
listDownloads,
pauseDownload,
resumeDownload,
setDownloadPriority,
upsertCredential
} from './modelDownloadApi'
vi.mock('@/scripts/api', () => ({
api: {
fetchApi: vi.fn()
}
}))
const fetchApi = vi.mocked(api.fetchApi)
function jsonResponse(status: number, body: unknown): Response {
return {
ok: status >= 200 && status < 300,
status,
statusText: `status ${status}`,
json: () => Promise.resolve(body)
} as unknown as Response
}
function nonJsonErrorResponse(status: number, statusText: string): Response {
return {
ok: false,
status,
statusText,
json: () => Promise.reject(new Error('not json'))
} as unknown as Response
}
describe('modelDownloadApi', () => {
beforeEach(() => {
vi.resetAllMocks()
})
describe('enqueueDownload', () => {
it('returns the enqueue response on 202', async () => {
fetchApi.mockResolvedValue(
jsonResponse(202, { download_id: 'd1', accepted: true })
)
const result = await enqueueDownload({
url: 'https://huggingface.co/x.safetensors',
model_id: 'loras/x.safetensors'
})
expect(result).toEqual({ download_id: 'd1', accepted: true })
expect(fetchApi).toHaveBeenCalledWith(
'/download/enqueue',
expect.objectContaining({ method: 'POST' })
)
})
it('throws a DownloadApiError carrying the error code on 409', async () => {
fetchApi.mockResolvedValue(
jsonResponse(409, {
error: { code: 'ALREADY_DOWNLOADING', message: 'exists' }
})
)
await expect(
enqueueDownload({ url: 'u', model_id: 'loras/x.safetensors' })
).rejects.toMatchObject({
code: 'ALREADY_DOWNLOADING',
status: 409,
message: 'exists'
})
})
it('falls back to statusText and an UNKNOWN code for a non-JSON error body', async () => {
fetchApi.mockResolvedValue(nonJsonErrorResponse(502, 'Bad Gateway'))
await expect(
enqueueDownload({ url: 'u', model_id: 'loras/x.safetensors' })
).rejects.toMatchObject({
message: 'Bad Gateway',
code: 'UNKNOWN',
status: 502
})
})
it('exposes the code through DownloadApiError.is()', async () => {
fetchApi.mockResolvedValue(
jsonResponse(400, {
error: { code: 'URL_NOT_ALLOWED', message: 'nope' }
})
)
const error = await enqueueDownload({
url: 'u',
model_id: 'loras/x.safetensors'
}).catch((e) => e)
expect(error).toBeInstanceOf(DownloadApiError)
expect(error.is('URL_NOT_ALLOWED')).toBe(true)
})
})
describe('listDownloads', () => {
it('unwraps the downloads array', async () => {
fetchApi.mockResolvedValue(
jsonResponse(200, { downloads: [{ download_id: 'd1' }] })
)
const result = await listDownloads()
expect(result).toEqual([{ download_id: 'd1' }])
})
})
describe('actions', () => {
it('posts to the pause route', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
await pauseDownload('d1')
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1/pause',
expect.objectContaining({ method: 'POST' })
)
})
it('posts to the resume route', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
await resumeDownload('d1')
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1/resume',
expect.objectContaining({ method: 'POST' })
)
})
it('posts to the cancel route', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
await cancelDownload('d1')
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1/cancel',
expect.objectContaining({ method: 'POST' })
)
})
it('sends the priority in the body', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
await setDownloadPriority('d1', 5)
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1/priority',
expect.objectContaining({ body: JSON.stringify({ priority: 5 }) })
)
})
it('sends a DELETE to the download route', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { deleted: true }))
await deleteDownload('d1')
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1',
expect.objectContaining({ method: 'DELETE' })
)
})
it('posts to the clear route and returns the deleted count', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { deleted: 3 }))
const result = await clearDownloads()
expect(result).toBe(3)
expect(fetchApi).toHaveBeenCalledWith(
'/download/clear',
expect.objectContaining({ method: 'POST' })
)
})
})
describe('checkAvailability', () => {
it('wraps the models map in the request body', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { models: {} }))
await checkAvailability({ 'loras/x.safetensors': 'https://h.co/x' })
expect(fetchApi).toHaveBeenCalledWith(
'/download/availability',
expect.objectContaining({
body: JSON.stringify({
models: { 'loras/x.safetensors': 'https://h.co/x' }
})
})
)
})
})
describe('credentials', () => {
it('unwraps the credentials list', async () => {
fetchApi.mockResolvedValue(
jsonResponse(200, { credentials: [{ id: 'c1' }] })
)
expect(await listCredentials()).toEqual([{ id: 'c1' }])
})
it('returns the created credential view', async () => {
fetchApi.mockResolvedValue(jsonResponse(201, { id: 'c1', host: 'h' }))
const result = await upsertCredential({ host: 'h', secret: 's' })
expect(result).toEqual({ id: 'c1', host: 'h' })
})
it('throws on a failed delete', async () => {
fetchApi.mockResolvedValue(
jsonResponse(404, { error: { code: 'NOT_FOUND', message: 'gone' } })
)
await expect(deleteCredential('c1')).rejects.toMatchObject({
code: 'NOT_FOUND'
})
})
})
})

View File

@@ -1,121 +0,0 @@
import { api } from '@/scripts/api'
import type {
AvailabilityResponse,
DownloadStatus,
EnqueueRequest,
EnqueueResponse,
HostCredentialUpsert,
HostCredentialView
} from '../types'
import { DownloadApiError } from '../types'
const BASE = '/download'
interface ErrorEnvelope {
error?: {
code?: string
message?: string
details?: Record<string, unknown>
}
}
async function throwFromResponse(response: Response): Promise<never> {
let body: ErrorEnvelope = {}
try {
body = await response.json()
} catch {
// Non-JSON error body
}
const error = body.error
throw new DownloadApiError(
error?.message ?? response.statusText,
error?.code ?? 'UNKNOWN',
response.status,
error?.details
)
}
async function parseJson<T>(response: Response): Promise<T> {
if (!response.ok) return throwFromResponse(response)
return response.json() as Promise<T>
}
function postJson(route: string, body: unknown): Promise<Response> {
return api.fetchApi(route, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(body)
})
}
export async function enqueueDownload(
body: EnqueueRequest
): Promise<EnqueueResponse> {
const response = await postJson(`${BASE}/enqueue`, body)
if (response.status === 202) {
return response.json() as Promise<EnqueueResponse>
}
return throwFromResponse(response)
}
export async function listDownloads(): Promise<DownloadStatus[]> {
const response = await api.fetchApi(BASE)
const data = await parseJson<{ downloads: DownloadStatus[] }>(response)
return data.downloads
}
async function postAction(id: string, action: string): Promise<void> {
const response = await postJson(`${BASE}/${id}/${action}`, undefined)
await parseJson<{ ok: boolean }>(response)
}
export const pauseDownload = (id: string) => postAction(id, 'pause')
export const resumeDownload = (id: string) => postAction(id, 'resume')
export const cancelDownload = (id: string) => postAction(id, 'cancel')
export async function deleteDownload(id: string): Promise<void> {
const response = await api.fetchApi(`${BASE}/${id}`, { method: 'DELETE' })
await parseJson<{ deleted: boolean }>(response)
}
export async function clearDownloads(): Promise<number> {
const response = await postJson(`${BASE}/clear`, undefined)
const data = await parseJson<{ deleted: number }>(response)
return data.deleted
}
export async function setDownloadPriority(
id: string,
priority: number
): Promise<void> {
const response = await postJson(`${BASE}/${id}/priority`, { priority })
await parseJson<{ ok: boolean }>(response)
}
export async function checkAvailability(
models: Record<string, string>
): Promise<AvailabilityResponse> {
const response = await postJson(`${BASE}/availability`, { models })
return parseJson<AvailabilityResponse>(response)
}
export async function listCredentials(): Promise<HostCredentialView[]> {
const response = await api.fetchApi(`${BASE}/credentials`)
const data = await parseJson<{ credentials: HostCredentialView[] }>(response)
return data.credentials
}
export async function upsertCredential(
body: HostCredentialUpsert
): Promise<HostCredentialView> {
const response = await postJson(`${BASE}/credentials`, body)
return parseJson<HostCredentialView>(response)
}
export async function deleteCredential(id: string): Promise<void> {
const response = await api.fetchApi(`${BASE}/credentials/${id}`, {
method: 'DELETE'
})
await parseJson<{ deleted: boolean }>(response)
}

View File

@@ -1,253 +0,0 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { defineComponent, ref } from 'vue'
import { createI18n } from 'vue-i18n'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import { DownloadApiError } from '../types'
import AddModelByUrlDialog from './AddModelByUrlDialog.vue'
const mockEnqueue = vi.fn()
const mockToastAdd = vi.fn()
const mockFetchModelTypes = vi.fn()
const mockModelTypes = ref([
{ name: 'Checkpoints', value: 'checkpoints' },
{ name: 'LoRA', value: 'loras' }
])
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => ({ enqueue: mockEnqueue })
}))
vi.mock('@/platform/assets/composables/useModelTypes', () => ({
useModelTypes: () => ({
modelTypes: mockModelTypes,
isLoading: ref(false),
fetchModelTypes: mockFetchModelTypes
})
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: () => ({ add: mockToastAdd })
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: { en: enMessages },
missingWarn: false,
fallbackWarn: false
})
const stubs = {
Dialog: { template: '<div><slot /></div>' },
DialogPortal: { template: '<div><slot /></div>' },
DialogOverlay: { template: '<div />' },
DialogContent: defineComponent({
emits: ['open-auto-focus'],
mounted() {
this.$emit('open-auto-focus')
},
template: '<div><slot /></div>'
}),
DialogHeader: { template: '<div><slot /></div>' },
DialogTitle: { template: '<div><slot /></div>' },
DialogDescription: { template: '<div><slot /></div>' },
DialogFooter: { template: '<div><slot /></div>' },
SingleSelect: {
props: ['modelValue', 'options', 'label', 'loading'],
emits: ['update:modelValue'],
template: `
<select
data-testid="folder-select"
:value="modelValue ?? ''"
@change="$emit('update:modelValue', $event.target.value)"
>
<option value="" disabled>{{ label }}</option>
<option v-for="opt in options" :key="opt.value" :value="opt.value">
{{ opt.name }}
</option>
</select>
`
}
}
function mountDialog(open = true) {
return render(AddModelByUrlDialog, {
props: { open },
global: { plugins: [i18n], stubs }
})
}
async function fillValidForm() {
await userEvent.type(
screen.getByLabelText('URL'),
'https://huggingface.co/org/model/resolve/main/model.safetensors'
)
await userEvent.selectOptions(
screen.getByTestId('folder-select'),
'checkpoints'
)
}
describe('AddModelByUrlDialog', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('auto-fills the filename from the url until edited manually', async () => {
mountDialog()
const urlInput = screen.getByLabelText('URL')
const filenameInput = screen.getByLabelText('Filename')
await userEvent.type(
urlInput,
'https://huggingface.co/org/model/resolve/main/model.safetensors'
)
expect(filenameInput).toHaveValue('model.safetensors')
await userEvent.clear(filenameInput)
await userEvent.type(filenameInput, 'custom.safetensors')
await userEvent.type(urlInput, '?download=true')
expect(filenameInput).toHaveValue('custom.safetensors')
})
it('shows a hint for a url on a non-allowlisted host', async () => {
mountDialog()
await userEvent.type(
screen.getByLabelText('URL'),
'https://example.com/model.safetensors'
)
expect(
screen.getByText('This host may not be on the download allowlist.')
).toBeInTheDocument()
})
it('disables submit until url, folder, and a valid filename are set', async () => {
mountDialog()
const submit = screen.getByRole('button', { name: 'Download' })
expect(submit).toBeDisabled()
await userEvent.type(
screen.getByLabelText('URL'),
'https://huggingface.co/org/model/resolve/main/model.safetensors'
)
expect(submit).toBeDisabled()
await userEvent.selectOptions(
screen.getByTestId('folder-select'),
'checkpoints'
)
expect(submit).toBeEnabled()
})
it('requires a known model extension', async () => {
mountDialog()
await userEvent.type(
screen.getByLabelText('URL'),
'https://huggingface.co/org/model/resolve/main/readme.txt'
)
await userEvent.selectOptions(
screen.getByTestId('folder-select'),
'checkpoints'
)
const submit = screen.getByRole('button', { name: 'Download' })
expect(submit).toBeDisabled()
})
it('closes without submitting when cancel is clicked', async () => {
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Cancel' }))
expect(mockEnqueue).not.toHaveBeenCalled()
expect(emitted('update:open')?.at(-1)).toEqual([false])
})
it('enqueues the download and closes on success', async () => {
mockEnqueue.mockResolvedValue({ download_id: 'd1', accepted: true })
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
expect(mockEnqueue).toHaveBeenCalledWith({
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
model_id: 'checkpoints/model.safetensors',
allow_any_extension: false
})
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'success' })
)
})
expect(emitted('update:open')?.at(-1)).toEqual([false])
})
it('shows an info toast and closes when the model is already downloading', async () => {
mockEnqueue.mockRejectedValue(
new DownloadApiError('exists', 'ALREADY_DOWNLOADING', 409)
)
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'info' })
)
})
expect(emitted('update:open')?.at(-1)).toEqual([false])
})
it('shows an info toast and closes when the model is already installed', async () => {
mockEnqueue.mockRejectedValue(
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
)
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'info' })
)
})
expect(emitted('update:open')?.at(-1)).toEqual([false])
})
it('shows an inline error for other API failures and stays open', async () => {
mockEnqueue.mockRejectedValue(
new DownloadApiError('url not allowed', 'URL_NOT_ALLOWED', 400)
)
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
await vi.waitFor(() => {
expect(screen.getByText('url not allowed')).toBeInTheDocument()
})
expect(emitted('update:open')).toBeUndefined()
})
it('shows a generic error message for a non-api error', async () => {
mockEnqueue.mockRejectedValue(new Error('network down'))
mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
await vi.waitFor(() => {
expect(screen.getByText('network down')).toBeInTheDocument()
})
})
})

View File

@@ -1,237 +0,0 @@
<template>
<Dialog v-model:open="isOpen">
<DialogPortal>
<DialogOverlay class="bg-black/70" />
<DialogContent
size="md"
class="flex flex-col gap-4 p-6"
@open-auto-focus="onOpen"
>
<DialogHeader>
<DialogTitle>{{ $t('modelManager.addModel') }}</DialogTitle>
<DialogDescription>
{{ $t('modelManager.addModelDescription') }}
</DialogDescription>
</DialogHeader>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="download-url">
{{ $t('modelManager.url') }}
</label>
<Input
id="download-url"
v-model="url"
:placeholder="$t('modelManager.urlPlaceholder')"
@update:model-value="onUrlChanged"
/>
<p v-if="hostHint" class="text-xs text-amber-400">{{ hostHint }}</p>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground">
{{ $t('modelManager.folder') }}
</label>
<SingleSelect
v-model="directory"
:options="folderOptions"
:label="$t('modelManager.selectFolder')"
:loading="isLoadingFolders"
/>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="download-filename">
{{ $t('modelManager.filename') }}
</label>
<Input
id="download-filename"
v-model="filename"
:placeholder="$t('modelManager.filenamePlaceholder')"
@update:model-value="onFilenameEdited"
/>
</div>
<!-- TODO: re-enable once we think we'd want to allow any extension
<label class="flex w-fit items-center gap-2 text-xs text-muted-foreground">
<input v-model="allowAnyExtension" type="checkbox" class="size-4" />
{{ $t('modelManager.allowAnyExtension') }}
</label>
-->
<p
v-if="modelId"
class="truncate rounded-md bg-secondary-background px-2 py-1 text-xs text-muted-foreground"
>
{{ modelId }}
</p>
<p v-if="errorMessage" class="text-xs text-red-400">
{{ errorMessage }}
</p>
<DialogFooter>
<Button variant="secondary" @click="isOpen = false">
{{ $t('g.cancel') }}
</Button>
<Button
variant="primary"
:disabled="!canSubmit"
:loading="isSubmitting"
@click="submit"
>
{{ $t('modelManager.download') }}
</Button>
</DialogFooter>
</DialogContent>
</DialogPortal>
</Dialog>
</template>
<script setup lang="ts">
import { computed, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import Button from '@/components/ui/button/Button.vue'
import Dialog from '@/components/ui/dialog/Dialog.vue'
import DialogContent from '@/components/ui/dialog/DialogContent.vue'
import DialogDescription from '@/components/ui/dialog/DialogDescription.vue'
import DialogFooter from '@/components/ui/dialog/DialogFooter.vue'
import DialogHeader from '@/components/ui/dialog/DialogHeader.vue'
import DialogOverlay from '@/components/ui/dialog/DialogOverlay.vue'
import DialogPortal from '@/components/ui/dialog/DialogPortal.vue'
import DialogTitle from '@/components/ui/dialog/DialogTitle.vue'
import Input from '@/components/ui/input/Input.vue'
import SingleSelect from '@/components/ui/single-select/SingleSelect.vue'
import { useModelTypes } from '@/platform/assets/composables/useModelTypes'
import { useToastStore } from '@/platform/updates/common/toastStore'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
import { DownloadApiError } from '../types'
import {
buildModelId,
filenameFromUrl,
hasModelExtension,
isLikelyAllowedHost,
isValidPathSegment
} from '../utils/modelId'
const isOpen = defineModel<boolean>('open', { required: true })
const { t } = useI18n()
const store = useModelDownloadStore()
const {
modelTypes,
isLoading: isLoadingFolders,
fetchModelTypes
} = useModelTypes()
const url = ref('')
const directory = ref<string | undefined>(undefined)
const filename = ref('')
const isFilenameUserEdited = ref(false)
const allowAnyExtension = ref(false)
const isSubmitting = ref(false)
const errorMessage = ref('')
const folderOptions = computed(() => modelTypes.value)
const modelId = computed(() =>
directory.value && filename.value
? buildModelId(directory.value, filename.value)
: ''
)
const hostHint = computed(() =>
url.value && !isLikelyAllowedHost(url.value)
? t('modelManager.hostNotAllowedHint')
: ''
)
const canSubmit = computed(
() =>
!!url.value &&
!!directory.value &&
isValidPathSegment(filename.value) &&
(allowAnyExtension.value || hasModelExtension(filename.value))
)
function onOpen() {
void fetchModelTypes()
errorMessage.value = ''
}
function onUrlChanged() {
if (!isFilenameUserEdited.value) {
filename.value = filenameFromUrl(url.value)
}
}
function onFilenameEdited() {
isFilenameUserEdited.value = true
}
function reset() {
url.value = ''
directory.value = undefined
filename.value = ''
isFilenameUserEdited.value = false
allowAnyExtension.value = false
errorMessage.value = ''
}
async function submit() {
if (!canSubmit.value) return
isSubmitting.value = true
errorMessage.value = ''
try {
await store.enqueue({
url: url.value,
model_id: modelId.value,
allow_any_extension: allowAnyExtension.value
})
useToastStore().add({
severity: 'success',
summary: t('modelManager.downloadQueued'),
detail: modelId.value,
life: 4000
})
reset()
isOpen.value = false
} catch (error) {
handleEnqueueError(error)
} finally {
isSubmitting.value = false
}
}
function handleEnqueueError(error: unknown) {
if (error instanceof DownloadApiError) {
if (error.is('ALREADY_AVAILABLE')) {
useToastStore().add({
severity: 'info',
summary: t('modelManager.alreadyInstalled'),
detail: modelId.value,
life: 4000
})
reset()
isOpen.value = false
return
}
if (error.is('ALREADY_DOWNLOADING')) {
useToastStore().add({
severity: 'info',
summary: t('modelManager.alreadyDownloading'),
detail: modelId.value,
life: 4000
})
reset()
isOpen.value = false
return
}
errorMessage.value = error.message
return
}
errorMessage.value =
error instanceof Error ? error.message : t('modelManager.actionFailed')
}
</script>

View File

@@ -1,313 +0,0 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { defineComponent, nextTick, reactive, ref } from 'vue'
import { createI18n } from 'vue-i18n'
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { HostCredentialView } from '../types'
import HostCredentialsDialog from './HostCredentialsDialog.vue'
const mockCloseDialog = vi.fn()
const mockToastAdd = vi.fn()
const mockCredentialsStore = reactive({
credentials: ref<HostCredentialView[]>([]),
isLoading: ref(false),
fetchCredentials: vi.fn(),
upsert: vi.fn(),
remove: vi.fn()
})
vi.mock('../stores/hostCredentialsStore', () => ({
useHostCredentialsStore: () => mockCredentialsStore
}))
vi.mock('@/stores/dialogStore', () => ({
useDialogStore: () => ({ closeDialog: mockCloseDialog })
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: () => ({ add: mockToastAdd })
}))
vi.mock('@/components/dialog/confirm/confirmDialog')
const mockShowConfirmDialog = vi.mocked(showConfirmDialog)
interface CapturedConfirmOptions {
footerProps?: {
onConfirm?: () => void | Promise<void>
onCancel?: () => void
}
}
function capturedOptions(): CapturedConfirmOptions {
return mockShowConfirmDialog.mock
.calls[0][0] as unknown as CapturedConfirmOptions
}
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: { en: enMessages },
missingWarn: false,
fallbackWarn: false
})
const stubs = {
Dialog: { template: '<div><slot /></div>' },
DialogPortal: { template: '<div><slot /></div>' },
DialogOverlay: { template: '<div />' },
DialogContent: defineComponent({
emits: ['open-auto-focus'],
mounted() {
this.$emit('open-auto-focus')
},
template: '<div><slot /></div>'
}),
DialogHeader: { template: '<div><slot /></div>' },
DialogTitle: { template: '<div><slot /></div>' },
DialogDescription: { template: '<div><slot /></div>' },
SingleSelect: {
props: ['modelValue', 'options', 'label'],
emits: ['update:modelValue'],
template: `
<select
data-testid="scheme-select"
:value="modelValue ?? ''"
@change="$emit('update:modelValue', $event.target.value)"
>
<option v-for="opt in options" :key="opt.value" :value="opt.value">
{{ opt.name }}
</option>
</select>
`
}
}
function createCredential(
overrides: Partial<HostCredentialView> = {}
): HostCredentialView {
return {
id: 'c1',
host: 'huggingface.co',
auth_scheme: 'bearer',
header_name: null,
query_param: null,
label: null,
match_subdomains: false,
enabled: true,
secret_last4: '1234',
created_at: 0,
updated_at: 0,
...overrides
}
}
function mountDialog(props: { open?: boolean; prefillHost?: string } = {}) {
return render(HostCredentialsDialog, {
props: { open: true, ...props },
global: { plugins: [i18n], stubs }
})
}
describe('HostCredentialsDialog', () => {
beforeEach(() => {
vi.clearAllMocks()
mockCredentialsStore.credentials = []
mockCredentialsStore.fetchCredentials.mockResolvedValue(undefined)
mockShowConfirmDialog.mockReturnValue(
{} as ReturnType<typeof showConfirmDialog>
)
})
it('renders existing credentials with host, scheme, and last 4 of the secret', () => {
mockCredentialsStore.credentials = [
createCredential({
host: 'huggingface.co',
auth_scheme: 'bearer',
secret_last4: '9f21'
})
]
mountDialog()
expect(
screen.getByText('huggingface.co · Bearer token · ••••9f21')
).toBeInTheDocument()
})
it('shows a disabled marker for a disabled credential', () => {
mockCredentialsStore.credentials = [createCredential({ enabled: false })]
mountDialog()
expect(screen.getByText(/Disabled/)).toBeInTheDocument()
})
it('prefills the host from the prefillHost prop on open', async () => {
mountDialog({ prefillHost: 'civitai.com' })
await nextTick()
expect(screen.getByLabelText('Host')).toHaveValue('civitai.com')
})
it('populates the form when editing an existing credential', async () => {
mockCredentialsStore.credentials = [
createCredential({
id: 'c1',
host: 'civitai.com',
auth_scheme: 'header',
header_name: 'X-Api-Key',
label: 'My Civitai key'
})
]
mountDialog()
await userEvent.click(screen.getByTitle('Edit'))
expect(screen.getByLabelText('Host')).toHaveValue('civitai.com')
expect(screen.getByLabelText('Label')).toHaveValue('My Civitai key')
expect(screen.getByText('Update credential')).toBeInTheDocument()
})
it('disables submit until host and secret are filled', async () => {
mountDialog()
const submit = screen.getByRole('button', { name: 'Save' })
expect(submit).toBeDisabled()
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
expect(submit).toBeDisabled()
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
expect(submit).toBeEnabled()
})
it('requires a query parameter name when the scheme is query', async () => {
mountDialog()
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
await userEvent.selectOptions(screen.getByTestId('scheme-select'), 'query')
const submit = screen.getByRole('button', { name: 'Save' })
expect(submit).toBeDisabled()
await userEvent.type(screen.getByLabelText('Query parameter'), 'token')
expect(submit).toBeEnabled()
})
it('shows the header name field and a subdomain warning for the header scheme', async () => {
mountDialog()
await userEvent.selectOptions(screen.getByTestId('scheme-select'), 'header')
expect(screen.getByLabelText('Header name')).toBeInTheDocument()
await userEvent.click(screen.getByLabelText('Match subdomains'))
expect(
screen.getByText(
'Not recommended: hubs redirect to sibling CDN hosts that must not receive your key.'
)
).toBeInTheDocument()
})
it('cancels an in-progress edit and resets the form', async () => {
mockCredentialsStore.credentials = [
createCredential({ id: 'c1', host: 'civitai.com' })
]
mountDialog()
await userEvent.click(screen.getByTitle('Edit'))
expect(screen.getByText('Update credential')).toBeInTheDocument()
await userEvent.click(screen.getByRole('button', { name: 'Cancel' }))
expect(screen.getByText('Add credential')).toBeInTheDocument()
expect(screen.getByLabelText('Host')).toHaveValue('')
})
it('submits the form payload and resets on success', async () => {
mockCredentialsStore.upsert.mockResolvedValue(createCredential())
mountDialog()
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
await userEvent.click(screen.getByRole('button', { name: 'Save' }))
expect(mockCredentialsStore.upsert).toHaveBeenCalledWith({
host: 'huggingface.co',
secret: 's3cret',
auth_scheme: 'bearer',
header_name: null,
query_param: null,
label: null,
match_subdomains: false
})
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'success' })
)
})
expect(screen.getByLabelText('Host')).toHaveValue('')
})
it('shows an inline error message when saving fails', async () => {
mockCredentialsStore.upsert.mockRejectedValue(new Error('save failed'))
mountDialog()
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
await userEvent.click(screen.getByRole('button', { name: 'Save' }))
await vi.waitFor(() => {
expect(screen.getByText('save failed')).toBeInTheDocument()
})
})
it('shows an inline error when the initial credentials fetch fails', async () => {
mockCredentialsStore.fetchCredentials.mockRejectedValue(
new Error('offline')
)
mountDialog()
await vi.waitFor(() => {
expect(screen.getByText('offline')).toBeInTheDocument()
})
})
it('deletes a credential through a confirm dialog', async () => {
mockCredentialsStore.remove.mockResolvedValue(undefined)
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
mountDialog()
await userEvent.click(screen.getByTitle('Delete'))
await capturedOptions().footerProps?.onConfirm?.()
expect(mockCredentialsStore.remove).toHaveBeenCalledWith('c1')
expect(mockCloseDialog).toHaveBeenCalled()
})
it('shows an error toast when deletion fails', async () => {
mockCredentialsStore.remove.mockRejectedValue(new Error('boom'))
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
mountDialog()
await userEvent.click(screen.getByTitle('Delete'))
await capturedOptions().footerProps?.onConfirm?.()
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error' })
)
})
})
it('does not delete when the confirm dialog is dismissed', async () => {
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
mountDialog()
await userEvent.click(screen.getByTitle('Delete'))
capturedOptions().footerProps?.onCancel?.()
expect(mockCredentialsStore.remove).not.toHaveBeenCalled()
})
})

View File

@@ -1,333 +0,0 @@
<template>
<Dialog v-model:open="isOpen">
<DialogPortal>
<DialogOverlay class="bg-black/70" />
<DialogContent
size="md"
class="flex flex-col gap-4 p-6"
@open-auto-focus="onOpen"
>
<DialogHeader>
<DialogTitle>{{ $t('modelManager.credentials.title') }}</DialogTitle>
<DialogDescription>
{{ $t('modelManager.credentials.description') }}
</DialogDescription>
</DialogHeader>
<div v-if="credentials.length" class="flex flex-col gap-2">
<div
v-for="credential in credentials"
:key="credential.id"
class="flex items-center gap-2 rounded-lg border border-border-default bg-secondary-background px-3 py-2"
>
<div class="flex min-w-0 flex-1 flex-col">
<span class="truncate text-sm font-medium text-base-foreground">
{{ credential.label || credential.host }}
</span>
<span class="truncate text-xs text-muted-foreground">
{{ credential.host }} ·
{{
$t(
`modelManager.credentials.scheme.${credential.auth_scheme}`
)
}}
<template v-if="credential.secret_last4">
· {{ credential.secret_last4 }}
</template>
<template v-if="!credential.enabled">
· {{ $t('modelManager.credentials.disabled') }}
</template>
</span>
</div>
<Button
variant="textonly"
size="icon"
:title="$t('modelManager.credentials.edit')"
@click="editCredential(credential)"
>
<i class="icon-[lucide--pencil] size-4" />
</Button>
<Button
variant="textonly"
size="icon"
:title="$t('g.delete')"
@click="confirmDelete(credential)"
>
<i class="icon-[lucide--trash-2] size-4 text-red-400" />
</Button>
</div>
</div>
<div class="flex flex-col gap-3 border-t border-border-default pt-3">
<span class="text-sm font-medium text-base-foreground">
{{
form.id
? $t('modelManager.credentials.update')
: $t('modelManager.credentials.add')
}}
</span>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-host">
{{ $t('modelManager.credentials.host') }}
</label>
<Input
id="cred-host"
v-model="form.host"
:placeholder="HOST_EXAMPLE"
/>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-secret">
{{ $t('modelManager.credentials.secret') }}
</label>
<Input
id="cred-secret"
v-model="form.secret"
type="password"
:placeholder="$t('modelManager.credentials.secretPlaceholder')"
/>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground">
{{ $t('modelManager.credentials.authScheme') }}
</label>
<SingleSelect v-model="form.auth_scheme" :options="schemeOptions" />
</div>
<div v-if="form.auth_scheme === 'header'" class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-header">
{{ $t('modelManager.credentials.headerName') }}
</label>
<Input
id="cred-header"
v-model="form.header_name"
:placeholder="HEADER_NAME_EXAMPLE"
/>
</div>
<div v-if="form.auth_scheme === 'query'" class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-query">
{{ $t('modelManager.credentials.queryParam') }}
</label>
<Input
id="cred-query"
v-model="form.query_param"
:placeholder="QUERY_PARAM_EXAMPLE"
/>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-label">
{{ $t('modelManager.credentials.label') }}
</label>
<Input id="cred-label" v-model="form.label" />
</div>
<label class="flex items-center gap-2 text-xs text-muted-foreground">
<input
v-model="form.match_subdomains"
type="checkbox"
class="size-4"
/>
{{ $t('modelManager.credentials.matchSubdomains') }}
</label>
<p v-if="form.match_subdomains" class="text-xs text-amber-400">
{{ $t('modelManager.credentials.matchSubdomainsWarning') }}
</p>
<p v-if="errorMessage" class="text-xs text-red-400">
{{ errorMessage }}
</p>
<div class="flex justify-end gap-2">
<Button v-if="form.id" variant="secondary" @click="resetForm">
{{ $t('g.cancel') }}
</Button>
<Button
variant="primary"
:disabled="!canSubmit"
:loading="isSubmitting"
@click="submit"
>
{{ $t('modelManager.credentials.save') }}
</Button>
</div>
</div>
</DialogContent>
</DialogPortal>
</Dialog>
</template>
<script setup lang="ts">
import { computed, reactive, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
import Button from '@/components/ui/button/Button.vue'
import Dialog from '@/components/ui/dialog/Dialog.vue'
import DialogContent from '@/components/ui/dialog/DialogContent.vue'
import DialogDescription from '@/components/ui/dialog/DialogDescription.vue'
import DialogHeader from '@/components/ui/dialog/DialogHeader.vue'
import DialogOverlay from '@/components/ui/dialog/DialogOverlay.vue'
import DialogPortal from '@/components/ui/dialog/DialogPortal.vue'
import DialogTitle from '@/components/ui/dialog/DialogTitle.vue'
import Input from '@/components/ui/input/Input.vue'
import SingleSelect from '@/components/ui/single-select/SingleSelect.vue'
import { useToastStore } from '@/platform/updates/common/toastStore'
import { storeToRefs } from 'pinia'
import { useDialogStore } from '@/stores/dialogStore'
import { useHostCredentialsStore } from '../stores/hostCredentialsStore'
import { AUTH_SCHEMES } from '../types'
import type { AuthScheme, HostCredentialView } from '../types'
const HOST_EXAMPLE = 'huggingface.co'
const HEADER_NAME_EXAMPLE = 'Authorization'
const QUERY_PARAM_EXAMPLE = 'token'
const { prefillHost = '' } = defineProps<{ prefillHost?: string }>()
const isOpen = defineModel<boolean>('open', { required: true })
const { t } = useI18n()
const store = useHostCredentialsStore()
const { credentials } = storeToRefs(store)
const dialogStore = useDialogStore()
const isSubmitting = ref(false)
const errorMessage = ref('')
interface CredentialForm {
id: string | null
host: string
secret: string
auth_scheme: AuthScheme
header_name: string
query_param: string
label: string
match_subdomains: boolean
}
function emptyForm(): CredentialForm {
return {
id: null,
host: '',
secret: '',
auth_scheme: 'bearer',
header_name: '',
query_param: '',
label: '',
match_subdomains: false
}
}
const form = reactive<CredentialForm>(emptyForm())
const schemeOptions = computed(() =>
AUTH_SCHEMES.map((scheme) => ({
value: scheme,
name: t(`modelManager.credentials.scheme.${scheme}`)
}))
)
const canSubmit = computed(
() =>
!!form.host.trim() &&
!!form.secret &&
(form.auth_scheme !== 'query' || !!form.query_param.trim()) &&
(form.auth_scheme !== 'header' || !!form.header_name.trim())
)
async function onOpen() {
resetForm()
if (prefillHost) {
form.host = prefillHost
}
try {
await store.fetchCredentials()
} catch (error) {
errorMessage.value =
error instanceof Error ? error.message : t('modelManager.actionFailed')
}
}
function resetForm() {
Object.assign(form, emptyForm())
errorMessage.value = ''
}
function editCredential(credential: HostCredentialView) {
Object.assign(form, {
id: credential.id,
host: credential.host,
secret: '',
auth_scheme: credential.auth_scheme,
header_name: credential.header_name ?? '',
query_param: credential.query_param ?? '',
label: credential.label ?? '',
match_subdomains: credential.match_subdomains
})
errorMessage.value = ''
}
async function submit() {
if (!canSubmit.value) return
isSubmitting.value = true
errorMessage.value = ''
try {
await store.upsert({
host: form.host,
secret: form.secret,
auth_scheme: form.auth_scheme,
header_name: form.auth_scheme === 'header' ? form.header_name : null,
query_param: form.auth_scheme === 'query' ? form.query_param : null,
label: form.label || null,
match_subdomains: form.match_subdomains
})
useToastStore().add({
severity: 'success',
summary: t('modelManager.credentials.saved'),
detail: form.host,
life: 4000
})
resetForm()
} catch (error) {
errorMessage.value =
error instanceof Error ? error.message : t('modelManager.actionFailed')
} finally {
isSubmitting.value = false
}
}
function confirmDelete(credential: HostCredentialView) {
const dialog = showConfirmDialog({
headerProps: { title: t('modelManager.credentials.deleteTitle') },
props: {
promptText: t('modelManager.credentials.deleteMessage', {
host: credential.host
})
},
footerProps: {
confirmText: t('g.delete'),
confirmVariant: 'destructive' as const,
onCancel: () => dialogStore.closeDialog(dialog),
onConfirm: async () => {
dialogStore.closeDialog(dialog)
try {
await store.remove(credential.id)
} catch (error) {
useToastStore().add({
severity: 'error',
summary: t('modelManager.actionFailed'),
detail: error instanceof Error ? error.message : String(error),
life: 5000
})
}
}
}
})
}
</script>

View File

@@ -1,336 +0,0 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createI18n } from 'vue-i18n'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { DownloadStatus } from '../types'
import ModelDownloadRow from './ModelDownloadRow.vue'
const mockPause = vi.fn()
const mockResume = vi.fn()
const mockCancel = vi.fn()
const mockRaisePriority = vi.fn()
const mockRemove = vi.fn()
vi.mock('../composables/useModelDownloadActions', () => ({
useModelDownloadActions: () => ({
pause: mockPause,
resume: mockResume,
cancel: mockCancel,
raisePriority: mockRaisePriority,
remove: mockRemove,
toastError: vi.fn()
})
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: { en: enMessages },
missingWarn: false,
fallbackWarn: false
})
function createDownload(
overrides: Partial<DownloadStatus> = {}
): DownloadStatus {
return {
download_id: 'd1',
model_id: 'loras/x.safetensors',
url: 'https://huggingface.co/org/x.safetensors',
status: 'active',
priority: 0,
total_bytes: 2048,
bytes_done: 1024,
progress: 0.5,
speed_bps: 512,
eta_seconds: 125,
segments: null,
error: null,
created_at: 0,
updated_at: 0,
...overrides
}
}
function mountRow(
download: DownloadStatus,
onOpenCredentials?: (host: string) => void
) {
return render(ModelDownloadRow, {
props: {
download,
...(onOpenCredentials ? { onOpenCredentials } : {})
},
global: { plugins: [i18n] }
})
}
describe('ModelDownloadRow', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('splits the model id into directory and filename', () => {
mountRow(createDownload({ model_id: 'loras/x.safetensors' }))
expect(screen.getByText('x.safetensors')).toBeInTheDocument()
expect(screen.getByText('loras')).toBeInTheDocument()
})
it('renders an empty directory when the model id has no folder', () => {
mountRow(createDownload({ model_id: 'x.safetensors' }))
expect(screen.getByText('x.safetensors')).toBeInTheDocument()
})
it('formats the meta line with percent, size, speed, and eta while active', () => {
mountRow(createDownload({ status: 'active' }))
expect(
screen.getByText('50% · 1 KB / 2 KB · 512 B/s · 2:05')
).toBeInTheDocument()
})
it('renders an empty meta line when no progress metrics are known yet', () => {
mountRow(
createDownload({
status: 'queued',
progress: null,
total_bytes: null,
bytes_done: 0,
speed_bps: null,
eta_seconds: null
})
)
expect(screen.getByTestId('meta-line')).toBeEmptyDOMElement()
})
it('omits the eta once a download is no longer active', () => {
mountRow(
createDownload({
status: 'paused',
progress: 0.5,
eta_seconds: 125
})
)
expect(screen.getByText('50% · 1 KB / 2 KB · 512 B/s')).toBeInTheDocument()
})
describe('progress bar visibility', () => {
it('shows a progress bar for an active download with known progress', () => {
mountRow(createDownload({ status: 'active', progress: 0.5 }))
expect(screen.getByTestId('progress-bar')).toBeInTheDocument()
})
it('hides the progress bar for a cancelled download', () => {
mountRow(createDownload({ status: 'cancelled', progress: 0.5 }))
expect(screen.queryByTestId('progress-bar')).not.toBeInTheDocument()
expect(screen.getByText('Cancelled')).toBeInTheDocument()
})
})
describe('action buttons by status', () => {
it('shows pause and cancel for a queued download, plus raise priority', () => {
mountRow(createDownload({ status: 'queued' }))
expect(screen.getByTitle('Raise priority')).toBeInTheDocument()
expect(screen.getByTitle('Pause')).toBeInTheDocument()
expect(screen.getByTitle('Cancel')).toBeInTheDocument()
expect(screen.queryByTitle('Resume')).not.toBeInTheDocument()
expect(screen.queryByTitle('Remove from list')).not.toBeInTheDocument()
})
it('shows only resume for a failed download without an auth error', () => {
mountRow(createDownload({ status: 'failed', error: 'disk full' }))
expect(screen.getByTitle('Resume')).toBeInTheDocument()
expect(screen.queryByTitle('Pause')).not.toBeInTheDocument()
expect(screen.queryByTitle('Cancel')).not.toBeInTheDocument()
expect(screen.queryByTitle('Remove from list')).not.toBeInTheDocument()
expect(screen.queryByTitle('Add credentials')).not.toBeInTheDocument()
})
it('shows the remove action for terminal downloads', () => {
mountRow(createDownload({ status: 'completed' }))
expect(screen.getByTitle('Remove from list')).toBeInTheDocument()
expect(screen.queryByTitle('Cancel')).not.toBeInTheDocument()
expect(screen.queryByTitle('Resume')).not.toBeInTheDocument()
})
})
describe('auth errors', () => {
it('shows the credentials button and a host-specific hint', async () => {
const onOpenCredentials = vi.fn()
mountRow(
createDownload({
status: 'failed',
url: 'https://huggingface.co/org/x.safetensors',
error: '401 Unauthorized'
}),
onOpenCredentials
)
expect(
screen.getByText(
'huggingface.co needs an API key. Add one in the Credentials Manager, then resume.'
)
).toBeInTheDocument()
await userEvent.click(screen.getByTitle('Add credentials'))
expect(onOpenCredentials).toHaveBeenCalledWith('huggingface.co')
})
it('falls back to a hostless hint when the url cannot be parsed', () => {
mountRow(
createDownload({
status: 'failed',
url: 'not-a-url',
error: '403 forbidden'
})
)
expect(
screen.getByText(
'This host/model needs an API key. Add one in the Credentials Manager, then resume.'
)
).toBeInTheDocument()
})
it('shows the raw error message for a non-auth failure', () => {
mountRow(createDownload({ status: 'failed', error: 'disk full' }))
expect(screen.getByText('disk full')).toBeInTheDocument()
expect(screen.queryByTitle('Add credentials')).not.toBeInTheDocument()
})
it('hides a leftover error while the download is not yet terminal', () => {
mountRow(createDownload({ status: 'active', error: 'disk full' }))
expect(screen.queryByText('disk full')).not.toBeInTheDocument()
})
})
describe('gated models', () => {
const gatedError =
'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/vae/diffusion_pytorch_model.safetensors is a gated model — Access to model black-forest-labs/FLUX.2-dev is restricted. You must have access to it and be authenticated to access it.'
it('shows the gated hint, credentials button, and a link to accept the license on the model page', () => {
mountRow(
createDownload({
status: 'failed',
url: 'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/vae/diffusion_pytorch_model.safetensors',
error: gatedError
})
)
expect(
screen.getByText(
"This model is gated. Accept its license on the model's page, then add an API key and resume."
)
).toBeInTheDocument()
expect(screen.getByTitle('Add credentials')).toBeInTheDocument()
const link = screen.getByRole('link', { name: 'Accept license' })
expect(link).toHaveAttribute(
'href',
'https://huggingface.co/black-forest-labs/FLUX.2-dev'
)
})
it('derives the model page from the error when the download url is a cdn link', () => {
mountRow(
createDownload({
status: 'failed',
url: 'https://cas-bridge.xethub.hf.co/xet-bridge-us/abc/def',
error: gatedError
})
)
expect(
screen.getByRole('link', { name: 'Accept license' })
).toHaveAttribute(
'href',
'https://huggingface.co/black-forest-labs/FLUX.2-dev'
)
})
it('hides the raw backend error text for gated failures', () => {
mountRow(
createDownload({
status: 'failed',
url: 'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/model.safetensors',
error: gatedError
})
)
expect(screen.queryByText(gatedError)).not.toBeInTheDocument()
})
it('omits the accept-license link when no huggingface url is present', () => {
mountRow(
createDownload({
status: 'failed',
url: 'https://example.com/model.safetensors',
error: 'Access to this model is restricted, request access.'
})
)
expect(
screen.queryByRole('link', { name: 'Accept license' })
).not.toBeInTheDocument()
})
})
describe('user actions', () => {
it('pauses on click', async () => {
const download = createDownload({ status: 'active' })
mountRow(download)
await userEvent.click(screen.getByTitle('Pause'))
expect(mockPause).toHaveBeenCalledWith(download)
})
it('resumes on click', async () => {
const download = createDownload({ status: 'paused' })
mountRow(download)
await userEvent.click(screen.getByTitle('Resume'))
expect(mockResume).toHaveBeenCalledWith(download)
})
it('cancels on click', async () => {
const download = createDownload({ status: 'active' })
mountRow(download)
await userEvent.click(screen.getByTitle('Cancel'))
expect(mockCancel).toHaveBeenCalledWith(download)
})
it('raises priority by 1 on click', async () => {
const download = createDownload({ status: 'queued', priority: 2 })
mountRow(download)
await userEvent.click(screen.getByTitle('Raise priority'))
expect(mockRaisePriority).toHaveBeenCalledWith(download, 1)
})
it('removes the download on click', async () => {
const download = createDownload({
download_id: 'd1',
status: 'completed'
})
mountRow(download)
await userEvent.click(screen.getByTitle('Remove from list'))
expect(mockRemove).toHaveBeenCalledWith(download)
})
})
})

View File

@@ -1,245 +0,0 @@
<template>
<div
:class="
cn(
'relative flex flex-col gap-1 overflow-hidden rounded-lg border border-border-default bg-secondary-background px-3 py-2',
isCancelled && 'opacity-60'
)
"
>
<div
v-if="showProgressBar"
:class="progressBarContainerClass"
data-testid="progress-bar"
>
<div :class="progressBarPrimaryClass" :style="barStyle" />
</div>
<div class="relative flex items-center gap-2">
<div class="flex min-w-0 flex-1 flex-col">
<span class="truncate text-sm font-medium text-base-foreground">
{{ filename }}
</span>
<span class="truncate text-xs text-muted-foreground">
{{ directory }}
</span>
</div>
<div class="flex shrink-0 items-center gap-0.5">
<template v-if="canRaisePriority">
<Button
variant="textonly"
size="icon"
:title="$t('modelManager.raisePriority')"
@click="actions.raisePriority(download, 1)"
>
<i class="icon-[lucide--chevron-up] size-4" />
</Button>
</template>
<Button
v-if="canPause"
variant="textonly"
size="icon"
:title="$t('g.pause')"
@click="actions.pause(download)"
>
<i class="icon-[lucide--pause] size-4" />
</Button>
<Button
v-if="isAuthError"
variant="textonly"
size="icon"
:title="$t('modelManager.addCredentials')"
@click="emit('openCredentials', host)"
>
<i class="icon-[lucide--key-round] size-4" />
</Button>
<Button
v-if="canResume"
variant="textonly"
size="icon"
:title="$t('modelManager.resume')"
@click="actions.resume(download)"
>
<i class="icon-[lucide--play] size-4" />
</Button>
<Button
v-if="canCancel"
variant="textonly"
size="icon"
:title="$t('g.cancel')"
@click="actions.cancel(download)"
>
<i class="icon-[lucide--x] size-4 text-red-400" />
</Button>
<Button
v-if="isTerminal"
variant="textonly"
size="icon"
:title="$t('modelManager.removeFromList')"
@click="actions.remove(download)"
>
<i class="icon-[lucide--x] size-4" />
</Button>
</div>
</div>
<div
class="relative flex items-center justify-between gap-2 text-xs text-muted-foreground"
>
<span class="flex items-center gap-1">
<i v-if="isCancelled" class="icon-[lucide--ban] size-3.5" />
{{ statusLabel }}
</span>
<span class="truncate" data-testid="meta-line">{{ metaLine }}</span>
</div>
<p
v-if="isGatedModel"
class="relative text-xs wrap-break-word text-amber-400"
>
{{ $t('modelManager.gatedModelHint') }}
<a
v-if="modelPageUrl"
:href="modelPageUrl"
target="_blank"
rel="noopener noreferrer"
class="underline"
>
{{ $t('modelManager.openModelPage') }}
</a>
</p>
<p
v-else-if="isAuthError"
class="relative text-xs wrap-break-word text-amber-400"
>
{{
host
? $t('modelManager.authErrorHint', { host })
: $t('modelManager.authErrorHintNoHost')
}}
</p>
<p
v-else-if="isFailed && download.error"
class="relative text-xs wrap-break-word text-red-400"
>
{{ download.error }}
</p>
</div>
</template>
<script setup lang="ts">
import { cn } from '@comfyorg/tailwind-utils'
import { computed } from 'vue'
import { useI18n } from 'vue-i18n'
import Button from '@/components/ui/button/Button.vue'
import { useProgressBarBackground } from '@/composables/useProgressBarBackground'
import { formatSize } from '@/utils/formatUtil'
import { useModelDownloadActions } from '../composables/useModelDownloadActions'
import { downloadProgressFraction } from '../stores/modelDownloadStore'
import type { DownloadStatus } from '../types'
import { directoryOf, filenameOf } from '../utils/modelId'
const { download } = defineProps<{ download: DownloadStatus }>()
const emit = defineEmits<{ openCredentials: [host: string] }>()
const { t } = useI18n()
const actions = useModelDownloadActions()
const {
progressBarContainerClass,
progressBarPrimaryClass,
progressPercentStyle
} = useProgressBarBackground()
const directory = computed(() => directoryOf(download.model_id))
const filename = computed(() => filenameOf(download.model_id))
const host = computed(() => {
try {
return new URL(download.url).hostname
} catch {
return ''
}
})
const percent = computed(() => {
const fraction = downloadProgressFraction(download)
return fraction == null ? undefined : Math.round(fraction * 100)
})
const barStyle = computed(() => progressPercentStyle(percent.value))
const isTerminal = computed(() =>
['completed', 'cancelled'].includes(download.status)
)
const isCancelled = computed(() => download.status === 'cancelled')
const showProgressBar = computed(
() =>
download.status !== 'failed' &&
!isCancelled.value &&
percent.value !== undefined
)
const canPause = computed(() => ['queued', 'active'].includes(download.status))
const canResume = computed(() => ['paused', 'failed'].includes(download.status))
const canCancel = computed(
() => !isTerminal.value && download.status !== 'failed'
)
const canRaisePriority = computed(() => download.status === 'queued')
const GATED_MODEL_PATTERN = /gated|restricted|request access|must have access/i
const AUTH_ERROR_PATTERN =
/api key|credential|token|unauthor|forbidden|\b401\b|\b403\b/i
const isFailed = computed(() => download.status === 'failed')
const isGatedModel = computed(
() => isFailed.value && !!download.error?.match(GATED_MODEL_PATTERN)
)
const isAuthError = computed(
() =>
isFailed.value &&
(isGatedModel.value || !!download.error?.match(AUTH_ERROR_PATTERN))
)
const HF_MODEL_URL_PATTERN =
/https?:\/\/huggingface\.co\/([^/\s?#]+)\/([^/\s?#]+)/i
const modelPageUrl = computed(() => {
const match = `${download.url} ${download.error ?? ''}`.match(
HF_MODEL_URL_PATTERN
)
if (!match) return ''
const [, owner, repo] = match
return `https://huggingface.co/${owner}/${repo}`
})
const statusLabel = computed(() => t(`modelManager.status.${download.status}`))
const metaLine = computed(() => {
const parts: string[] = []
if (percent.value !== undefined) parts.push(`${percent.value}%`)
if (download.total_bytes != null) {
parts.push(
`${formatSize(download.bytes_done)} / ${formatSize(download.total_bytes)}`
)
}
if (download.speed_bps) parts.push(`${formatSize(download.speed_bps)}/s`)
if (download.eta_seconds != null && download.status === 'active') {
parts.push(formatEta(download.eta_seconds))
}
return parts.join(' · ')
})
function formatEta(seconds: number): string {
const total = Math.max(0, Math.round(seconds))
const hours = Math.floor(total / 3600)
const minutes = Math.floor((total % 3600) / 60)
const secs = total % 60
const paddedSecs = secs.toString().padStart(2, '0')
if (hours > 0) {
return `${hours}:${minutes.toString().padStart(2, '0')}:${paddedSecs}`
}
return `${minutes}:${paddedSecs}`
}
</script>

View File

@@ -1,178 +0,0 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { reactive, ref } from 'vue'
import { createI18n } from 'vue-i18n'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { DownloadStatus } from '../types'
import ModelManagerSidebarTab from './ModelManagerSidebarTab.vue'
const mockHydrate = vi.fn()
const mockClearHistory = vi.fn()
const mockStore = reactive({
downloadList: ref<DownloadStatus[]>([]),
activeDownloads: ref<DownloadStatus[]>([]),
historyDownloads: ref<DownloadStatus[]>([]),
hydrate: mockHydrate
})
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => mockStore
}))
vi.mock('../composables/useModelDownloadActions', () => ({
useModelDownloadActions: () => ({ clearHistory: mockClearHistory })
}))
vi.mock('./ModelDownloadRow.vue', () => ({
default: {
name: 'ModelDownloadRow',
props: ['download'],
emits: ['openCredentials'],
template:
'<div><span>{{ download.download_id }}</span>' +
'<button @click="$emit(\'openCredentials\', download.model_id)">open-credentials</button></div>'
}
}))
vi.mock('./AddModelByUrlDialog.vue', () => ({
default: {
name: 'AddModelByUrlDialog',
props: ['open'],
template: '<div data-testid="add-model-dialog">{{ open }}</div>'
}
}))
vi.mock('./HostCredentialsDialog.vue', () => ({
default: {
name: 'HostCredentialsDialog',
props: ['open', 'prefillHost'],
template:
'<div data-testid="credentials-dialog">{{ open }}:{{ prefillHost }}</div>'
}
}))
vi.mock('@/components/sidebar/tabs/SidebarTabTemplate.vue', () => ({
default: {
name: 'SidebarTabTemplate',
template: '<div><slot name="tool-buttons" /><slot name="body" /></div>'
}
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: { en: enMessages },
missingWarn: false,
fallbackWarn: false
})
function createDownload(
overrides: Partial<DownloadStatus> = {}
): DownloadStatus {
return {
download_id: 'd1',
model_id: 'loras/x.safetensors',
url: 'https://huggingface.co/org/x.safetensors',
status: 'active',
priority: 0,
total_bytes: null,
bytes_done: 0,
progress: null,
speed_bps: null,
eta_seconds: null,
segments: null,
error: null,
created_at: 0,
updated_at: 0,
...overrides
}
}
function mountTab() {
return render(ModelManagerSidebarTab, { global: { plugins: [i18n] } })
}
describe('ModelManagerSidebarTab', () => {
beforeEach(() => {
vi.clearAllMocks()
mockStore.downloadList = []
mockStore.activeDownloads = []
mockStore.historyDownloads = []
})
it('hydrates the store on mount', () => {
mountTab()
expect(mockHydrate).toHaveBeenCalled()
})
it('shows the empty state when there are no downloads', () => {
mountTab()
expect(screen.getByText('No downloads yet')).toBeInTheDocument()
})
it('renders active downloads under the Active section', () => {
const download = createDownload({ download_id: 'd1' })
mockStore.activeDownloads = [download]
mockStore.downloadList = [download]
mountTab()
expect(screen.getByText('Active')).toBeInTheDocument()
expect(screen.getByText('d1')).toBeInTheDocument()
expect(screen.queryByText('No downloads yet')).not.toBeInTheDocument()
})
it('renders history downloads under the History section with a clear action', async () => {
const download = createDownload({ download_id: 'd2' })
mockStore.historyDownloads = [download]
mockStore.downloadList = [download]
mountTab()
expect(screen.getByText('History')).toBeInTheDocument()
expect(screen.getByText('d2')).toBeInTheDocument()
await userEvent.click(screen.getByText('Clear history'))
expect(mockClearHistory).toHaveBeenCalled()
})
it('opens the add-model dialog from the toolbar button', async () => {
mountTab()
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('false')
await userEvent.click(screen.getByTitle('Add model'))
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('true')
})
it('opens the add-model dialog from the empty state button', async () => {
mountTab()
await userEvent.click(screen.getByText('Add model'))
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('true')
})
it('opens the credentials dialog with an empty prefill from the toolbar button', async () => {
mountTab()
await userEvent.click(screen.getByTitle('Credentials Manager'))
expect(screen.getByTestId('credentials-dialog')).toHaveTextContent('true:')
})
it('opens the credentials dialog prefilled with the row host', async () => {
const download = createDownload({
download_id: 'd1',
model_id: 'loras/x.safetensors'
})
mockStore.activeDownloads = [download]
mockStore.downloadList = [download]
mountTab()
await userEvent.click(screen.getByText('open-credentials'))
expect(screen.getByTestId('credentials-dialog')).toHaveTextContent(
'true:loras/x.safetensors'
)
})
})

View File

@@ -1,103 +0,0 @@
<template>
<SidebarTabTemplate :title="$t('modelManager.title')">
<template #tool-buttons>
<Button
variant="textonly"
size="icon"
:title="$t('modelManager.credentials.title')"
@click="openCredentials('')"
>
<i class="icon-[lucide--key-round] size-4" />
</Button>
<Button
variant="textonly"
size="icon"
:title="$t('modelManager.addModel')"
@click="addOpen = true"
>
<i class="icon-[lucide--plus] size-4" />
</Button>
</template>
<template #body>
<div class="flex flex-col gap-4 p-3">
<section v-if="activeDownloads.length" class="flex flex-col gap-2">
<h3 class="text-xs font-semibold text-muted-foreground uppercase">
{{ $t('modelManager.active') }}
</h3>
<ModelDownloadRow
v-for="download in activeDownloads"
:key="download.download_id"
:download
@open-credentials="openCredentials"
/>
</section>
<section v-if="historyDownloads.length" class="flex flex-col gap-2">
<div class="flex items-center justify-between">
<h3 class="text-xs font-semibold text-muted-foreground uppercase">
{{ $t('modelManager.history') }}
</h3>
<Button variant="link" size="sm" @click="actions.clearHistory()">
{{ $t('modelManager.clearHistory') }}
</Button>
</div>
<ModelDownloadRow
v-for="download in historyDownloads"
:key="download.download_id"
:download
@open-credentials="openCredentials"
/>
</section>
<div
v-if="!store.downloadList.length"
class="flex flex-col items-center gap-3 py-10 text-center text-muted-foreground"
>
<i class="icon-[lucide--download] size-8" />
<p class="text-sm">{{ $t('modelManager.empty') }}</p>
<Button variant="primary" size="sm" @click="addOpen = true">
{{ $t('modelManager.addModel') }}
</Button>
</div>
</div>
</template>
</SidebarTabTemplate>
<AddModelByUrlDialog v-model:open="addOpen" />
<HostCredentialsDialog
v-model:open="credentialsOpen"
:prefill-host="prefillHost"
/>
</template>
<script setup lang="ts">
import { storeToRefs } from 'pinia'
import { onMounted, ref } from 'vue'
import SidebarTabTemplate from '@/components/sidebar/tabs/SidebarTabTemplate.vue'
import Button from '@/components/ui/button/Button.vue'
import AddModelByUrlDialog from './AddModelByUrlDialog.vue'
import HostCredentialsDialog from './HostCredentialsDialog.vue'
import ModelDownloadRow from './ModelDownloadRow.vue'
import { useModelDownloadActions } from '../composables/useModelDownloadActions'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
const store = useModelDownloadStore()
const actions = useModelDownloadActions()
const { activeDownloads, historyDownloads } = storeToRefs(store)
const addOpen = ref(false)
const credentialsOpen = ref(false)
const prefillHost = ref('')
function openCredentials(host: string) {
prefillHost.value = host
credentialsOpen.value = true
}
onMounted(() => {
void store.hydrate()
})
</script>

View File

@@ -1,51 +0,0 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import * as downloadApi from '../api/modelDownloadApi'
import { useModelAvailability } from './useModelAvailability'
vi.mock('../api/modelDownloadApi', () => ({
checkAvailability: vi.fn()
}))
describe('useModelAvailability', () => {
beforeEach(() => {
vi.resetAllMocks()
})
it('short-circuits without calling the API for an empty map', async () => {
const { check, results } = useModelAvailability()
await check({})
expect(downloadApi.checkAvailability).not.toHaveBeenCalled()
expect(results.value).toEqual({})
})
it('stores the availability map from the response', async () => {
vi.mocked(downloadApi.checkAvailability).mockResolvedValue({
models: {
'loras/x.safetensors': { state: 'available', url_allowed: true }
}
})
const { check, results, isChecking } = useModelAvailability()
await check({ 'loras/x.safetensors': 'https://h.co/x' })
expect(results.value['loras/x.safetensors']).toEqual({
state: 'available',
url_allowed: true
})
expect(isChecking.value).toBe(false)
})
it('captures errors and resets the checking flag', async () => {
vi.mocked(downloadApi.checkAvailability).mockRejectedValue(
new Error('boom')
)
const { check, error, isChecking } = useModelAvailability()
await expect(check({ 'loras/x.safetensors': 'u' })).rejects.toThrow('boom')
expect(error.value).toBeInstanceOf(Error)
expect(isChecking.value).toBe(false)
})
})

View File

@@ -1,41 +0,0 @@
import { ref, shallowRef } from 'vue'
import { checkAvailability } from '../api/modelDownloadApi'
import type { AvailabilityEntry } from '../types'
/**
* Bulk-checks whether the models declared by a workflow are present,
* downloading, or missing. Pass a `model_id -> source URL` map and badge
* each model from the returned entries.
*/
export function useModelAvailability() {
const results = shallowRef<Record<string, AvailabilityEntry>>({})
const isChecking = ref(false)
const error = ref<unknown>(null)
async function check(models: Record<string, string>) {
if (Object.keys(models).length === 0) {
results.value = {}
return results.value
}
isChecking.value = true
error.value = null
try {
const response = await checkAvailability(models)
results.value = response.models
return response.models
} catch (e) {
error.value = e
throw e
} finally {
isChecking.value = false
}
}
return {
results,
isChecking,
error,
check
}
}

View File

@@ -1,168 +0,0 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
import type { DownloadStatus } from '../types'
import { DownloadApiError } from '../types'
import { useModelDownloadActions } from './useModelDownloadActions'
const mockStore = {
pause: vi.fn(),
resume: vi.fn(),
cancel: vi.fn(),
setPriority: vi.fn(),
hydrate: vi.fn()
}
const mockToastAdd = vi.fn()
const mockCloseDialog = vi.fn()
vi.mock('vue-i18n', () => ({
useI18n: () => ({ t: (key: string) => key })
}))
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => mockStore
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: () => ({ add: mockToastAdd })
}))
vi.mock('@/stores/dialogStore', () => ({
useDialogStore: () => ({ closeDialog: mockCloseDialog })
}))
vi.mock('@/components/dialog/confirm/confirmDialog')
const mockShowConfirmDialog = vi.mocked(showConfirmDialog)
interface CapturedConfirmOptions {
footerProps?: {
confirmVariant?: string
onConfirm?: () => void | Promise<void>
onCancel?: () => void
}
}
function capturedOptions(): CapturedConfirmOptions {
return mockShowConfirmDialog.mock
.calls[0][0] as unknown as CapturedConfirmOptions
}
function createDownload(
overrides: Partial<DownloadStatus> = {}
): DownloadStatus {
return {
download_id: 'd1',
model_id: 'loras/x.safetensors',
url: 'https://huggingface.co/x.safetensors',
status: 'active',
priority: 0,
total_bytes: null,
bytes_done: 0,
progress: null,
speed_bps: null,
eta_seconds: null,
segments: null,
error: null,
created_at: 0,
updated_at: 0,
...overrides
}
}
describe('useModelDownloadActions', () => {
beforeEach(() => {
vi.clearAllMocks()
mockShowConfirmDialog.mockReturnValue(
{} as ReturnType<typeof showConfirmDialog>
)
})
it('pauses a download by id', async () => {
mockStore.pause.mockResolvedValue(undefined)
const { pause } = useModelDownloadActions()
await pause(createDownload({ download_id: 'd1' }))
expect(mockStore.pause).toHaveBeenCalledWith('d1')
})
it('resumes a download by id', async () => {
mockStore.resume.mockResolvedValue(undefined)
const { resume } = useModelDownloadActions()
await resume(createDownload({ download_id: 'd1' }))
expect(mockStore.resume).toHaveBeenCalledWith('d1')
})
it('raises priority relative to the current value', async () => {
mockStore.setPriority.mockResolvedValue(undefined)
const { raisePriority } = useModelDownloadActions()
await raisePriority(createDownload({ download_id: 'd1', priority: 2 }), 1)
expect(mockStore.setPriority).toHaveBeenCalledWith('d1', 3)
})
it('shows an error toast and re-hydrates the store when an action fails', async () => {
mockStore.pause.mockRejectedValue(new Error('offline'))
mockStore.hydrate.mockResolvedValue(undefined)
const { pause } = useModelDownloadActions()
await pause(createDownload())
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error', detail: 'offline' })
)
expect(mockStore.hydrate).toHaveBeenCalled()
})
it('uses the DownloadApiError message in the toast', async () => {
mockStore.pause.mockRejectedValue(
new DownloadApiError('nope', 'URL_NOT_ALLOWED', 400)
)
mockStore.hydrate.mockResolvedValue(undefined)
const { pause } = useModelDownloadActions()
await pause(createDownload())
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ detail: 'nope' })
)
})
it('swallows a failed re-hydrate after an action error', async () => {
mockStore.pause.mockRejectedValue(new Error('offline'))
mockStore.hydrate.mockRejectedValue(new Error('still offline'))
const { pause } = useModelDownloadActions()
await expect(pause(createDownload())).resolves.toBeUndefined()
})
describe('cancel', () => {
it('opens a destructive confirm dialog and only cancels on confirm', async () => {
mockStore.cancel.mockResolvedValue(undefined)
const { cancel } = useModelDownloadActions()
cancel(createDownload({ download_id: 'd1' }))
expect(capturedOptions().footerProps?.confirmVariant).toBe('destructive')
await capturedOptions().footerProps?.onConfirm?.()
expect(mockStore.cancel).toHaveBeenCalledWith('d1')
expect(mockCloseDialog).toHaveBeenCalled()
})
it('does not cancel when the confirm dialog is dismissed', () => {
const { cancel } = useModelDownloadActions()
cancel(createDownload({ download_id: 'd1' }))
capturedOptions().footerProps?.onCancel?.()
expect(mockStore.cancel).not.toHaveBeenCalled()
expect(mockCloseDialog).toHaveBeenCalled()
})
})
})

View File

@@ -1,97 +0,0 @@
import { useI18n } from 'vue-i18n'
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
import { useToastStore } from '@/platform/updates/common/toastStore'
import { useDialogStore } from '@/stores/dialogStore'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
import type { DownloadStatus } from '../types'
import { DownloadApiError } from '../types'
/**
* Wraps download store mutations with user feedback: optimistic-friendly error
* toasts and a confirmation prompt for the destructive cancel action (which
* deletes the partial file).
*/
export function useModelDownloadActions() {
const store = useModelDownloadStore()
const dialogStore = useDialogStore()
const { t } = useI18n()
function toastError(error: unknown) {
const detail =
error instanceof DownloadApiError || error instanceof Error
? error.message
: String(error)
useToastStore().add({
severity: 'error',
summary: t('modelManager.actionFailed'),
detail,
life: 5000
})
}
async function run(action: () => Promise<void>) {
try {
await action()
} catch (error) {
toastError(error)
// The store applies optimistic status/priority patches before the API
// call; re-fetch the authoritative state so a failed mutation doesn't
// leave the row stuck in the wrong local state.
try {
await store.hydrate()
} catch {
// Server unreachable; the next poll will reconcile.
}
}
}
const pause = (download: DownloadStatus) =>
run(() => store.pause(download.download_id))
const resume = (download: DownloadStatus) =>
run(() => store.resume(download.download_id))
const raisePriority = (download: DownloadStatus, delta: number) =>
run(() =>
store.setPriority(download.download_id, download.priority + delta)
)
const remove = (download: DownloadStatus) =>
run(() => store.remove(download.download_id))
const clearHistory = () => run(() => store.clearHistory())
function cancel(download: DownloadStatus) {
const dialog = showConfirmDialog({
headerProps: { title: t('modelManager.cancelConfirmTitle') },
props: {
promptText: t(
'modelManager.cancelConfirmMessage',
{ name: download.model_id },
{ escapeParameter: false }
)
},
footerProps: {
confirmText: t('modelManager.cancelConfirm'),
confirmVariant: 'destructive' as const,
onCancel: () => dialogStore.closeDialog(dialog),
onConfirm: async () => {
dialogStore.closeDialog(dialog)
await run(() => store.cancel(download.download_id))
}
}
})
}
return {
pause,
resume,
cancel,
raisePriority,
remove,
clearHistory,
toastError
}
}

View File

@@ -1,93 +0,0 @@
import { ref } from 'vue'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { CompletedDownload } from '../stores/modelDownloadStore'
import { useModelDownloadEffects } from './useModelDownloadEffects'
const mockLastCompletedDownload = ref<CompletedDownload | null>(null)
const mockRefreshModelFolder = vi.fn()
const mockRefreshMissingModels = vi.fn()
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => ({
get lastCompletedDownload() {
return mockLastCompletedDownload.value
}
})
}))
vi.mock('@/stores/modelStore', () => ({
useModelStore: () => ({ refreshModelFolder: mockRefreshModelFolder })
}))
vi.mock('@/platform/missingModel/missingModelStore', () => ({
useMissingModelStore: () => ({
refreshMissingModels: mockRefreshMissingModels
})
}))
describe('useModelDownloadEffects', () => {
beforeEach(() => {
vi.clearAllMocks()
mockLastCompletedDownload.value = null
mockRefreshModelFolder.mockResolvedValue(undefined)
})
it('does nothing until a download completes', () => {
useModelDownloadEffects()
expect(mockRefreshModelFolder).not.toHaveBeenCalled()
expect(mockRefreshMissingModels).not.toHaveBeenCalled()
})
it('refreshes the model folder and re-scans missing models on completion', async () => {
useModelDownloadEffects()
mockLastCompletedDownload.value = {
downloadId: 'd1',
modelId: 'loras/x.safetensors',
directory: 'loras',
timestamp: 1
}
await vi.waitFor(() => {
expect(mockRefreshModelFolder).toHaveBeenCalledWith('loras')
})
expect(mockRefreshMissingModels).toHaveBeenCalled()
})
it('skips the folder refresh when the directory is unknown', async () => {
useModelDownloadEffects()
mockLastCompletedDownload.value = {
downloadId: 'd1',
modelId: 'unknown.safetensors',
directory: '',
timestamp: 1
}
await vi.waitFor(() => {
expect(mockRefreshMissingModels).toHaveBeenCalled()
})
expect(mockRefreshModelFolder).not.toHaveBeenCalled()
})
it('still re-scans missing models when the folder refresh fails', async () => {
const consoleWarn = vi.spyOn(console, 'warn').mockImplementation(() => {})
mockRefreshModelFolder.mockRejectedValue(new Error('boom'))
useModelDownloadEffects()
mockLastCompletedDownload.value = {
downloadId: 'd1',
modelId: 'loras/x.safetensors',
directory: 'loras',
timestamp: 1
}
await vi.waitFor(() => {
expect(mockRefreshMissingModels).toHaveBeenCalled()
})
expect(consoleWarn).toHaveBeenCalled()
consoleWarn.mockRestore()
})
})

View File

@@ -1,45 +0,0 @@
import { watch } from 'vue'
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
import { useModelStore } from '@/stores/modelStore'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
/**
* Side effects that run when a server-side download finishes: refresh the
* affected model folder so the file is recognized, and re-scan missing models
* so node errors for the now-present model clear automatically.
*
* Mounted once at the app root so it runs regardless of whether the Model
* Manager panel is open.
*/
export function useModelDownloadEffects() {
const store = useModelDownloadStore()
watch(
() => store.lastCompletedDownload,
async (completed) => {
if (!completed) return
if (completed.directory) {
try {
await useModelStore().refreshModelFolder(completed.directory)
} catch (error) {
console.warn(
'[ModelManager] Failed to refresh model folder after download',
error
)
}
}
try {
await useMissingModelStore().refreshMissingModels()
} catch (error) {
console.warn(
'[ModelManager] Failed to re-scan missing models after download',
error
)
}
}
)
}

View File

@@ -1,44 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
const mockActiveDownloadCount = { value: 0 }
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => ({
get activeDownloadCount() {
return mockActiveDownloadCount.value
}
})
}))
vi.mock('../components/ModelManagerSidebarTab.vue', () => ({
default: { name: 'ModelManagerSidebarTab' }
}))
import { useModelManagerSidebarTab } from './useModelManagerSidebarTab'
describe('useModelManagerSidebarTab', () => {
it('returns the expected sidebar tab extension shape', () => {
const tab = useModelManagerSidebarTab()
expect(tab.id).toBe('model-manager')
expect(tab.type).toBe('vue')
expect(tab.title).toBe('modelManager.title')
expect(tab.tooltip).toBe('modelManager.title')
expect(tab.label).toBe('modelManager.title')
})
it('shows no badge when there are no active downloads', () => {
mockActiveDownloadCount.value = 0
const tab = useModelManagerSidebarTab()
expect(typeof tab.iconBadge).toBe('function')
expect((tab.iconBadge as () => string | null)()).toBeNull()
})
it('shows the active download count as a badge', () => {
mockActiveDownloadCount.value = 3
const tab = useModelManagerSidebarTab()
expect((tab.iconBadge as () => string | null)()).toBe('3')
})
})

View File

@@ -1,22 +0,0 @@
import { markRaw } from 'vue'
import type { SidebarTabExtension } from '@/types/extensionTypes'
import ModelManagerSidebarTab from '../components/ModelManagerSidebarTab.vue'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
export function useModelManagerSidebarTab(): SidebarTabExtension {
return {
id: 'model-manager',
icon: 'icon-[lucide--download]',
title: 'modelManager.title',
tooltip: 'modelManager.title',
label: 'modelManager.title',
component: markRaw(ModelManagerSidebarTab),
type: 'vue',
iconBadge: () => {
const count = useModelDownloadStore().activeDownloadCount
return count > 0 ? count.toString() : null
}
}
}

View File

@@ -1,191 +0,0 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import * as downloadApi from '../api/modelDownloadApi'
import type { HostCredentialUpsert, HostCredentialView } from '../types'
import { useHostCredentialsStore } from './hostCredentialsStore'
vi.mock('../api/modelDownloadApi', () => ({
listCredentials: vi.fn(),
upsertCredential: vi.fn(),
deleteCredential: vi.fn()
}))
function createCredential(
overrides: Partial<HostCredentialView> = {}
): HostCredentialView {
return {
id: 'c1',
host: 'huggingface.co',
auth_scheme: 'bearer',
header_name: null,
query_param: null,
label: null,
match_subdomains: false,
enabled: true,
secret_last4: '1234',
created_at: 0,
updated_at: 0,
...overrides
}
}
describe('useHostCredentialsStore', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.resetAllMocks()
})
describe('fetchCredentials', () => {
it('toggles isLoading and stores the result', async () => {
const credential = createCredential()
vi.mocked(downloadApi.listCredentials).mockResolvedValue([credential])
const store = useHostCredentialsStore()
const promise = store.fetchCredentials()
expect(store.isLoading).toBe(true)
await promise
expect(store.isLoading).toBe(false)
expect(store.credentials).toEqual([credential])
})
it('clears isLoading even when the request fails', async () => {
vi.mocked(downloadApi.listCredentials).mockRejectedValue(
new Error('boom')
)
const store = useHostCredentialsStore()
await expect(store.fetchCredentials()).rejects.toThrow('boom')
expect(store.isLoading).toBe(false)
})
})
describe('upsert', () => {
it('normalizes the host and appends a new credential', async () => {
const created = createCredential({ id: 'new', host: 'huggingface.co' })
vi.mocked(downloadApi.upsertCredential).mockResolvedValue(created)
const store = useHostCredentialsStore()
const body: HostCredentialUpsert = {
host: ' HuggingFace.co ',
secret: 's3cret'
}
const result = await store.upsert(body)
expect(downloadApi.upsertCredential).toHaveBeenCalledWith(
expect.objectContaining({ host: 'huggingface.co', secret: 's3cret' })
)
expect(result).toEqual(created)
expect(store.credentials).toEqual([created])
})
it('replaces an existing credential by id', async () => {
const original = createCredential({ id: 'c1', label: 'Old' })
const updated = createCredential({ id: 'c1', label: 'New' })
vi.mocked(downloadApi.upsertCredential).mockResolvedValue(updated)
const store = useHostCredentialsStore()
store.credentials.push(original)
await store.upsert({ host: 'huggingface.co', secret: 's' })
expect(store.credentials).toEqual([updated])
})
})
describe('remove', () => {
it('deletes and filters out the credential by id', async () => {
vi.mocked(downloadApi.deleteCredential).mockResolvedValue(undefined)
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({ id: 'c1' }),
createCredential({ id: 'c2' })
)
await store.remove('c1')
expect(downloadApi.deleteCredential).toHaveBeenCalledWith('c1')
expect(store.credentials.map((c) => c.id)).toEqual(['c2'])
})
})
describe('enabledCredentialForHost', () => {
it('matches an exact enabled host case-insensitively', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({ host: 'HuggingFace.co', enabled: true })
)
expect(store.enabledCredentialForHost('huggingface.co')?.host).toBe(
'HuggingFace.co'
)
})
it('returns undefined for a disabled exact match', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({ host: 'huggingface.co', enabled: false })
)
expect(store.enabledCredentialForHost('huggingface.co')).toBeUndefined()
})
it('falls back to a subdomain match only when match_subdomains is set', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({
host: 'huggingface.co',
match_subdomains: true,
enabled: true
})
)
expect(store.enabledCredentialForHost('cdn.huggingface.co')?.host).toBe(
'huggingface.co'
)
})
it('skips a disabled subdomain credential and returns a later enabled match', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({
id: 'disabled',
host: 'huggingface.co',
match_subdomains: true,
enabled: false
}),
createCredential({
id: 'enabled',
host: 'huggingface.co',
match_subdomains: true,
enabled: true
})
)
expect(store.enabledCredentialForHost('cdn.huggingface.co')?.id).toBe(
'enabled'
)
})
it('does not subdomain-match when match_subdomains is false', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({
host: 'huggingface.co',
match_subdomains: false,
enabled: true
})
)
expect(
store.enabledCredentialForHost('cdn.huggingface.co')
).toBeUndefined()
})
it('returns undefined when no host matches', () => {
const store = useHostCredentialsStore()
expect(store.enabledCredentialForHost('example.com')).toBeUndefined()
})
})
})

View File

@@ -1,75 +0,0 @@
import { defineStore } from 'pinia'
import { computed, ref } from 'vue'
import {
deleteCredential,
listCredentials,
upsertCredential
} from '../api/modelDownloadApi'
import type { HostCredentialUpsert, HostCredentialView } from '../types'
function normalizeHost(host: string): string {
return host.trim().toLowerCase()
}
export const useHostCredentialsStore = defineStore('hostCredentials', () => {
const credentials = ref<HostCredentialView[]>([])
const isLoading = ref(false)
const credentialsByHost = computed(
() => new Map(credentials.value.map((c) => [normalizeHost(c.host), c]))
)
async function fetchCredentials() {
isLoading.value = true
try {
credentials.value = await listCredentials()
} finally {
isLoading.value = false
}
}
async function upsert(
body: HostCredentialUpsert
): Promise<HostCredentialView> {
const view = await upsertCredential({
...body,
host: normalizeHost(body.host)
})
const index = credentials.value.findIndex((c) => c.id === view.id)
credentials.value =
index === -1
? [...credentials.value, view]
: credentials.value.map((c) => (c.id === view.id ? view : c))
return view
}
async function remove(id: string) {
await deleteCredential(id)
credentials.value = credentials.value.filter((c) => c.id !== id)
}
function enabledCredentialForHost(
host: string
): HostCredentialView | undefined {
const normalized = normalizeHost(host)
const exact = credentialsByHost.value.get(normalized)
if (exact) return exact.enabled ? exact : undefined
return credentials.value.find(
(c) =>
c.enabled &&
c.match_subdomains &&
normalized.endsWith(`.${normalizeHost(c.host)}`)
)
}
return {
credentials,
isLoading,
fetchCredentials,
upsert,
remove,
enabledCredentialForHost
}
})

View File

@@ -1,388 +0,0 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
const flushPromises = () => vi.advanceTimersByTimeAsync(0)
import * as downloadApi from '../api/modelDownloadApi'
import type { DownloadState, DownloadStatus } from '../types'
import {
downloadProgressFraction,
useModelDownloadStore
} from './modelDownloadStore'
type ProgressHandler = (e: CustomEvent<DownloadStatus>) => void
const eventHandler = vi.hoisted(() => {
const state: { current: ProgressHandler | null } = { current: null }
return state
})
vi.mock('@/scripts/api', () => ({
api: {
addEventListener: vi.fn((_event: string, handler: ProgressHandler) => {
eventHandler.current = handler
})
}
}))
vi.mock('../api/modelDownloadApi', () => ({
enqueueDownload: vi.fn(),
listDownloads: vi.fn(),
pauseDownload: vi.fn(),
resumeDownload: vi.fn(),
cancelDownload: vi.fn(),
setDownloadPriority: vi.fn(),
deleteDownload: vi.fn(),
clearDownloads: vi.fn()
}))
function createStatus(overrides: Partial<DownloadStatus> = {}): DownloadStatus {
return {
download_id: 'd1',
model_id: 'loras/x.safetensors',
url: 'https://huggingface.co/x.safetensors',
status: 'active',
priority: 0,
total_bytes: 1000,
bytes_done: 250,
progress: 0.25,
speed_bps: 100,
eta_seconds: 10,
segments: null,
error: null,
created_at: 1,
updated_at: 2,
...overrides
}
}
function dispatch(status: DownloadStatus) {
if (!eventHandler.current) throw new Error('handler not registered')
eventHandler.current(new CustomEvent('download_progress', { detail: status }))
}
describe('useModelDownloadStore', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.useFakeTimers({ shouldAdvanceTime: false })
vi.resetAllMocks()
eventHandler.current = null
})
afterEach(() => {
vi.useRealTimers()
})
it('upserts rows from download_progress events', () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', bytes_done: 100 }))
dispatch(createStatus({ download_id: 'd1', bytes_done: 900 }))
dispatch(createStatus({ download_id: 'd2', status: 'queued' }))
expect(store.downloadList).toHaveLength(2)
expect(
store.downloadList.find((d) => d.download_id === 'd1')?.bytes_done
).toBe(900)
})
it('splits active and terminal downloads', () => {
const store = useModelDownloadStore()
const states: DownloadState[] = [
'queued',
'active',
'paused',
'verifying',
'completed',
'failed',
'cancelled'
]
states.forEach((status, idx) =>
dispatch(createStatus({ download_id: `d${idx}`, status }))
)
expect(store.activeDownloads.map((d) => d.status)).toEqual([
'queued',
'active',
'paused',
'verifying'
])
expect(store.historyDownloads.map((d) => d.status)).toEqual([
'completed',
'failed',
'cancelled'
])
expect(store.activeDownloadCount).toBe(4)
})
it('inserts an optimistic queued row on enqueue', async () => {
vi.mocked(downloadApi.enqueueDownload).mockResolvedValue({
download_id: 'new-id',
accepted: true
})
const store = useModelDownloadStore()
const result = await store.enqueue({
url: 'https://huggingface.co/x.safetensors',
model_id: 'loras/x.safetensors'
})
expect(result.download_id).toBe('new-id')
const row = store.downloadList.find((d) => d.download_id === 'new-id')
expect(row?.status).toBe('queued')
expect(row?.model_id).toBe('loras/x.safetensors')
})
it('optimistically updates status when pausing', async () => {
vi.mocked(downloadApi.pauseDownload).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
await store.pause('d1')
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
'paused'
)
expect(downloadApi.pauseDownload).toHaveBeenCalledWith('d1')
})
it('optimistically updates status when resuming', async () => {
vi.mocked(downloadApi.resumeDownload).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'paused' }))
await store.resume('d1')
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
'queued'
)
expect(downloadApi.resumeDownload).toHaveBeenCalledWith('d1')
})
it('marks a download cancelled after the API call resolves', async () => {
vi.mocked(downloadApi.cancelDownload).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
await store.cancel('d1')
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
'cancelled'
)
expect(downloadApi.cancelDownload).toHaveBeenCalledWith('d1')
})
it('optimistically updates priority and calls the API', async () => {
vi.mocked(downloadApi.setDownloadPriority).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'queued', priority: 0 }))
await store.setPriority('d1', 5)
expect(
store.downloadList.find((d) => d.download_id === 'd1')?.priority
).toBe(5)
expect(downloadApi.setDownloadPriority).toHaveBeenCalledWith('d1', 5)
})
it('is a no-op when patching priority for an unknown id', async () => {
vi.mocked(downloadApi.setDownloadPriority).mockResolvedValue()
const store = useModelDownloadStore()
await store.setPriority('missing', 5)
expect(store.downloadList).toHaveLength(0)
expect(downloadApi.setDownloadPriority).toHaveBeenCalledWith('missing', 5)
})
it('finds a download by model id', () => {
const store = useModelDownloadStore()
dispatch(
createStatus({ download_id: 'd1', model_id: 'loras/x.safetensors' })
)
expect(store.findByModelId('loras/x.safetensors')?.download_id).toBe('d1')
expect(store.findByModelId('loras/missing.safetensors')).toBeUndefined()
})
describe('hydrate', () => {
it('replaces the download map with the fetched list', async () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'stale', status: 'active' }))
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
createStatus({ download_id: 'fresh', status: 'active' })
])
await store.hydrate()
expect(store.downloadList.map((d) => d.download_id)).toEqual(['fresh'])
})
it('records a completion when a refreshed row transitions to completed', async () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
createStatus({
download_id: 'd1',
model_id: 'loras/x.safetensors',
status: 'completed'
})
])
await store.hydrate()
expect(store.lastCompletedDownload).toMatchObject({
downloadId: 'd1',
modelId: 'loras/x.safetensors',
directory: 'loras'
})
})
it('does not re-record a completion for a row that was already completed', async () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'completed' }))
const firstTimestamp = store.lastCompletedDownload?.timestamp
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
createStatus({ download_id: 'd1', status: 'completed' })
])
await store.hydrate()
expect(store.lastCompletedDownload?.timestamp).toBe(firstTimestamp)
})
})
it('records the last completed download once on the completing transition', () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
expect(store.lastCompletedDownload).toBeNull()
dispatch(
createStatus({
download_id: 'd1',
model_id: 'loras/x.safetensors',
status: 'completed'
})
)
expect(store.lastCompletedDownload).toMatchObject({
downloadId: 'd1',
modelId: 'loras/x.safetensors',
directory: 'loras'
})
const firstTimestamp = store.lastCompletedDownload?.timestamp
dispatch(createStatus({ download_id: 'd1', status: 'completed' }))
expect(store.lastCompletedDownload?.timestamp).toBe(firstTimestamp)
})
it('refetches when a download fails without an error message', async () => {
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
createStatus({ download_id: 'd1', status: 'failed', error: 'gated' })
])
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
dispatch(createStatus({ download_id: 'd1', status: 'failed', error: null }))
await flushPromises()
expect(downloadApi.listDownloads).toHaveBeenCalledOnce()
expect(store.downloadList.find((d) => d.download_id === 'd1')?.error).toBe(
'gated'
)
})
it('does not refetch when the failure event already carries an error', async () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
dispatch(
createStatus({ download_id: 'd1', status: 'failed', error: 'disk full' })
)
await flushPromises()
expect(downloadApi.listDownloads).not.toHaveBeenCalled()
expect(store.downloadList.find((d) => d.download_id === 'd1')?.error).toBe(
'disk full'
)
})
it('deletes a row through the backend so it stays gone', async () => {
vi.mocked(downloadApi.deleteDownload).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
await store.remove('d2')
expect(downloadApi.deleteDownload).toHaveBeenCalledWith('d2')
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d1'])
})
it('clears every history row in one backend call, leaving active downloads', async () => {
vi.mocked(downloadApi.clearDownloads).mockResolvedValue(2)
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
dispatch(createStatus({ download_id: 'd3', status: 'failed' }))
await store.clearHistory()
expect(downloadApi.clearDownloads).toHaveBeenCalledOnce()
expect(downloadApi.deleteDownload).not.toHaveBeenCalled()
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d1'])
})
it('keeps history rows locally when the bulk clear fails', async () => {
vi.mocked(downloadApi.clearDownloads).mockRejectedValue(new Error('boom'))
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
await expect(store.clearHistory()).rejects.toThrow('boom')
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d2'])
})
it('keeps the row locally when the backend delete fails', async () => {
vi.mocked(downloadApi.deleteDownload).mockRejectedValue(new Error('boom'))
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
await expect(store.remove('d2')).rejects.toThrow('boom')
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d2'])
})
it('polls the list when active downloads go stale', async () => {
vi.mocked(downloadApi.listDownloads).mockResolvedValue([])
useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
await vi.advanceTimersByTimeAsync(15_000)
expect(downloadApi.listDownloads).toHaveBeenCalled()
})
describe('downloadProgressFraction', () => {
it('uses live progress when present', () => {
expect(downloadProgressFraction(createStatus({ progress: 0.4 }))).toBe(
0.4
)
})
it('derives from bytes when progress is null', () => {
expect(
downloadProgressFraction(
createStatus({ progress: null, bytes_done: 500, total_bytes: 1000 })
)
).toBe(0.5)
})
it('returns null when total is unknown', () => {
expect(
downloadProgressFraction(
createStatus({ progress: null, total_bytes: null })
)
).toBeNull()
})
})
})

View File

@@ -1,229 +0,0 @@
import { useIntervalFn } from '@vueuse/core'
import { defineStore } from 'pinia'
import { computed, ref } from 'vue'
import { api } from '@/scripts/api'
import {
cancelDownload,
clearDownloads,
deleteDownload,
enqueueDownload,
listDownloads,
pauseDownload,
resumeDownload,
setDownloadPriority
} from '../api/modelDownloadApi'
import type {
DownloadState,
DownloadStatus,
EnqueueRequest,
EnqueueResponse
} from '../types'
import { directoryOf } from '../utils/modelId'
const ACTIVE_STATES: ReadonlySet<DownloadState> = new Set([
'queued',
'active',
'paused',
'verifying'
])
const POLL_INTERVAL_MS = 5_000
const STALE_THRESHOLD_MS = 8_000
/**
* Static completion fraction (0..1) for a download. Live values (`progress`)
* are only populated while a worker is running; for paused/terminal rows we
* derive it from `bytes_done / total_bytes`.
*/
export function downloadProgressFraction(
download: DownloadStatus
): number | null {
if (download.progress != null) return download.progress
if (download.total_bytes && download.total_bytes > 0) {
return download.bytes_done / download.total_bytes
}
return null
}
function optimisticRow(
downloadId: string,
request: EnqueueRequest
): DownloadStatus {
const now = Math.floor(Date.now() / 1000)
return {
download_id: downloadId,
model_id: request.model_id,
url: request.url,
status: 'queued',
priority: request.priority ?? 0,
total_bytes: null,
bytes_done: 0,
progress: null,
speed_bps: null,
eta_seconds: null,
segments: null,
error: null,
created_at: now,
updated_at: now
}
}
export interface CompletedDownload {
downloadId: string
modelId: string
directory: string
timestamp: number
}
export const useModelDownloadStore = defineStore('modelDownload', () => {
const downloads = ref<Map<string, DownloadStatus>>(new Map())
const lastWsUpdate = ref(0)
const lastCompletedDownload = ref<CompletedDownload | null>(null)
const downloadList = computed(() => Array.from(downloads.value.values()))
const activeDownloads = computed(() =>
downloadList.value.filter((d) => ACTIVE_STATES.has(d.status))
)
const historyDownloads = computed(() =>
downloadList.value.filter((d) => !ACTIVE_STATES.has(d.status))
)
const hasActiveDownloads = computed(() => activeDownloads.value.length > 0)
const activeDownloadCount = computed(() => activeDownloads.value.length)
function upsert(status: DownloadStatus) {
downloads.value.set(status.download_id, status)
}
function findByModelId(modelId: string): DownloadStatus | undefined {
return downloadList.value.find((d) => d.model_id === modelId)
}
function handleProgress(e: CustomEvent<DownloadStatus>) {
lastWsUpdate.value = Date.now()
const previous = downloads.value.get(e.detail.download_id)
upsert(e.detail)
if (e.detail.status === 'completed' && previous?.status !== 'completed') {
lastCompletedDownload.value = {
downloadId: e.detail.download_id,
modelId: e.detail.model_id,
directory: directoryOf(e.detail.model_id),
timestamp: Date.now()
}
}
// The live failure event carries the terminal status but not the error
// text, and polling stops once nothing is active. Refetch so the failure
// reason surfaces without the user reopening the panel.
if (
e.detail.status === 'failed' &&
previous?.status !== 'failed' &&
!e.detail.error
) {
void hydrate().catch(() => {})
}
}
async function hydrate() {
const previous = downloads.value
const list = await listDownloads()
downloads.value = new Map(list.map((d) => [d.download_id, d]))
for (const download of list) {
const prior = previous.get(download.download_id)
if (
download.status === 'completed' &&
prior &&
prior.status !== 'completed'
) {
lastCompletedDownload.value = {
downloadId: download.download_id,
modelId: download.model_id,
directory: directoryOf(download.model_id),
timestamp: Date.now()
}
}
}
}
async function enqueue(request: EnqueueRequest): Promise<EnqueueResponse> {
const response = await enqueueDownload(request)
upsert(optimisticRow(response.download_id, request))
return response
}
function patchStatus(id: string, status: DownloadState) {
const existing = downloads.value.get(id)
if (existing) {
downloads.value.set(id, { ...existing, status })
}
}
async function pause(id: string) {
patchStatus(id, 'paused')
await pauseDownload(id)
}
async function resume(id: string) {
patchStatus(id, 'queued')
await resumeDownload(id)
}
async function cancel(id: string) {
await cancelDownload(id)
patchStatus(id, 'cancelled')
}
async function setPriority(id: string, priority: number) {
const existing = downloads.value.get(id)
if (existing) {
downloads.value.set(id, { ...existing, priority })
}
await setDownloadPriority(id, priority)
}
async function remove(id: string) {
await deleteDownload(id)
downloads.value.delete(id)
}
async function clearHistory() {
const clearedIds = historyDownloads.value.map((d) => d.download_id)
await clearDownloads()
for (const id of clearedIds) {
downloads.value.delete(id)
}
}
async function pollIfStale() {
if (!hasActiveDownloads.value) return
if (Date.now() - lastWsUpdate.value < STALE_THRESHOLD_MS) return
try {
await hydrate()
} catch {
// Server unreachable; retry on next interval
}
}
useIntervalFn(() => void pollIfStale(), POLL_INTERVAL_MS)
api.addEventListener('download_progress', handleProgress)
return {
downloadList,
activeDownloads,
historyDownloads,
hasActiveDownloads,
activeDownloadCount,
lastCompletedDownload,
upsert,
findByModelId,
hydrate,
enqueue,
pause,
resume,
cancel,
setPriority,
remove,
clearHistory
}
})

View File

@@ -1,130 +0,0 @@
import { z } from 'zod'
import type { DownloadState, DownloadStatus } from '@/schemas/apiSchema'
export type { DownloadState, DownloadStatus }
/**
* Known model file extensions accepted by the backend without
* `allow_any_extension`. Mirrors the server's extension allowlist and is
* used for instant client-side validation only — the server stays the source
* of truth.
*/
export const MODEL_EXTENSIONS = [
'.safetensors',
'.sft',
'.ckpt',
'.pth',
'.pt',
'.gguf',
'.bin'
] as const
/**
* Hosts the backend allows out of the box. Admins can extend this
* server-side, so this list is only for optimistic client-side hints; rely on
* the server's `URL_NOT_ALLOWED` / `url_allowed` as the source of truth.
*/
export const DEFAULT_ALLOWED_HOSTS = [
'huggingface.co',
'civitai.com',
'localhost',
'127.0.0.1'
] as const
export interface EnqueueRequest {
url: string
model_id: string
priority?: number
expected_sha256?: string | null
allow_any_extension?: boolean
credential_id?: string | null
}
export interface EnqueueResponse {
download_id: string
accepted: boolean
}
export const AUTH_SCHEMES = ['bearer', 'header', 'query'] as const
export type AuthScheme = (typeof AUTH_SCHEMES)[number]
const zHostCredentialView = z.object({
id: z.string(),
host: z.string(),
auth_scheme: z.enum(AUTH_SCHEMES),
header_name: z.string().nullable(),
query_param: z.string().nullable(),
label: z.string().nullable(),
match_subdomains: z.boolean(),
enabled: z.boolean(),
secret_last4: z.string().nullable(),
created_at: z.number(),
updated_at: z.number()
})
export type HostCredentialView = z.infer<typeof zHostCredentialView>
export interface HostCredentialUpsert {
host: string
secret: string
auth_scheme?: AuthScheme
header_name?: string | null
query_param?: string | null
label?: string | null
match_subdomains?: boolean
enabled?: boolean
}
interface AvailabilityBase {
url_allowed: boolean
}
export type AvailabilityEntry =
| (AvailabilityBase & { state: 'available' })
| (AvailabilityBase & { state: 'missing' })
| (AvailabilityBase & {
state: 'downloading'
download_id: string
progress: number | null
bytes_done: number
total_bytes: number | null
speed_bps: number | null
})
export interface AvailabilityResponse {
models: Record<string, AvailabilityEntry>
}
const DOWNLOAD_ERROR_CODES = [
'INVALID_JSON',
'INVALID_BODY',
'URL_NOT_ALLOWED',
'INVALID_MODEL_ID',
'INVALID_CREDENTIAL',
'ALREADY_AVAILABLE',
'ALREADY_DOWNLOADING',
'DOWNLOAD_ACTIVE',
'NOT_FOUND'
] as const
export type DownloadErrorCode = (typeof DOWNLOAD_ERROR_CODES)[number]
/**
* Error envelope returned by every download-manager endpoint on failure.
* `code` is the stable machine-readable discriminator; `message` is
* user-facing; `details` is an open object (do not assume a shape).
*/
export class DownloadApiError extends Error {
constructor(
message: string,
public readonly code: string,
public readonly status: number,
public readonly details?: Record<string, unknown>
) {
super(message)
this.name = 'DownloadApiError'
}
is(code: DownloadErrorCode): boolean {
return this.code === code
}
}

View File

@@ -1,99 +0,0 @@
import { describe, expect, it } from 'vitest'
import {
buildModelId,
filenameFromUrl,
hasModelExtension,
hostFromUrl,
isLikelyAllowedHost,
isValidPathSegment
} from './modelId'
describe('modelId utils', () => {
describe('isValidPathSegment', () => {
it('accepts valid filenames', () => {
expect(isValidPathSegment('my_lora.safetensors')).toBe(true)
expect(isValidPathSegment('sdxl-base.ckpt')).toBe(true)
})
it('rejects traversal, leading dots, and slashes', () => {
expect(isValidPathSegment('..')).toBe(false)
expect(isValidPathSegment('.hidden')).toBe(false)
expect(isValidPathSegment('nested/path')).toBe(false)
expect(isValidPathSegment('')).toBe(false)
})
})
describe('hasModelExtension', () => {
it('matches known extensions case-insensitively', () => {
expect(hasModelExtension('model.SAFETENSORS')).toBe(true)
expect(hasModelExtension('weights.gguf')).toBe(true)
})
it('rejects unknown extensions', () => {
expect(hasModelExtension('readme.txt')).toBe(false)
expect(hasModelExtension('archive.zip')).toBe(false)
})
})
describe('buildModelId', () => {
it('joins directory and filename with a single slash', () => {
expect(buildModelId('loras', 'x.safetensors')).toBe('loras/x.safetensors')
})
it('collapses redundant slashes at the join boundary', () => {
expect(buildModelId('loras/', 'x.safetensors')).toBe(
'loras/x.safetensors'
)
expect(buildModelId('loras', '/x.safetensors')).toBe(
'loras/x.safetensors'
)
expect(buildModelId('loras//', '//x.safetensors')).toBe(
'loras/x.safetensors'
)
})
})
describe('hostFromUrl', () => {
it('extracts the lowercased host', () => {
expect(hostFromUrl('https://HuggingFace.co/a/b')).toBe('huggingface.co')
})
it('returns null for invalid URLs', () => {
expect(hostFromUrl('not a url')).toBeNull()
})
})
describe('isLikelyAllowedHost', () => {
it('allows built-in hosts and their subdomains', () => {
expect(isLikelyAllowedHost('https://huggingface.co/x')).toBe(true)
expect(isLikelyAllowedHost('https://cdn.huggingface.co/x')).toBe(true)
expect(isLikelyAllowedHost('https://civitai.com/api/download/1')).toBe(
true
)
})
it('flags unknown hosts', () => {
expect(isLikelyAllowedHost('https://example.com/x')).toBe(false)
expect(isLikelyAllowedHost('garbage')).toBe(false)
})
it('does not treat a fake subdomain of an allowed IP literal as allowed', () => {
expect(isLikelyAllowedHost('https://127.0.0.1/x')).toBe(true)
expect(isLikelyAllowedHost('https://evil.127.0.0.1/x')).toBe(false)
})
})
describe('filenameFromUrl', () => {
it('returns the decoded trailing path segment', () => {
expect(filenameFromUrl('https://h.co/a/my%20model.safetensors')).toBe(
'my model.safetensors'
)
})
it('returns empty string when no segment is present', () => {
expect(filenameFromUrl('https://h.co/')).toBe('')
expect(filenameFromUrl('bad')).toBe('')
})
})
})

View File

@@ -1,79 +0,0 @@
import { DEFAULT_ALLOWED_HOSTS, MODEL_EXTENSIONS } from '../types'
const SEGMENT_PATTERN = /^[A-Za-z0-9][A-Za-z0-9._-]*$/
export function isValidPathSegment(segment: string): boolean {
return SEGMENT_PATTERN.test(segment)
}
export function hasModelExtension(filename: string): boolean {
const lower = filename.toLowerCase()
return MODEL_EXTENSIONS.some((ext) => lower.endsWith(ext))
}
export function buildModelId(directory: string, filename: string): string {
return `${directory.replace(/\/+$/, '')}/${filename.replace(/^\/+/, '')}`
}
/**
* Directory portion of a `directory/filename` model id, or an empty string when
* the id has no directory prefix.
*/
export function directoryOf(modelId: string): string {
const slash = modelId.indexOf('/')
return slash === -1 ? '' : modelId.slice(0, slash)
}
/**
* Filename portion of a `directory/filename` model id, falling back to the full
* id when there is no directory prefix.
*/
export function filenameOf(modelId: string): string {
const slash = modelId.indexOf('/')
return slash === -1 ? modelId : modelId.slice(slash + 1)
}
/**
* Lowercased host of a URL, or `null` when the URL is unparseable.
*/
export function hostFromUrl(url: string): string | null {
try {
return new URL(url).hostname.toLowerCase()
} catch {
return null
}
}
const IPV4_PATTERN = /^\d{1,3}(?:\.\d{1,3}){3}$/
function isIpLiteral(host: string): boolean {
return IPV4_PATTERN.test(host) || host.includes(':')
}
/**
* Optimistic client-side allowlist hint. The server can extend the
* allowlist, so a `false` here is advisory — defer to `URL_NOT_ALLOWED`.
*/
export function isLikelyAllowedHost(url: string): boolean {
const host = hostFromUrl(url)
if (!host) return false
return DEFAULT_ALLOWED_HOSTS.some(
(allowed) =>
host === allowed ||
(!isIpLiteral(allowed) && host.endsWith(`.${allowed}`))
)
}
/**
* Best-effort filename guess from a URL path, for prefilling the model_id
* filename field. May be empty when the URL has no usable trailing segment.
*/
export function filenameFromUrl(url: string): string {
try {
const { pathname } = new URL(url)
const last = pathname.split('/').filter(Boolean).pop() ?? ''
return decodeURIComponent(last)
} catch {
return ''
}
}

View File

@@ -30,6 +30,39 @@ describe('TelemetryRegistry', () => {
expect(b.trackSearchQuery).toHaveBeenCalledExactlyOnceWith(payload)
})
it('dispatches trackBeginCheckout with intent metadata to every provider', () => {
const a: TelemetryProvider = { trackBeginCheckout: vi.fn() }
const b: TelemetryProvider = {}
const registry = new TelemetryRegistry()
registry.registerProvider(a)
registry.registerProvider(b)
const metadata = {
user_id: 'user-1',
tier: 'pro' as const,
cycle: 'monthly' as const,
checkout_type: 'new' as const,
payment_intent_source: 'subscribe_to_run' as const
}
registry.trackBeginCheckout(metadata)
expect(a.trackBeginCheckout).toHaveBeenCalledExactlyOnceWith(metadata)
})
it('dispatches trackAddApiCreditButtonClicked with its source', () => {
const provider: TelemetryProvider = {
trackAddApiCreditButtonClicked: vi.fn()
}
const registry = new TelemetryRegistry()
registry.registerProvider(provider)
registry.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
expect(
provider.trackAddApiCreditButtonClicked
).toHaveBeenCalledExactlyOnceWith({ source: 'credits_panel' })
})
it('skips providers that do not implement trackSearchQuery', () => {
const empty: TelemetryProvider = {}
const registry = new TelemetryRegistry()

View File

@@ -1,6 +1,7 @@
import type { AuditLog } from '@/services/customerEventsService'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
@@ -99,8 +100,10 @@ export class TelemetryRegistry implements TelemetryDispatcher {
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
}
trackAddApiCreditButtonClicked(): void {
this.dispatch((provider) => provider.trackAddApiCreditButtonClicked?.())
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.dispatch((provider) =>
provider.trackAddApiCreditButtonClicked?.(metadata)
)
}
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {

View File

@@ -313,6 +313,42 @@ describe('PostHogTelemetryProvider', () => {
)
})
it('captures begin_checkout with intent metadata', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()
provider.trackBeginCheckout({
user_id: 'user-1',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
expect(hoisted.mockCapture).toHaveBeenCalledWith(
TelemetryEvents.BEGIN_CHECKOUT,
{
user_id: 'user-1',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
}
)
})
it('captures add-credit clicks with their source', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()
provider.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
expect(hoisted.mockCapture).toHaveBeenCalledWith(
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
{ source: 'credits_panel' }
)
})
it('captures share attribution events', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()

View File

@@ -10,7 +10,9 @@ import { remoteConfig } from '@/platform/remoteConfig/remoteConfig'
import type { RemoteConfig } from '@/platform/remoteConfig/types'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
EnterLinearMetadata,
ShareFlowMetadata,
@@ -350,8 +352,12 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
this.trackEvent(eventName, metadata)
}
trackAddApiCreditButtonClicked(): void {
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
}
trackBeginCheckout(metadata: BeginCheckoutMetadata): void {
this.trackEvent(TelemetryEvents.BEGIN_CHECKOUT, metadata)
}
trackMonthlySubscriptionSucceeded(

View File

@@ -115,6 +115,17 @@ describe('HostTelemetrySink', () => {
)
})
it('forwards add-credit clicks with their source', () => {
new HostTelemetrySink().trackAddApiCreditButtonClicked({
source: 'avatar_menu'
})
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
{ source: 'avatar_menu' }
)
})
it('does nothing when the host bridge is absent', () => {
delete window.__comfyDesktop2

View File

@@ -10,6 +10,7 @@ import {
import type { AuditLog } from '@/services/customerEventsService'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
@@ -126,8 +127,8 @@ export class HostTelemetrySink implements TelemetryProvider {
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
}
trackAddApiCreditButtonClicked(): void {
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
}
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {

View File

@@ -12,12 +12,29 @@
* 3. Check dist/assets/*.js files contain no tracking code
*/
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { AuditLog } from '@/services/customerEventsService'
import type { AppMode } from '@/utils/appMode'
export type PaymentIntentSource =
| 'subscription_required'
| 'out_of_credits'
| 'top_up_blocked'
| 'deep_link'
| 'subscribe_to_run'
| 'subscribe_now_button'
| 'upgrade_to_add_credits'
| 'settings_billing_panel'
| 'avatar_menu_plans'
| 'team_members_panel'
| 'invite_member_upsell'
| 'upload_model_upgrade'
| 'team_upgrade_resume'
export type SubscriptionCheckoutType = 'new' | 'change'
export type SubscriptionCheckoutTier = TierKey | 'team'
/**
* Authentication metadata for sign-up tracking
*/
@@ -426,16 +443,23 @@ export interface CheckoutAttributionMetadata {
export interface SubscriptionMetadata {
current_tier?: string
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
}
export interface AddCreditsClickMetadata {
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
}
export interface BeginCheckoutMetadata
extends Record<string, unknown>, CheckoutAttributionMetadata {
user_id: string
tier: TierKey
tier: SubscriptionCheckoutTier
cycle: BillingCycle
checkout_type: 'new' | 'change'
checkout_type: SubscriptionCheckoutType
checkout_attempt_id?: string
billing_op_id?: string
previous_tier?: TierKey
payment_intent_source?: PaymentIntentSource
}
interface EcommerceItemMetadata {
@@ -457,8 +481,9 @@ export interface SubscriptionSuccessMetadata extends Record<string, unknown> {
checkout_attempt_id: string
tier: TierKey
cycle: BillingCycle
checkout_type: 'new' | 'change'
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
payment_intent_source?: PaymentIntentSource
value: number
currency: string
ecommerce: EcommerceMetadata
@@ -489,7 +514,7 @@ export interface TelemetryProvider {
metadata?: SubscriptionSuccessMetadata
): void
trackMonthlySubscriptionCancelled?(): void
trackAddApiCreditButtonClicked?(): void
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
trackApiCreditTopupSucceeded?(): void
trackWorkspaceInviteSent?(metadata: WorkspaceInviteMetadata): void

View File

@@ -321,7 +321,7 @@ const handleOpenWorkspaceSettings = () => {
}
const handleOpenPlansAndPricing = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
emit('close')
}
@@ -336,13 +336,12 @@ const handleOpenPlanAndCreditsSettings = () => {
}
const handleUpgradeToAddCredits = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
emit('close')
}
const handleTopUp = () => {
// Track purchase credits entry from avatar popover
useTelemetry()?.trackAddApiCreditButtonClicked()
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
dialogService.showTopUpCreditsDialog()
emit('close')
}

View File

@@ -391,12 +391,13 @@ const showZeroState = computed(
)
function handleSubscribeWorkspace() {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'settings_billing_panel' })
}
function handleUpgrade() {
if (isFreeTierPlan.value) showPricingTable()
else showSubscriptionDialog()
if (isFreeTierPlan.value)
showPricingTable({ reason: 'settings_billing_panel' })
else showSubscriptionDialog({ reason: 'settings_billing_panel' })
}
function handleViewMoreDetails() {

View File

@@ -113,7 +113,7 @@ import { cn } from '@comfyorg/tailwind-utils'
import { useEventListener } from '@vueuse/core'
import Button from '@/components/ui/button/Button.vue'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
import SubscriptionAddPaymentPreviewWorkspace from './SubscriptionAddPaymentPreviewWorkspace.vue'
@@ -123,7 +123,7 @@ import UnifiedPricingTable from './UnifiedPricingTable.vue'
const { onClose, reason, initialPlanMode } = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
initialPlanMode?: 'personal' | 'team'
}>()
@@ -152,7 +152,7 @@ const {
handleConfirmTransition,
handleTeamSubscribe,
handleResubscribe
} = useSubscriptionCheckout(emit)
} = useSubscriptionCheckout(emit, reason)
// Backspace mirrors the back arrow on the confirm step, but never while an
// editable element is focused (let it delete text there).

View File

@@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { ref } from 'vue'
import { createI18n } from 'vue-i18n'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import SubscriptionRequiredDialogContentWorkspace from './SubscriptionRequiredDialogContentWorkspace.vue'
@@ -17,25 +17,10 @@ const mockHandleResubscribe = vi.fn()
const mockHandleSuccessClose = vi.fn()
const mockCheckoutStep = ref<'pricing' | 'preview' | 'success'>('pricing')
const mockPreviewData = ref<{ transition_type: string } | null>(null)
const mockUseSubscriptionCheckout = vi.hoisted(() => vi.fn())
vi.mock('@/platform/workspace/composables/useSubscriptionCheckout', () => ({
useSubscriptionCheckout: () => ({
checkoutStep: mockCheckoutStep,
isLoadingPreview: ref(false),
loadingTier: ref(null),
isSubscribing: ref(false),
isResubscribing: ref(false),
previewData: mockPreviewData,
selectedTierKey: ref('standard'),
selectedBillingCycle: ref('yearly'),
isPolling: ref(false),
handleSubscribeClick: mockHandleSubscribeClick,
handleBackToPricing: mockHandleBackToPricing,
handleAddCreditCard: mockHandleAddCreditCard,
handleConfirmTransition: mockHandleConfirmTransition,
handleResubscribe: mockHandleResubscribe,
handleSuccessClose: mockHandleSuccessClose
})
useSubscriptionCheckout: mockUseSubscriptionCheckout
}))
const i18n = createI18n({
@@ -91,7 +76,7 @@ const SuccessStub = {
function renderComponent(
props: {
onClose?: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
isPersonal?: boolean
} = {}
) {
@@ -121,6 +106,23 @@ function renderComponent(
describe('SubscriptionRequiredDialogContentWorkspace', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseSubscriptionCheckout.mockReturnValue({
checkoutStep: mockCheckoutStep,
isLoadingPreview: ref(false),
loadingTier: ref(null),
isSubscribing: ref(false),
isResubscribing: ref(false),
previewData: mockPreviewData,
selectedTierKey: ref('standard'),
selectedBillingCycle: ref('yearly'),
isPolling: ref(false),
handleSubscribeClick: mockHandleSubscribeClick,
handleBackToPricing: mockHandleBackToPricing,
handleAddCreditCard: mockHandleAddCreditCard,
handleConfirmTransition: mockHandleConfirmTransition,
handleResubscribe: mockHandleResubscribe,
handleSuccessClose: mockHandleSuccessClose
})
mockCheckoutStep.value = 'pricing'
mockPreviewData.value = null
})
@@ -132,6 +134,15 @@ describe('SubscriptionRequiredDialogContentWorkspace', () => {
expect(screen.queryByTestId('transition-preview')).not.toBeInTheDocument()
})
it('passes the reason into subscription checkout', () => {
renderComponent({ reason: 'out_of_credits' })
expect(mockUseSubscriptionCheckout).toHaveBeenCalledWith(
expect.any(Function),
'out_of_credits'
)
})
it('shows the team workspace header by default', () => {
renderComponent()
expect(screen.getByText('Team Workspace')).toBeInTheDocument()

View File

@@ -116,7 +116,7 @@
import { cn } from '@comfyorg/tailwind-utils'
import Button from '@/components/ui/button/Button.vue'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
import PricingTableWorkspace from './PricingTableWorkspace.vue'
@@ -130,7 +130,7 @@ const {
isPersonal = false
} = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
isPersonal?: boolean
}>()
@@ -154,7 +154,7 @@ const {
handleConfirmTransition,
handleResubscribe,
handleSuccessClose
} = useSubscriptionCheckout(emit)
} = useSubscriptionCheckout(emit, reason)
</script>
<style scoped>

View File

@@ -61,6 +61,9 @@ function onDismiss() {
function onUpgrade() {
dialogStore.closeDialog({ key: 'invite-member-upsell' })
subscriptionDialog.show({ planMode: 'team' })
subscriptionDialog.show({
planMode: 'team',
reason: 'invite_member_upsell'
})
}
</script>

View File

@@ -277,7 +277,7 @@ export function useMembersPanel() {
}
function showTeamPlans() {
subscriptionDialog.show({ planMode: 'team' })
subscriptionDialog.show({ planMode: 'team', reason: 'team_members_panel' })
}
return {

View File

@@ -1,8 +1,9 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { computed } from 'vue'
import { computed, reactive } from 'vue'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import type { Plan } from '@/platform/workspace/api/workspaceApi'
import { findPlanSlug } from './useSubscriptionCheckout'
@@ -75,7 +76,9 @@ const {
mockPlans,
mockResubscribe,
mockToastAdd,
mockStartOperation
mockStartOperation,
mockTrackBeginCheckout,
mockUserId
} = vi.hoisted(() => ({
mockSubscribe: vi.fn(),
mockPreviewSubscribe: vi.fn(),
@@ -84,7 +87,9 @@ const {
mockPlans: { value: [] as Plan[] },
mockResubscribe: vi.fn(),
mockToastAdd: vi.fn(),
mockStartOperation: vi.fn()
mockStartOperation: vi.fn(),
mockTrackBeginCheckout: vi.fn(),
mockUserId: { value: 'user-1' as string | null }
}))
vi.mock('@/composables/billing/useBillingContext', () => ({
@@ -119,7 +124,14 @@ vi.mock('primevue/usetoast', () => ({
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackMonthlySubscriptionSucceeded: vi.fn() })
useTelemetry: () => ({
trackMonthlySubscriptionSucceeded: vi.fn(),
trackBeginCheckout: mockTrackBeginCheckout
})
}))
vi.mock('@/stores/authStore', () => ({
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
}))
vi.mock('vue-i18n', async (importOriginal) => {
@@ -135,10 +147,10 @@ vi.mock('vue-i18n', async (importOriginal) => {
describe('useSubscriptionCheckout', () => {
let emit: ReturnType<typeof vi.fn>
async function setup() {
async function setup(paymentIntentSource?: PaymentIntentSource) {
const { useSubscriptionCheckout } =
await import('./useSubscriptionCheckout')
return useSubscriptionCheckout(emit as never)
return useSubscriptionCheckout(emit as never, paymentIntentSource)
}
beforeEach(() => {
@@ -146,6 +158,7 @@ describe('useSubscriptionCheckout', () => {
vi.clearAllMocks()
mockPlans.value = allPlans()
mockStartOperation.mockResolvedValue({ status: 'succeeded' })
mockUserId.value = 'user-1'
emit = vi.fn()
})
@@ -459,6 +472,13 @@ describe('useSubscriptionCheckout', () => {
cancelUrl: 'https://platform.comfy.org/payment/failed'
})
expect(checkout.checkoutStep.value).toBe('success')
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
tier: 'team',
checkout_type: 'new',
billing_op_id: 'op-team-1'
})
)
})
it('uses the annual plan slug for the yearly cycle', async () => {
@@ -553,6 +573,39 @@ describe('useSubscriptionCheckout', () => {
detail: 'Team payment failed'
})
)
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('keeps team checkout_type as change when the preview request fails', async () => {
const checkout = await setup()
mockPreviewSubscribe.mockRejectedValueOnce(new Error('not supported'))
await checkout.handleSubscribeTeamClick({
stop: {
id: 'team_1400',
usd: 1400,
credits: 295_400,
discountedUsd: 1295
},
billingCycle: 'monthly',
isChange: true
})
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-team-change'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleTeamSubscribe()
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
tier: 'team',
cycle: 'monthly',
checkout_type: 'change',
billing_op_id: 'op-team-change'
})
)
})
})
@@ -603,6 +656,47 @@ describe('useSubscriptionCheckout', () => {
expect(checkout.checkoutStep.value).toBe('success')
})
it('skips begin_checkout when no user id is available', async () => {
mockUserId.value = null
const checkout = await setup('subscribe_to_run')
checkout.selectedTierKey.value = 'standard'
checkout.selectedBillingCycle.value = 'yearly'
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-1'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleAddCreditCard()
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
mockUserId.value = 'user-1'
})
it('fires begin_checkout carrying the payment intent source', async () => {
const checkout = await setup('subscribe_to_run')
checkout.selectedTierKey.value = 'standard'
checkout.selectedBillingCycle.value = 'yearly'
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-1'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleAddCreditCard()
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-1',
tier: 'standard',
cycle: 'yearly',
checkout_type: 'new',
billing_op_id: 'op-1',
payment_intent_source: 'subscribe_to_run'
})
})
it('opens payment URL when needs_payment_method', async () => {
const checkout = await setup()
checkout.selectedTierKey.value = 'standard'
@@ -720,6 +814,7 @@ describe('useSubscriptionCheckout', () => {
detail: 'Payment failed'
})
)
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
})

View File

@@ -9,16 +9,26 @@ import type { TeamPlanSelection } from '@/platform/cloud/subscription/constants/
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { useTelemetry } from '@/platform/telemetry'
import type {
PaymentIntentSource,
SubscriptionCheckoutType
} from '@/platform/telemetry/types'
import type {
Plan,
PreviewSubscribeResponse,
SubscribeResponse
} from '@/platform/workspace/api/workspaceApi'
import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore'
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
type CheckoutStep = 'pricing' | 'preview' | 'success'
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
interface SelectedTeamCheckout {
stop: TeamPlanSelection
checkoutType: SubscriptionCheckoutType
}
/**
* Which screen the `preview` step shows. Only a change prorates: a team change
* carries `previewData` (handleSubscribeTeamClick sets it solely for an immediate
@@ -45,9 +55,12 @@ export function findPlanSlug(
return plan?.slug ?? null
}
export function useSubscriptionCheckout(emit: {
(e: 'close', subscribed: boolean): void
}) {
export function useSubscriptionCheckout(
emit: {
(e: 'close', subscribed: boolean): void
},
paymentIntentSource?: PaymentIntentSource
) {
const { t } = useI18n()
const toast = useToast()
const {
@@ -68,13 +81,16 @@ export function useSubscriptionCheckout(emit: {
const isResubscribing = ref(false)
const previewData = ref<PreviewSubscribeResponse | null>(null)
const selectedTierKey = ref<CheckoutTierKey | null>(null)
const selectedTeamStop = ref<TeamPlanSelection | null>(null)
const selectedTeamCheckout = ref<SelectedTeamCheckout | null>(null)
const selectedBillingCycle = ref<BillingCycle>('yearly')
const isPolling = computed(() => billingOperationStore.hasPendingOperations)
const isTeamCheckout = computed(() => selectedTeamStop.value !== null)
const selectedTeamStop = computed(
() => selectedTeamCheckout.value?.stop ?? null
)
const isTeamCheckout = computed(() => selectedTeamCheckout.value !== null)
const previewVariant = computed<PreviewVariant>(() => {
if (selectedTeamStop.value) {
if (selectedTeamCheckout.value) {
return previewData.value ? 'team-change' : 'team-new'
}
if (previewData.value) {
@@ -154,7 +170,10 @@ export function useSubscriptionCheckout(emit: {
billingCycle: BillingCycle
isChange?: boolean
}) {
selectedTeamStop.value = payload.stop
selectedTeamCheckout.value = {
stop: payload.stop,
checkoutType: payload.isChange ? 'change' : 'new'
}
selectedBillingCycle.value = payload.billingCycle
selectedTierKey.value = null
previewData.value = null
@@ -182,7 +201,7 @@ export function useSubscriptionCheckout(emit: {
function handleBackToPricing() {
checkoutStep.value = 'pricing'
previewData.value = null
selectedTeamStop.value = null
selectedTeamCheckout.value = null
}
function handleSuccessClose() {
@@ -190,20 +209,34 @@ export function useSubscriptionCheckout(emit: {
}
async function handleSubscription() {
if (!selectedTierKey.value) return
const tierKey = selectedTierKey.value
if (!tierKey) return
const billingCycle = selectedBillingCycle.value
const checkoutType =
previewData.value &&
previewData.value.transition_type !== 'new_subscription'
? 'change'
: 'new'
isSubscribing.value = true
try {
const planSlug = getApiPlanSlug(
selectedTierKey.value,
selectedBillingCycle.value
)
const planSlug = getApiPlanSlug(tierKey, billingCycle)
if (!planSlug) return
const response = await subscribe(planSlug, {
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
})
if (response) {
trackWorkspaceCheckoutStarted({
tier: tierKey,
cycle: billingCycle,
checkoutType,
billingOpId: response.billing_op_id,
paymentIntentSource
})
}
await handleSubscribeResponse(response)
} catch (error) {
showSubscribeError(error)
@@ -269,8 +302,8 @@ export function useSubscriptionCheckout(emit: {
}
async function handleTeamSubscription() {
const stop = selectedTeamStop.value
if (!stop?.id) {
const teamCheckout = selectedTeamCheckout.value
if (!teamCheckout?.stop.id) {
toast.add({
severity: 'error',
summary: t('subscription.teamPlan.name'),
@@ -279,16 +312,28 @@ export function useSubscriptionCheckout(emit: {
return
}
const { stop, checkoutType } = teamCheckout
const billingCycle = selectedBillingCycle.value
isSubscribing.value = true
try {
const planSlug = getTeamPlanSlug(selectedBillingCycle.value)
const planSlug = getTeamPlanSlug(billingCycle)
const response = await subscribe(planSlug, {
teamCreditStopId: stop.id,
billingCycle: selectedBillingCycle.value,
billingCycle,
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
})
if (response) {
trackWorkspaceCheckoutStarted({
tier: 'team',
cycle: billingCycle,
checkoutType,
billingOpId: response.billing_op_id,
paymentIntentSource
})
}
await handleSubscribeResponse(response)
} catch (error) {
showSubscribeError(error)

View File

@@ -2,6 +2,7 @@ import { computed, ref, shallowRef } from 'vue'
import { useBillingPlans } from '@/platform/cloud/subscription/composables/useBillingPlans'
import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
BillingBalanceResponse,
BillingStatusResponse,
@@ -275,12 +276,12 @@ export function useWorkspaceBilling(): BillingState & BillingActions {
async function requireActiveSubscription(): Promise<void> {
await fetchStatus()
if (!isActiveSubscription.value) {
subscriptionDialog.show()
subscriptionDialog.show({ reason: 'subscription_required' })
}
}
function showSubscriptionDialog(): void {
subscriptionDialog.show()
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
subscriptionDialog.show(options)
}
return {

View File

@@ -0,0 +1,38 @@
import { useTelemetry } from '@/platform/telemetry'
import type {
PaymentIntentSource,
SubscriptionCheckoutTier,
SubscriptionCheckoutType
} from '@/platform/telemetry/types'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { useAuthStore } from '@/stores/authStore'
interface TrackWorkspaceCheckoutStartedOptions {
tier: SubscriptionCheckoutTier
cycle: BillingCycle
checkoutType: SubscriptionCheckoutType
billingOpId: string
paymentIntentSource?: PaymentIntentSource
}
export function trackWorkspaceCheckoutStarted({
tier,
cycle,
checkoutType,
billingOpId,
paymentIntentSource
}: TrackWorkspaceCheckoutStartedOptions) {
const { userId } = useAuthStore()
if (!userId) return
useTelemetry()?.trackBeginCheckout({
user_id: userId,
tier,
cycle,
checkout_type: checkoutType,
billing_op_id: billingOpId,
...(paymentIntentSource
? { payment_intent_source: paymentIntentSource }
: {})
})
}

View File

@@ -154,39 +154,6 @@ const zAssetDownloadWsMessage = z.object({
error: z.string().optional()
})
const DOWNLOAD_STATES = [
'queued',
'active',
'paused',
'verifying',
'completed',
'failed',
'cancelled'
] as const
const zDownloadSegment = z.object({
idx: z.number().int().nonnegative(),
bytes_done: z.number().int().nonnegative(),
length: z.number().int().nonnegative()
})
const zDownloadStatus = z.object({
download_id: z.string(),
model_id: z.string(),
url: z.string(),
status: z.enum(DOWNLOAD_STATES),
priority: z.number().int(),
total_bytes: z.number().int().nonnegative().nullable(),
bytes_done: z.number().int().nonnegative(),
progress: z.number().nullable(),
speed_bps: z.number().nonnegative().nullable(),
eta_seconds: z.number().nonnegative().nullable(),
segments: z.array(zDownloadSegment).nullable(),
error: z.string().nullable(),
created_at: z.number(),
updated_at: z.number()
})
const zAssetExportWsMessage = z.object({
task_id: z.string(),
export_name: z.string().optional(),
@@ -221,8 +188,6 @@ export type ProgressStateWsMessage = z.infer<typeof zProgressStateWsMessage>
export type FeatureFlagsWsMessage = z.infer<typeof zFeatureFlagsWsMessage>
export type AssetDownloadWsMessage = z.infer<typeof zAssetDownloadWsMessage>
export type AssetExportWsMessage = z.infer<typeof zAssetExportWsMessage>
export type DownloadState = (typeof DOWNLOAD_STATES)[number]
export type DownloadStatus = z.infer<typeof zDownloadStatus>
// End of ws messages
export type NotificationWsMessage = z.infer<typeof zNotificationWsMessage>

View File

@@ -34,7 +34,6 @@ import type {
AssetDownloadWsMessage,
AssetExportWsMessage,
CustomNodesI18n,
DownloadStatus,
EmbeddingsResponse,
ExecutedWsMessage,
ExecutingWsMessage,
@@ -187,7 +186,6 @@ interface BackendApiCalls {
feature_flags: FeatureFlagsWsMessage
asset_download: AssetDownloadWsMessage
asset_export: AssetExportWsMessage
download_progress: DownloadStatus
}
/** Dictionary of all api calls */

View File

@@ -18,7 +18,7 @@ import type {
} from '@/stores/dialogStore'
import type { ComponentAttrs } from 'vue-component-type-helpers'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { WorkspaceRole } from '@/platform/workspace/api/workspaceApi'
// Lazy loaders for dialogs - components are loaded on first use
@@ -442,9 +442,9 @@ export const useDialogService = () => {
})
}
async function showSubscriptionRequiredDialog(options?: {
reason?: SubscriptionDialogReason
}) {
async function showSubscriptionRequiredDialog(
options?: SubscriptionDialogOptions
) {
if (!isCloud || !window.__CONFIG__?.subscription_required) {
return
}

View File

@@ -113,27 +113,6 @@ describe('commandStore', () => {
})
})
describe('unregisterCommand', () => {
it('removes a registered command', () => {
const store = useCommandStore()
store.registerCommand({ id: 'to.remove', function: vi.fn() })
store.unregisterCommand('to.remove')
expect(store.isRegistered('to.remove')).toBe(false)
})
it('is a no-op for an unregistered id', () => {
const store = useCommandStore()
store.registerCommand({ id: 'keep.me', function: vi.fn() })
store.unregisterCommand('nonexistent')
expect(store.isRegistered('keep.me')).toBe(true)
expect(store.commands).toHaveLength(1)
})
})
describe('isRegistered', () => {
it('returns false for unregistered command', () => {
const store = useCommandStore()

View File

@@ -88,12 +88,6 @@ export const useCommandStore = defineStore('command', () => {
}
}
const unregisterCommand = (commandId: string) => {
if (!commandsById.value[commandId]) return
const { [commandId]: _removed, ...rest } = commandsById.value
commandsById.value = rest
}
const getCommand = (command: string) => {
return commandsById.value[command]
}
@@ -145,7 +139,6 @@ export const useCommandStore = defineStore('command', () => {
getCommand,
registerCommand,
registerCommands,
unregisterCommand,
isRegistered,
loadExtensionCommands,
formatKeySequence

View File

@@ -5,19 +5,12 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useSidebarTabStore } from '@/stores/workspace/sidebarTabStore'
const {
mockGetSetting,
mockRegisterCommand,
mockRegisterCommands,
mockUnregisterCommand,
mockServerSideModelDownloads
} = vi.hoisted(() => ({
mockGetSetting: vi.fn(),
mockRegisterCommand: vi.fn(),
mockRegisterCommands: vi.fn(),
mockUnregisterCommand: vi.fn(),
mockServerSideModelDownloads: { value: false }
}))
const { mockGetSetting, mockRegisterCommand, mockRegisterCommands } =
vi.hoisted(() => ({
mockGetSetting: vi.fn(),
mockRegisterCommand: vi.fn(),
mockRegisterCommands: vi.fn()
}))
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: () => ({
@@ -28,43 +21,10 @@ vi.mock('@/platform/settings/settingStore', () => ({
vi.mock('@/stores/commandStore', () => ({
useCommandStore: () => ({
registerCommand: mockRegisterCommand,
unregisterCommand: mockUnregisterCommand,
commands: []
})
}))
vi.mock('@/composables/useFeatureFlags', async () => {
const { ref } = await import('vue')
const serverSideModelDownloadsRef = ref(mockServerSideModelDownloads.value)
Object.defineProperty(mockServerSideModelDownloads, 'value', {
get: () => serverSideModelDownloadsRef.value,
set: (value: boolean) => {
serverSideModelDownloadsRef.value = value
}
})
return {
useFeatureFlags: () => ({
flags: {
get serverSideModelDownloads() {
return mockServerSideModelDownloads.value
}
}
})
}
})
vi.mock(
'@/platform/modelManager/composables/useModelManagerSidebarTab',
() => ({
useModelManagerSidebarTab: () => ({
id: 'model-manager',
title: 'model-manager',
type: 'vue',
component: {}
})
})
)
vi.mock('@/stores/menuItemStore', () => ({
useMenuItemStore: () => ({
registerCommands: mockRegisterCommands
@@ -139,8 +99,6 @@ describe('useSidebarTabStore', () => {
mockGetSetting.mockReset()
mockRegisterCommand.mockClear()
mockRegisterCommands.mockClear()
mockUnregisterCommand.mockClear()
mockServerSideModelDownloads.value = false
})
it('registers the job history tab when QPO V2 is enabled', () => {
@@ -202,57 +160,4 @@ describe('useSidebarTabStore', () => {
])
expect(mockRegisterCommand).toHaveBeenCalledTimes(6)
})
it('registers the model-manager tab when the feature flag starts enabled', () => {
mockServerSideModelDownloads.value = true
const store = useSidebarTabStore()
store.registerCoreSidebarTabs()
expect(store.sidebarTabs.map((tab) => tab.id)).toContain('model-manager')
})
it('does not register the model-manager tab when the feature flag is disabled', () => {
const store = useSidebarTabStore()
store.registerCoreSidebarTabs()
expect(store.sidebarTabs.map((tab) => tab.id)).not.toContain(
'model-manager'
)
})
it('registers the model-manager tab when the feature flag turns on later', async () => {
const store = useSidebarTabStore()
store.registerCoreSidebarTabs()
expect(store.sidebarTabs.map((tab) => tab.id)).not.toContain(
'model-manager'
)
mockServerSideModelDownloads.value = true
await nextTick()
expect(store.sidebarTabs.map((tab) => tab.id)).toContain('model-manager')
expect(mockRegisterCommand).toHaveBeenCalledWith(
expect.objectContaining({
id: 'Workspace.ToggleSidebarTab.model-manager'
})
)
})
it('unregisters the model-manager tab and its command when the flag turns off', async () => {
mockServerSideModelDownloads.value = true
const store = useSidebarTabStore()
store.registerCoreSidebarTabs()
expect(store.sidebarTabs.map((tab) => tab.id)).toContain('model-manager')
mockServerSideModelDownloads.value = false
await nextTick()
expect(store.sidebarTabs.map((tab) => tab.id)).not.toContain(
'model-manager'
)
expect(mockUnregisterCommand).toHaveBeenCalledWith(
'Workspace.ToggleSidebarTab.model-manager'
)
})
})

View File

@@ -5,9 +5,7 @@ import { useAssetsSidebarTab } from '@/composables/sidebarTabs/useAssetsSidebarT
import { useJobHistorySidebarTab } from '@/composables/sidebarTabs/useJobHistorySidebarTab'
import { useModelLibrarySidebarTab } from '@/composables/sidebarTabs/useModelLibrarySidebarTab'
import { useNodeLibrarySidebarTab } from '@/composables/sidebarTabs/useNodeLibrarySidebarTab'
import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { t, te } from '@/i18n'
import { useModelManagerSidebarTab } from '@/platform/modelManager/composables/useModelManagerSidebarTab'
import { useSettingStore } from '@/platform/settings/settingStore'
import { useAppsSidebarTab } from '@/platform/workflow/management/composables/useAppsSidebarTab'
import { useWorkflowsSidebarTab } from '@/platform/workflow/management/composables/useWorkflowsSidebarTab'
@@ -55,8 +53,7 @@ export const useSidebarTabStore = defineStore('sidebarTab', () => {
'model-library': 'sideToolbar.modelLibrary',
workflows: 'sideToolbar.workflows',
assets: 'sideToolbar.assets',
'job-history': 'queue.jobHistory',
'model-manager': 'modelManager.title'
'job-history': 'queue.jobHistory'
}
const key = menubarLabelKeys[tab.id]
@@ -135,27 +132,6 @@ export const useSidebarTabStore = defineStore('sidebarTab', () => {
(enabled) => syncJobHistoryTab(enabled)
)
const modelManagerTabId = 'model-manager'
const { flags } = useFeatureFlags()
const syncModelManagerTab = (enabled: boolean) => {
const hasTab = sidebarTabs.value.some(
(tab) => tab.id === modelManagerTabId
)
if (enabled && !hasTab) {
registerSidebarTab(useModelManagerSidebarTab())
} else if (!enabled && hasTab) {
unregisterSidebarTab(modelManagerTabId)
useCommandStore().unregisterCommand(
`Workspace.ToggleSidebarTab.${modelManagerTabId}`
)
}
}
watch(
() => flags.serverSideModelDownloads,
(enabled) => syncModelManagerTab(enabled),
{ immediate: true }
)
registerSidebarTab(useAssetsSidebarTab())
registerSidebarTab(useNodeLibrarySidebarTab())
registerSidebarTab(useModelLibrarySidebarTab())

View File

@@ -58,7 +58,6 @@ import { useQueuePolling } from '@/platform/remote/comfyui/useQueuePolling'
import { useErrorHandling } from '@/composables/useErrorHandling'
import { useReconnectQueueRefresh } from '@/composables/useReconnectQueueRefresh'
import { useReconnectingNotification } from '@/composables/useReconnectingNotification'
import { useModelDownloadEffects } from '@/platform/modelManager/composables/useModelDownloadEffects'
import { useProgressFavicon } from '@/composables/useProgressFavicon'
import { SERVER_CONFIG_ITEMS } from '@/constants/serverConfig'
import type { ServerConfig, ServerConfigValue } from '@/constants/serverConfig'
@@ -228,7 +227,6 @@ useMenuItemStore().registerCoreMenuCommands()
useKeybindingService().registerCoreKeybindings()
useSidebarTabStore().registerCoreSidebarTabs()
void useBottomPanelStore().registerCoreBottomPanelTabs()
useModelDownloadEffects()
useQueuePolling()
const queuePendingTaskCountStore = useQueuePendingTaskCountStore()