Compare commits

..

2 Commits

Author SHA1 Message Date
huang47
838734c67f test: cover asset and model stores 2026-07-02 09:01:08 -07:00
Benjamin Lu
2ec2a0e091 feat: attribute payment intent through paywall, checkout, and top-up telemetry (#13363)
## Summary

Answers "why did this user want to pay?" by capturing the triggering
product moment at every paywall/upsell entry point and carrying it
through checkout and success telemetry.

## Changes

- **What**:
- Widen `SubscriptionDialogReason` from 4 coarse values to 13 grounded
intent sources (`subscribe_to_run`, `upgrade_to_add_credits`,
`invite_member_upsell`, `settings_billing_panel`, etc.)
- Fire `app:subscription_required_modal_opened` from
`useSubscriptionDialog` (the choke point all dialog variants pass
through) — the workspace/unified path previously emitted nothing; remove
the now-duplicate emitters in `useSubscription` and
`usePricingTableUrlLoader`
- Add `payment_intent_source` to
`BeginCheckoutMetadata`/`SubscriptionSuccessMetadata`, threaded via the
existing `reason` prop: dialog → `PricingTable` →
`performSubscriptionCheckout` → pending-attempt record, so legacy
`app:monthly_subscription_succeeded` carries intent alongside
`checkout_attempt_id`
- Fire `begin_checkout` on the workspace checkout path
(`useSubscriptionCheckout`, personal + team confirm) and the team
deep-link util — both previously emitted nothing; `tier` widened to
`TierKey | 'team'`
- Implement `trackBeginCheckout` in `PostHogTelemetryProvider` (was
GTM/host-only, so `begin_checkout` never reached PostHog)
- Thread `showSubscriptionDialog(options)` through the billing-context
adapters and pass a reason at ~14 call sites; add `source` to
`app:add_api_credit_button_clicked`

## Review Focus

- `modal_opened` now fires once per dialog actually shown, so a
free-tier user clicking Upgrade emits two events (free-tier dialog, then
pricing table) where the legacy path emitted one
- Intent is threaded explicitly via props/params rather than shared
state; `useSubscriptionCheckout` gained an optional second parameter
2026-07-02 03:11:21 +00:00
109 changed files with 4754 additions and 5444 deletions

View File

@@ -1,304 +0,0 @@
import { expect, mergeTests } from '@playwright/test'
import type { Page } from '@playwright/test'
import type { ComfyPage } from '@e2e/fixtures/ComfyPage'
import { comfyPageFixture } from '@e2e/fixtures/ComfyPage'
import { ConfirmDialog } from '@e2e/fixtures/components/ConfirmDialog'
import { jsonRoute } from '@e2e/fixtures/utils/jsonRoute'
import { webSocketFixture } from '@e2e/fixtures/ws'
import type {
DownloadStatus,
EnqueueResponse,
HostCredentialUpsert,
HostCredentialView
} from '@/platform/modelManager/types'
const test = mergeTests(comfyPageFixture, webSocketFixture)
const DOWNLOADS_ROUTE = /\/api\/download$/
const ENQUEUE_ROUTE = /\/api\/download\/enqueue$/
const CREDENTIALS_ROUTE = /\/api\/download\/credentials$/
const CREDENTIAL_ROUTE = /\/api\/download\/credentials\/([^/?]+)$/
const CLEAR_ROUTE = /\/api\/download\/clear$/
const DOWNLOAD_ID = 'e2e-download-1'
const MODEL_URL =
'https://huggingface.co/e2e/test/resolve/main/model.safetensors'
const MODEL_ID = 'checkpoints/e2e-test-model.safetensors'
function makeDownloadStatus(
overrides: Partial<DownloadStatus> = {}
): DownloadStatus {
const now = Math.floor(Date.now() / 1000)
return {
download_id: DOWNLOAD_ID,
model_id: MODEL_ID,
url: MODEL_URL,
status: 'queued',
priority: 0,
total_bytes: null,
bytes_done: 0,
progress: null,
speed_bps: null,
eta_seconds: null,
segments: null,
error: null,
created_at: now,
updated_at: now,
...overrides
}
}
async function enableServerSideModelDownloads(comfyPage: ComfyPage) {
await comfyPage.page.evaluate(() => {
window.app!.api.serverFeatureFlags.value = {
...window.app!.api.serverFeatureFlags.value,
server_side_model_downloads: true
}
})
}
/** Keeps GET /download consistent so the 5s stale-poll can't wipe test state. */
async function mockDownloadsList(
page: Page,
getDownloads: () => DownloadStatus[]
) {
await page.route(DOWNLOADS_ROUTE, async (route) => {
if (route.request().method() !== 'GET') return route.fallback()
await route.fulfill(jsonRoute({ downloads: getDownloads() }))
})
}
function modelDownloaderTabButton(comfyPage: ComfyPage) {
return comfyPage.page.locator('.model-manager-tab-button')
}
function modelDownloaderBadge(comfyPage: ComfyPage) {
return modelDownloaderTabButton(comfyPage).locator('.sidebar-icon-badge')
}
function modelDownloaderPanel(comfyPage: ComfyPage) {
return comfyPage.page.locator('.sidebar-content-container')
}
async function openModelDownloaderTab(comfyPage: ComfyPage) {
await modelDownloaderTabButton(comfyPage).click()
const panel = modelDownloaderPanel(comfyPage)
await expect(panel.getByText('Downloads', { exact: true })).toBeVisible()
// Toolbar buttons are only interactive while the tab panel is hovered
// (`group-hover/sidebar-tab`), so keep the cursor over it for later clicks.
await panel.hover()
}
test.describe('Model Downloader sidebar', { tag: '@ui' }, () => {
test.beforeEach(async ({ comfyPage }) => {
await comfyPage.modelLibrary.mockFoldersWithFiles({ checkpoints: [] })
// Isolate from any real downloads already tracked by the backend.
await mockDownloadsList(comfyPage.page, () => [])
await enableServerSideModelDownloads(comfyPage)
})
test('adds a model by URL and reflects live progress over the websocket', async ({
comfyPage,
getWebSocket
}) => {
let downloads: DownloadStatus[] = []
await mockDownloadsList(comfyPage.page, () => downloads)
await comfyPage.page.route(ENQUEUE_ROUTE, async (route) => {
downloads = [makeDownloadStatus()]
const response: EnqueueResponse = {
download_id: DOWNLOAD_ID,
accepted: true
}
await route.fulfill({
status: 202,
contentType: 'application/json',
body: JSON.stringify(response)
})
})
await openModelDownloaderTab(comfyPage)
const panel = modelDownloaderPanel(comfyPage)
await panel.getByTitle('Add model').click()
const addDialog = comfyPage.page.getByRole('dialog')
await expect(addDialog.getByText('Add model')).toBeVisible()
await addDialog.getByLabel('URL').fill(MODEL_URL)
await expect(addDialog.getByLabel('Filename')).toHaveValue(
'model.safetensors'
)
await addDialog.getByLabel('Filename').fill('e2e-test-model.safetensors')
await addDialog.getByRole('combobox', { name: 'Select a folder' }).click()
await comfyPage.page
.getByRole('option', { name: 'Checkpoints', exact: true })
.click()
await addDialog
.getByRole('button', { name: 'Download', exact: true })
.click()
await expect(addDialog).toBeHidden()
await expect(panel.getByText('e2e-test-model.safetensors')).toBeVisible()
await expect(panel.getByText('Queued', { exact: true })).toBeVisible()
await expect(modelDownloaderBadge(comfyPage)).toHaveText('1')
const ws = await getWebSocket()
function sendProgress(overrides: Partial<DownloadStatus>) {
const payload = makeDownloadStatus(overrides)
downloads = [payload]
ws.send(JSON.stringify({ type: 'download_progress', data: payload }))
}
sendProgress({
status: 'active',
progress: 0.4,
bytes_done: 400,
total_bytes: 1000,
speed_bps: 500_000
})
await expect(panel.getByText('Downloading', { exact: true })).toBeVisible()
await expect(panel.getByText(/40%/)).toBeVisible()
await expect(modelDownloaderBadge(comfyPage)).toHaveText('1')
sendProgress({
status: 'completed',
progress: 1,
bytes_done: 1000,
total_bytes: 1000,
speed_bps: null,
eta_seconds: null
})
await expect(panel.getByText('Completed', { exact: true })).toBeVisible()
await expect(panel.getByText('History', { exact: true })).toBeVisible()
await expect(modelDownloaderBadge(comfyPage)).toHaveCount(0)
})
test('clears history and the cleared rows stay gone after reopening the tab', async ({
comfyPage
}) => {
let downloads: DownloadStatus[] = [
makeDownloadStatus({ status: 'completed', progress: 1, bytes_done: 1000 })
]
await mockDownloadsList(comfyPage.page, () => downloads)
await comfyPage.page.route(CLEAR_ROUTE, async (route) => {
const count = downloads.length
downloads = []
await route.fulfill(jsonRoute({ deleted: count }))
})
await openModelDownloaderTab(comfyPage)
const panel = modelDownloaderPanel(comfyPage)
await expect(panel.getByText('History', { exact: true })).toBeVisible()
await expect(panel.getByText('e2e-test-model.safetensors')).toBeVisible()
await panel.getByRole('button', { name: 'Clear history' }).click()
await expect(panel.getByText('No downloads yet')).toBeVisible()
// Reopening the tab re-runs hydrate() -> GET /api/download. The bug was
// that the cleared row reappeared here; the persisted delete must prevent
// the backend list from returning it again.
await modelDownloaderTabButton(comfyPage).click()
await openModelDownloaderTab(comfyPage)
await expect(
modelDownloaderPanel(comfyPage).getByText('No downloads yet')
).toBeVisible()
await expect(
modelDownloaderPanel(comfyPage).getByText('e2e-test-model.safetensors')
).toBeHidden()
})
test('manages download credentials: add, edit, and delete', async ({
comfyPage
}) => {
let credentials: HostCredentialView[] = []
let nextId = 1
await comfyPage.page.route(CREDENTIALS_ROUTE, async (route) => {
const method = route.request().method()
if (method === 'GET') {
await route.fulfill(jsonRoute({ credentials }))
return
}
if (method === 'POST') {
const body = route.request().postDataJSON() as HostCredentialUpsert
const existing = credentials.find((c) => c.host === body.host)
const now = Math.floor(Date.now() / 1000)
const view: HostCredentialView = {
id: existing?.id ?? `cred-${nextId++}`,
host: body.host,
auth_scheme: body.auth_scheme ?? 'bearer',
header_name: body.header_name ?? null,
query_param: body.query_param ?? null,
label: body.label ?? null,
match_subdomains: body.match_subdomains ?? false,
enabled: body.enabled ?? true,
secret_last4: body.secret.slice(-4),
created_at: existing?.created_at ?? now,
updated_at: now
}
credentials = existing
? credentials.map((c) => (c.id === view.id ? view : c))
: [...credentials, view]
await route.fulfill(jsonRoute(view))
return
}
await route.fallback()
})
await comfyPage.page.route(CREDENTIAL_ROUTE, async (route) => {
if (route.request().method() !== 'DELETE') return route.fallback()
const id = route.request().url().match(CREDENTIAL_ROUTE)?.[1]
credentials = credentials.filter((c) => c.id !== id)
await route.fulfill(jsonRoute({ deleted: true }))
})
const dialog = comfyPage.page
.getByRole('dialog')
.filter({ hasText: 'Credentials Manager' })
// The success toast shown after saving can momentarily steal focus and
// close the dialog, so re-opening is retried until it sticks.
async function ensureCredentialsDialogOpen() {
await expect(async () => {
const panel = modelDownloaderPanel(comfyPage)
await panel.hover()
await panel.getByTitle('Credentials Manager').click()
await expect(dialog).toBeVisible({ timeout: 1000 })
}).toPass({ timeout: 10_000 })
}
await openModelDownloaderTab(comfyPage)
await ensureCredentialsDialogOpen()
await dialog.getByLabel('Host').fill('huggingface.co')
await dialog.getByLabel('API key').fill('secret-key-1234')
await dialog.getByRole('button', { name: 'Save', exact: true }).click()
await ensureCredentialsDialogOpen()
await expect(
dialog.getByText('huggingface.co · Bearer token · ••••1234')
).toBeVisible()
await dialog.getByTitle('Edit').click()
await expect(dialog.getByText('Update credential')).toBeVisible()
await expect(dialog.getByLabel('API key')).toHaveValue('')
await dialog.getByLabel('Label').fill('My HF Key')
await dialog.getByLabel('API key').fill('updated-secret-5678')
await dialog.getByRole('button', { name: 'Save', exact: true }).click()
await ensureCredentialsDialogOpen()
await expect(dialog.getByText('My HF Key')).toBeVisible()
await expect(
dialog.getByText('huggingface.co · Bearer token · ••••5678')
).toBeVisible()
await dialog.getByTitle('Delete').click()
const confirm = new ConfirmDialog(comfyPage.page)
await confirm.click('delete')
await expect(dialog.getByText('My HF Key')).toBeHidden()
})
})

View File

@@ -243,7 +243,7 @@ onMounted(() => {
--sidebar-padding: 4px;
--sidebar-icon-size: 1rem;
--sidebar-default-floating-width: 50px;
--sidebar-default-floating-width: 48px;
--sidebar-default-connected-width: calc(
var(--sidebar-default-floating-width) + var(--sidebar-padding) * 2
);

View File

@@ -224,7 +224,7 @@ const handleOpenUserSettings = () => {
}
const handleOpenPlansAndPricing = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
emit('close')
}
@@ -239,8 +239,7 @@ const handleOpenPlanAndCreditsSettings = () => {
}
const handleTopUp = () => {
// Track purchase credits entry from avatar popover
useTelemetry()?.trackAddApiCreditButtonClicked()
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
dialogService.showTopUpCreditsDialog()
emit('close')
}
@@ -254,7 +253,7 @@ const handleOpenPartnerNodesInfo = () => {
}
const handleUpgradeToAddCredits = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
emit('close')
}

View File

@@ -21,6 +21,6 @@ const { isFreeTier } = useBillingContext()
const subscriptionDialog = useSubscriptionDialog()
function handleClick() {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'subscribe_now_button' })
}
</script>

View File

@@ -1,17 +1,12 @@
<script setup lang="ts">
import type { DialogContentEmits, DialogContentProps } from 'reka-ui'
import {
DialogContent,
injectDialogRootContext,
useForwardPropsEmits
} from 'reka-ui'
import { DialogContent, useForwardPropsEmits } from 'reka-ui'
import type { HTMLAttributes } from 'vue'
import { cn } from '@comfyorg/tailwind-utils'
import type { DialogContentSize } from './dialog.variants'
import { dialogContentVariants } from './dialog.variants'
import { useModalPointerLock } from './useModalPointerLock'
const {
size,
@@ -28,11 +23,6 @@ const {
const emits = defineEmits<DialogContentEmits>()
const forwarded = useForwardPropsEmits(restProps, emits)
const dialogRootContext = injectDialogRootContext(null)
if (dialogRootContext?.modal.value) {
useModalPointerLock(() => dialogRootContext.open.value)
}
</script>
<template>

View File

@@ -1,114 +0,0 @@
import { render } from '@testing-library/vue'
import { describe, expect, it } from 'vitest'
import {
SelectContent,
SelectPortal,
SelectRoot,
SelectTrigger,
SelectViewport
} from 'reka-ui'
import { defineComponent, h, nextTick, ref } from 'vue'
import Dialog from './Dialog.vue'
import DialogContent from './DialogContent.vue'
import DialogPortal from './DialogPortal.vue'
async function flush() {
await nextTick()
await nextTick()
await new Promise((resolve) => setTimeout(resolve, 0))
await nextTick()
}
function mountDialogWithSelect(modal: boolean) {
const dialogOpen = ref(true)
const selectOpen = ref(false)
const Parent = defineComponent({
setup() {
return () =>
h(
Dialog,
{
modal,
open: dialogOpen.value,
'onUpdate:open': (value: boolean) => (dialogOpen.value = value)
},
() =>
h(DialogPortal, null, () =>
h(DialogContent, null, () =>
h(
SelectRoot,
{
open: selectOpen.value,
'onUpdate:open': (value: boolean) =>
(selectOpen.value = value)
},
() => [
h(SelectTrigger, null, () => 'trigger'),
h(SelectPortal, null, () =>
h(SelectContent, { position: 'popper' }, () =>
h(SelectViewport, null, () => 'items')
)
)
]
)
)
)
)
}
})
return { ...render(Parent), dialogOpen, selectOpen }
}
describe('modal dialog pointer lock', () => {
it('keeps body inert after a nested combobox popover opens and closes', async () => {
const { selectOpen } = mountDialogWithSelect(true)
await flush()
expect(document.body.style.pointerEvents).toBe('none')
selectOpen.value = true
await flush()
expect(document.body.style.pointerEvents).toBe('none')
// Reka restores body pointer events when the popover layer unmounts; the
// lock must re-assert it so the canvas behind the dialog stays inert.
selectOpen.value = false
await flush()
expect(document.body.style.pointerEvents).toBe('none')
})
it('restores body pointer events once the dialog closes', async () => {
const { dialogOpen } = mountDialogWithSelect(true)
await flush()
expect(document.body.style.pointerEvents).toBe('none')
dialogOpen.value = false
await flush()
expect(document.body.style.pointerEvents).toBe('')
})
it('does not lock body for a non-modal dialog', async () => {
mountDialogWithSelect(false)
await flush()
expect(document.body.style.pointerEvents).toBe('')
})
it('holds the body lock until every open modal dialog closes', async () => {
const first = mountDialogWithSelect(true)
const second = mountDialogWithSelect(true)
await flush()
expect(document.body.style.pointerEvents).toBe('none')
// With one modal still open the shared lock must keep the body inert.
first.dialogOpen.value = false
await flush()
expect(document.body.style.pointerEvents).toBe('none')
// Once the last modal closes the lock releases and stops forcing inert.
second.dialogOpen.value = false
await flush()
expect(document.body.style.pointerEvents).not.toBe('none')
})
})

View File

@@ -1,54 +0,0 @@
import { onScopeDispose, toValue, watch } from 'vue'
import type { MaybeRefOrGetter } from 'vue'
/**
* Keeps the canvas behind a modal dialog inert by holding `document.body`'s
* pointer-events lock for as long as at least one modal dialog is open.
*
* Reka-UI locks body pointer events per modal layer, but a nested dismissable
* layer that is portalled to the body — e.g. a `Select` popover inside the
* dialog — restores the body's pointer events when it closes, even while the
* outer modal dialog is still open. That momentarily re-enables the canvas, so
* combobox clicks leak through to it and can select a node or dismiss the
* dialog. Reka still performs the initial lock and final restore; the
* `MutationObserver` only re-asserts `none` if the lock is cleared while a
* modal dialog is still open.
*/
let openModalCount = 0
let observer: MutationObserver | null = null
function enforceLock() {
if (openModalCount > 0 && document.body.style.pointerEvents !== 'none') {
document.body.style.pointerEvents = 'none'
}
}
function acquire() {
openModalCount += 1
if (observer) return
observer = new MutationObserver(enforceLock)
observer.observe(document.body, {
attributes: true,
attributeFilter: ['style']
})
}
function release() {
openModalCount = Math.max(0, openModalCount - 1)
if (openModalCount > 0) return
observer?.disconnect()
observer = null
document.body.style.pointerEvents = ''
}
export function useModalPointerLock(isOpen: MaybeRefOrGetter<boolean>) {
let holding = false
const sync = (open: boolean) => {
if (open === holding) return
holding = open
if (open) acquire()
else release()
}
watch(() => toValue(isOpen), sync, { immediate: true })
onScopeDispose(() => sync(false))
}

View File

@@ -1,5 +1,6 @@
import type { ComputedRef, Ref } from 'vue'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type {
BillingStatus,
@@ -75,9 +76,10 @@ export interface BillingActions {
*/
requireActiveSubscription: () => Promise<void>
/**
* Shows the subscription dialog.
* Shows the subscription dialog. Pass a reason so the paywall open and any
* downstream checkout stay attributed to the triggering product moment.
*/
showSubscriptionDialog: () => void
showSubscriptionDialog: (options?: SubscriptionDialogOptions) => void
}
export interface BillingState {

View File

@@ -7,6 +7,7 @@ import {
getTierFeatures
} from '@/platform/cloud/subscription/constants/tierPricing'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
PreviewSubscribeOptions,
SubscribeOptions
@@ -281,8 +282,8 @@ function useBillingContextInternal(): BillingContext {
return activeContext.value.requireActiveSubscription()
}
function showSubscriptionDialog() {
return activeContext.value.showSubscriptionDialog()
function showSubscriptionDialog(options?: SubscriptionDialogOptions) {
return activeContext.value.showSubscriptionDialog(options)
}
return {

View File

@@ -2,6 +2,7 @@ import { computed, ref } from 'vue'
import { useAuthActions } from '@/composables/auth/useAuthActions'
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
BillingStatus,
BillingSubscriptionStatus,
@@ -189,12 +190,12 @@ export function useLegacyBilling(): BillingState & BillingActions {
async function requireActiveSubscription(): Promise<void> {
await fetchStatus()
if (!isActiveSubscription.value) {
legacyShowSubscriptionDialog()
legacyShowSubscriptionDialog({ reason: 'subscription_required' })
}
}
function showSubscriptionDialog(): void {
legacyShowSubscriptionDialog()
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
legacyShowSubscriptionDialog(options)
}
return {

View File

@@ -503,7 +503,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}
@@ -526,7 +526,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}
@@ -548,7 +548,7 @@ export function useCoreCommands(): ComfyCommand[] {
}) => {
trackRunButton(metadata)
if (!isActiveSubscription.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
return
}

View File

@@ -30,8 +30,7 @@ export enum ServerFeatureFlag {
COMFYHUB_PROFILE_GATE_ENABLED = 'comfyhub_profile_gate_enabled',
SHOW_SIGNIN_BUTTON = 'show_signin_button',
UNIFIED_CLOUD_AUTH = 'unified_cloud_auth',
SIGNUP_TURNSTILE = 'signup_turnstile',
SERVER_SIDE_MODEL_DOWNLOADS = 'server_side_model_downloads'
SIGNUP_TURNSTILE = 'signup_turnstile'
}
/**
@@ -182,12 +181,6 @@ export function useFeatureFlags() {
remoteConfig.value.signup_turnstile,
'off'
)
},
get serverSideModelDownloads() {
return api.getServerFeature(
ServerFeatureFlag.SERVER_SIDE_MODEL_DOWNLOADS,
false
)
}
})

View File

@@ -1069,74 +1069,6 @@
"update": "Update",
"description": "Check out the latest improvements and features in this update."
},
"modelManager": {
"title": "Downloads",
"active": "Active",
"history": "History",
"empty": "No downloads yet",
"clearHistory": "Clear history",
"addModel": "Add model",
"addModelDescription": "Download a model directly to your ComfyUI models folder.",
"url": "URL",
"urlPlaceholder": "https://huggingface.co/.../model.safetensors",
"hostNotAllowedHint": "This host may not be on the download allowlist.",
"folder": "Model folder",
"selectFolder": "Select a folder",
"filename": "Filename",
"filenamePlaceholder": "model.safetensors",
"allowAnyExtension": "Allow any file extension (advanced)",
"download": "Download",
"downloadQueued": "Download queued",
"alreadyInstalled": "Model already installed",
"alreadyDownloading": "Model is already downloading",
"resume": "Resume",
"raisePriority": "Raise priority",
"removeFromList": "Remove from list",
"addCredentials": "Add credentials",
"authErrorHint": "{host} needs an API key. Add one in the Credentials Manager, then resume.",
"authErrorHintNoHost": "This host/model needs an API key. Add one in the Credentials Manager, then resume.",
"gatedModelHint": "This model is gated. Accept its license on the model's page, then add an API key and resume.",
"openModelPage": "Accept license",
"actionFailed": "Action failed",
"cancelConfirmTitle": "Cancel download?",
"cancelConfirmMessage": "This will stop the download and delete the partial file for \"{name}\".",
"cancelConfirm": "Cancel download",
"status": {
"queued": "Queued",
"active": "Downloading",
"paused": "Paused",
"verifying": "Verifying",
"completed": "Completed",
"failed": "Failed",
"cancelled": "Cancelled"
},
"credentials": {
"title": "Credentials Manager",
"description": "Store API keys per host (e.g. HuggingFace, Civitai). Secrets are write-only and never shown again.",
"add": "Add credential",
"update": "Update credential",
"edit": "Edit",
"save": "Save",
"saved": "Credential saved",
"host": "Host",
"secret": "API key",
"secretPlaceholder": "Paste API key",
"authScheme": "Auth scheme",
"headerName": "Header name",
"queryParam": "Query parameter",
"label": "Label",
"disabled": "Disabled",
"matchSubdomains": "Match subdomains",
"matchSubdomainsWarning": "Not recommended: hubs redirect to sibling CDN hosts that must not receive your key.",
"deleteTitle": "Delete credential?",
"deleteMessage": "Delete the stored credential for \"{host}\"?",
"scheme": {
"bearer": "Bearer token",
"header": "Custom header",
"query": "Query parameter"
}
}
},
"menu": {
"hideMenu": "Hide Menu",
"showMenu": "Show Menu",

View File

@@ -25,6 +25,6 @@ function handleClose() {
}
function handleSubscribe() {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'upload_model_upgrade' })
}
</script>

View File

@@ -0,0 +1,85 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { useAssetsApi } from './useAssetsApi'
const mockAssetsStore = vi.hoisted(() => ({
inputAssets: [] as AssetItem[],
historyAssets: [] as AssetItem[],
inputLoading: false,
historyLoading: false,
inputError: null as string | null,
historyError: null as string | null,
hasMoreHistory: false,
isLoadingMore: false,
updateInputs: vi.fn(),
updateHistory: vi.fn(),
loadMoreHistory: vi.fn()
}))
vi.mock('@/stores/assetsStore', () => ({
useAssetsStore: () => mockAssetsStore
}))
function createAsset(id: string): AssetItem {
return {
id,
name: `${id}.png`,
size: 1,
created_at: '2026-01-01T00:00:00Z',
tags: ['input']
}
}
describe('useAssetsApi', () => {
beforeEach(() => {
vi.clearAllMocks()
mockAssetsStore.inputAssets = [createAsset('input-1')]
mockAssetsStore.historyAssets = [createAsset('history-1')]
mockAssetsStore.inputLoading = true
mockAssetsStore.historyLoading = false
mockAssetsStore.inputError = 'input-error'
mockAssetsStore.historyError = 'history-error'
mockAssetsStore.hasMoreHistory = true
mockAssetsStore.isLoadingMore = true
})
it('uses input assets and refreshes inputs', async () => {
const api = useAssetsApi('input')
expect(api.media.value).toEqual([createAsset('input-1')])
expect(api.loading.value).toBe(true)
expect(api.error.value).toBe('input-error')
expect(api.hasMore.value).toBe(false)
expect(api.isLoadingMore.value).toBe(false)
await expect(api.fetchMediaList()).resolves.toEqual([
createAsset('input-1')
])
await expect(api.refresh()).resolves.toEqual([createAsset('input-1')])
await api.loadMore()
expect(mockAssetsStore.updateInputs).toHaveBeenCalledTimes(2)
expect(mockAssetsStore.updateHistory).not.toHaveBeenCalled()
expect(mockAssetsStore.loadMoreHistory).not.toHaveBeenCalled()
})
it('uses output history and loads more history', async () => {
const api = useAssetsApi('output')
expect(api.media.value).toEqual([createAsset('history-1')])
expect(api.loading.value).toBe(false)
expect(api.error.value).toBe('history-error')
expect(api.hasMore.value).toBe(true)
expect(api.isLoadingMore.value).toBe(true)
await expect(api.fetchMediaList()).resolves.toEqual([
createAsset('history-1')
])
await api.loadMore()
expect(mockAssetsStore.updateHistory).toHaveBeenCalledOnce()
expect(mockAssetsStore.updateInputs).not.toHaveBeenCalled()
expect(mockAssetsStore.loadMoreHistory).toHaveBeenCalledOnce()
})
})

View File

@@ -8,6 +8,7 @@ import { useI18n } from 'vue-i18n'
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
import { MediaAssetKey } from '@/platform/assets/schemas/mediaAssetSchema'
import { api } from '@/scripts/api'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import type { AssetMeta } from '@/platform/assets/schemas/mediaAssetSchema'
import type * as outputAssetUtilModule from '../utils/outputAssetUtil'
@@ -18,6 +19,12 @@ const mockIsCloud = vi.hoisted(() => ({ value: false }))
// Track the filename passed to createAnnotatedPath
const capturedFilenames = vi.hoisted(() => ({ values: [] as string[] }))
const capturedAnnotatedPaths = vi.hoisted(() => ({
values: [] as Array<{
item: { filename: string; subfolder?: string; type?: string }
options: { rootFolder?: string }
}>
}))
const mockDownloadFile = vi.hoisted(() => vi.fn())
vi.mock('@/base/common/downloadUtil', () => ({
@@ -73,9 +80,10 @@ vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: () => ({})
}))
const mockCopyToClipboard = vi.hoisted(() => vi.fn())
vi.mock('@/composables/useCopyToClipboard', () => ({
useCopyToClipboard: () => ({
copyToClipboard: vi.fn()
copyToClipboard: mockCopyToClipboard
})
}))
@@ -93,45 +101,50 @@ vi.mock('@/platform/workflow/utils/workflowExtractionUtil', () => ({
extractWorkflowFromAsset: mockExtractWorkflowFromAsset
}))
const mockAddNodeOnGraph = vi.hoisted(() => vi.fn())
const mockGetCanvasCenter = vi.hoisted(() => vi.fn())
vi.mock('@/services/litegraphService', () => ({
useLitegraphService: () => ({
addNodeOnGraph: vi.fn().mockReturnValue(
fromAny<LGraphNode, unknown>({
widgets: [{ name: 'image', value: '', callback: vi.fn() }],
graph: { setDirtyCanvas: vi.fn() }
})
),
getCanvasCenter: vi.fn().mockReturnValue([100, 100])
addNodeOnGraph: mockAddNodeOnGraph,
getCanvasCenter: mockGetCanvasCenter
})
}))
const mockNodeDefsByName = vi.hoisted(() => ({
value: {
LoadImage: {
name: 'LoadImage',
display_name: 'Load Image'
}
} as Record<string, unknown>
}))
vi.mock('@/stores/nodeDefStore', () => ({
useNodeDefStore: () => ({
nodeDefsByName: {
LoadImage: {
name: 'LoadImage',
display_name: 'Load Image'
}
}
nodeDefsByName: mockNodeDefsByName.value
})
}))
vi.mock('@/utils/createAnnotatedPath', () => ({
createAnnotatedPath: vi.fn((item: { filename: string }) => {
capturedFilenames.values.push(item.filename)
return item.filename
})
createAnnotatedPath: vi.fn(
(
item: { filename: string; subfolder?: string; type?: string },
options: { rootFolder?: string }
) => {
capturedAnnotatedPaths.values.push({ item, options })
capturedFilenames.values.push(item.filename)
return item.filename
}
)
}))
const mockDetectNodeTypeFromFilename = vi.hoisted(() => vi.fn())
vi.mock('@/utils/loaderNodeUtil', () => ({
detectNodeTypeFromFilename: vi.fn().mockReturnValue({
nodeType: 'LoadImage',
widgetName: 'image'
})
detectNodeTypeFromFilename: mockDetectNodeTypeFromFilename
}))
const mockIsResultItemType = vi.hoisted(() => vi.fn())
vi.mock('@/utils/typeGuardUtil', () => ({
isResultItemType: vi.fn().mockReturnValue(true)
isResultItemType: mockIsResultItemType
}))
const mockGetAssetType = vi.hoisted(() => vi.fn())
@@ -186,7 +199,9 @@ vi.mock('@/scripts/api', () => ({
}
}))
const mockAppGraph = vi.hoisted(() => ({ value: { _nodes: [] as unknown[] } }))
const mockAppGraph = vi.hoisted(() => ({
value: { _nodes: [] as unknown[] } as { _nodes: unknown[] } | null
}))
vi.mock('@/scripts/app', () => ({
app: {
get graph() {
@@ -291,7 +306,43 @@ describe('useMediaAssetActions', () => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.clearAllMocks()
capturedFilenames.values = []
capturedAnnotatedPaths.values = []
mockIsCloud.value = false
mockAppGraph.value = { _nodes: [] }
mockDownloadFile.mockReset()
mockCopyToClipboard.mockReset()
mockShowDialog.mockReset()
mockAddNodeOnGraph.mockReset()
mockAddNodeOnGraph.mockReturnValue(
fromAny<LGraphNode, unknown>({
widgets: [{ name: 'image', value: '', callback: vi.fn() }],
graph: { setDirtyCanvas: vi.fn() }
})
)
mockGetCanvasCenter.mockReset()
mockGetCanvasCenter.mockReturnValue([100, 100])
mockNodeDefsByName.value = {
LoadImage: {
name: 'LoadImage',
display_name: 'Load Image'
}
}
mockDetectNodeTypeFromFilename.mockReset()
mockDetectNodeTypeFromFilename.mockReturnValue({
nodeType: 'LoadImage',
widgetName: 'image'
})
mockIsResultItemType.mockReset()
mockIsResultItemType.mockReturnValue(true)
mockExtractWorkflowFromAsset.mockReset()
mockOpenWorkflowAction.mockReset()
mockExportWorkflowAction.mockReset()
mockCreateAssetExport.mockReset()
mockCreateAssetExport.mockResolvedValue({
task_id: 'test-task-id',
status: 'pending'
})
mockDeleteAsset.mockReset()
mockGetOutputAssetMetadata.mockReset()
mockGetOutputAssetMetadata.mockReturnValue(null)
mockGetAssetType.mockReset()
@@ -299,7 +350,139 @@ describe('useMediaAssetActions', () => {
mockResolveOutputAssetItems.mockResolvedValue([])
})
describe('copyJobId', () => {
it('does nothing when no asset is available', async () => {
const { actions, unmount } = mountMediaActions()
await actions.copyJobId()
expect(mockCopyToClipboard).not.toHaveBeenCalled()
expect(useToast().add).not.toHaveBeenCalled()
unmount()
})
it('warns when the asset has no job id', async () => {
mockGetAssetType.mockReturnValue('input')
const actions = useMediaAssetActions()
await actions.copyJobId(createMockAsset())
expect(mockCopyToClipboard).not.toHaveBeenCalled()
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'warn' })
)
})
it('copies the metadata job id when present', async () => {
mockGetOutputAssetMetadata.mockReturnValue({ jobId: 'job-from-meta' })
const actions = useMediaAssetActions()
await actions.copyJobId(createMockAsset())
expect(mockCopyToClipboard).toHaveBeenCalledWith('job-from-meta')
})
it('copies the output asset id when metadata omits the job id', async () => {
mockGetAssetType.mockReturnValue('output')
const actions = useMediaAssetActions()
await actions.copyJobId(createMockAsset({ id: 'history-id' }))
expect(mockCopyToClipboard).toHaveBeenCalledWith('history-id')
})
})
describe('addWorkflow', () => {
it('does nothing when no asset is available', async () => {
const { actions, unmount } = mountMediaActions()
await actions.addWorkflow()
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
expect(useToast().add).not.toHaveBeenCalled()
unmount()
})
it('uses the injected media asset when no explicit asset is provided', async () => {
const mediaAsset = createMockMediaAsset({ name: 'context-image.png' })
const { actions, unmount } = mountMediaActions(mediaAsset)
await actions.addWorkflow()
expect(capturedFilenames.values).toContain('context-image.png')
unmount()
})
it('warns when the filename has no compatible loader node', async () => {
mockDetectNodeTypeFromFilename.mockReturnValue({
nodeType: undefined,
widgetName: undefined
})
const actions = useMediaAssetActions()
await actions.addWorkflow(createMockAsset({ name: 'notes.txt' }))
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'warn' })
)
})
it('reports missing node definitions', async () => {
mockNodeDefsByName.value = {}
const actions = useMediaAssetActions()
await actions.addWorkflow(createMockAsset())
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error' })
)
})
it('reports loader-node creation failure', async () => {
mockAddNodeOnGraph.mockReturnValue(undefined)
const actions = useMediaAssetActions()
await actions.addWorkflow(createMockAsset())
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error' })
)
})
it('still adds the node when the expected widget is absent', async () => {
const setDirtyCanvas = vi.fn()
mockAddNodeOnGraph.mockReturnValue(
fromAny<LGraphNode, unknown>({
widgets: [{ name: 'other', value: '' }],
graph: { setDirtyCanvas }
})
)
mockGetOutputAssetMetadata.mockReturnValue({ subfolder: 'nested' })
mockGetAssetType.mockReturnValue('custom')
mockIsResultItemType.mockReturnValue(false)
const actions = useMediaAssetActions()
await actions.addWorkflow(createMockAsset({ name: 'asset.png' }))
expect(capturedAnnotatedPaths.values.at(-1)).toEqual({
item: {
filename: 'asset.png',
subfolder: 'nested',
type: undefined
},
options: { rootFolder: 'input' }
})
expect(setDirtyCanvas).toHaveBeenCalledWith(true, true)
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'success' })
)
})
describe('OSS mode (isCloud = false)', () => {
beforeEach(() => {
mockIsCloud.value = false
@@ -366,6 +549,83 @@ describe('useMediaAssetActions', () => {
})
describe('addMultipleToWorkflow', () => {
it('does nothing for an empty selection', async () => {
const actions = useMediaAssetActions()
await actions.addMultipleToWorkflow([])
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
expect(useToast().add).not.toHaveBeenCalled()
})
it('shows a failure toast when none of the selected assets can be added', async () => {
mockDetectNodeTypeFromFilename
.mockReturnValueOnce({ nodeType: undefined, widgetName: undefined })
.mockReturnValueOnce({ nodeType: 'MissingNode', widgetName: 'image' })
const actions = useMediaAssetActions()
await actions.addMultipleToWorkflow([
createMockAsset({ id: 'a', name: 'unsupported.txt' }),
createMockAsset({ id: 'b', name: 'missing.png' })
])
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error' })
)
})
it('shows a partial warning when only some nodes are added', async () => {
mockAddNodeOnGraph
.mockReturnValueOnce(
fromAny<LGraphNode, unknown>({
widgets: [{ name: 'image', value: '', callback: vi.fn() }],
graph: { setDirtyCanvas: vi.fn() }
})
)
.mockReturnValueOnce(undefined)
const actions = useMediaAssetActions()
await actions.addMultipleToWorkflow([
createMockAsset({ id: 'a', name: 'a.png' }),
createMockAsset({ id: 'b', name: 'b.png' })
])
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'warn' })
)
})
it('adds assets without a matching widget using untyped paths', async () => {
const setDirtyCanvas = vi.fn()
mockAddNodeOnGraph.mockReturnValue(
fromAny<LGraphNode, unknown>({
widgets: [{ name: 'other', value: '' }],
graph: { setDirtyCanvas }
})
)
mockGetAssetType.mockReturnValue('custom')
mockIsResultItemType.mockReturnValue(false)
const actions = useMediaAssetActions()
await actions.addMultipleToWorkflow([
createMockAsset({ id: 'asset-1', name: 'asset-1.png' })
])
expect(capturedAnnotatedPaths.values.at(-1)).toEqual({
item: {
filename: 'asset-1.png',
subfolder: '',
type: undefined
},
options: { rootFolder: undefined }
})
expect(setDirtyCanvas).toHaveBeenCalledWith(true, true)
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'success' })
)
})
describe('Cloud mode (isCloud = true)', () => {
beforeEach(() => {
mockIsCloud.value = true
@@ -397,10 +657,56 @@ describe('useMediaAssetActions', () => {
})
})
describe('openWorkflow', () => {
beforeEach(() => {
mockExtractWorkflowFromAsset.mockResolvedValue({
workflow: { version: 0.4 },
filename: 'workflow.json'
})
})
it('does nothing when no asset is available', async () => {
const { actions, unmount } = mountMediaActions()
await actions.openWorkflow()
expect(mockExtractWorkflowFromAsset).not.toHaveBeenCalled()
expect(mockOpenWorkflowAction).not.toHaveBeenCalled()
unmount()
})
it('shows a success toast after opening the workflow', async () => {
mockOpenWorkflowAction.mockResolvedValue({ success: true })
const actions = useMediaAssetActions()
await actions.openWorkflow(createMockAsset())
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'success' })
)
})
it('uses the fallback warning when opening returns no error message', async () => {
mockOpenWorkflowAction.mockResolvedValue({ success: false })
const actions = useMediaAssetActions()
await actions.openWorkflow(createMockAsset())
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({
severity: 'warn',
detail: 'mediaAsset.noWorkflowDataFound'
})
)
})
})
describe('exportWorkflow', () => {
const successResult = { success: true } as const
const cancelledResult = { success: false, cancelled: true } as const
const failureResult = { success: false, error: 'boom' } as const
const failureWithoutError = { success: false } as const
const noWorkflowResult = {
success: false,
error: 'No workflow data available'
@@ -455,6 +761,31 @@ describe('useMediaAssetActions', () => {
)
})
it('does nothing when no asset is available', async () => {
const { actions, unmount } = mountMediaActions()
await actions.exportWorkflow()
expect(mockExtractWorkflowFromAsset).not.toHaveBeenCalled()
expect(mockExportWorkflowAction).not.toHaveBeenCalled()
unmount()
})
it('uses the fallback error when export fails without a message', async () => {
mockExportWorkflowAction.mockResolvedValue(failureWithoutError)
const actions = useMediaAssetActions()
await actions.exportWorkflow(createMockAsset())
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({
severity: 'error',
detail: 'mediaAsset.failedToExportWorkflow'
})
)
})
it('shows no toast when every asset in a bulk export is cancelled', async () => {
mockExportWorkflowAction.mockResolvedValue(cancelledResult)
const actions = useMediaAssetActions()
@@ -500,6 +831,118 @@ describe('useMediaAssetActions', () => {
})
})
describe('openMultipleWorkflows', () => {
beforeEach(() => {
mockExtractWorkflowFromAsset.mockResolvedValue({
workflow: { version: 0.4 },
filename: 'workflow.json'
})
})
it('does nothing for an empty selection', async () => {
const actions = useMediaAssetActions()
await actions.openMultipleWorkflows([])
expect(mockOpenWorkflowAction).not.toHaveBeenCalled()
expect(useToast().add).not.toHaveBeenCalled()
})
it('shows success when every workflow opens', async () => {
mockOpenWorkflowAction.mockResolvedValue({ success: true })
const actions = useMediaAssetActions()
await actions.openMultipleWorkflows([
createMockAsset({ id: 'a' }),
createMockAsset({ id: 'b' })
])
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'success' })
)
})
it('shows a missing-workflow warning when none open', async () => {
mockOpenWorkflowAction.mockResolvedValue({ success: false })
const actions = useMediaAssetActions()
await actions.openMultipleWorkflows([
createMockAsset({ id: 'a' }),
createMockAsset({ id: 'b' })
])
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'warn' })
)
})
it('shows a partial warning when extraction throws for one asset', async () => {
mockExtractWorkflowFromAsset
.mockResolvedValueOnce({
workflow: { version: 0.4 },
filename: 'ok.json'
})
.mockRejectedValueOnce(new Error('missing workflow'))
mockOpenWorkflowAction.mockResolvedValue({ success: true })
const actions = useMediaAssetActions()
await actions.openMultipleWorkflows([
createMockAsset({ id: 'a' }),
createMockAsset({ id: 'b' })
])
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'warn' })
)
})
})
describe('exportMultipleWorkflows', () => {
beforeEach(() => {
mockExtractWorkflowFromAsset.mockResolvedValue({
workflow: { version: 0.4 },
filename: 'workflow.json'
})
})
it('does nothing for an empty selection', async () => {
const actions = useMediaAssetActions()
await actions.exportMultipleWorkflows([])
expect(mockExportWorkflowAction).not.toHaveBeenCalled()
expect(useToast().add).not.toHaveBeenCalled()
})
it('shows no-workflows warning when every export fails', async () => {
mockExportWorkflowAction.mockResolvedValue({
success: false,
error: 'boom'
})
const actions = useMediaAssetActions()
await actions.exportMultipleWorkflows([
createMockAsset({ id: 'a' }),
createMockAsset({ id: 'b' })
])
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'warn' })
)
})
it('counts extraction failures as failed exports', async () => {
mockExtractWorkflowFromAsset.mockRejectedValue(new Error('missing'))
const actions = useMediaAssetActions()
await actions.exportMultipleWorkflows([createMockAsset()])
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'warn' })
)
})
})
describe('downloadAssets', () => {
it('downloads the injected media asset when called without explicit assets', () => {
const mediaAsset = createMockMediaAsset({
@@ -534,6 +977,36 @@ describe('useMediaAssetActions', () => {
unmount()
})
it('uses the asset URL when no preview URL is available', () => {
mockGetAssetType.mockReturnValue('input')
const asset = createMockAsset({
name: 'raw image.png',
preview_url: undefined,
user_metadata: { subfolder: 'uploads' }
})
const actions = useMediaAssetActions()
actions.downloadAssets([asset])
expect(mockDownloadFile).toHaveBeenCalledWith(
'http://localhost:8188/api/view?filename=raw+image.png&type=input&subfolder=uploads',
'raw image.png'
)
})
it('shows an error toast when a direct download throws', () => {
mockDownloadFile.mockImplementation(() => {
throw new Error('download failed')
})
const actions = useMediaAssetActions()
actions.downloadAssets([createMockAsset()])
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error' })
)
})
it('keeps single explicit assets on the direct download path in cloud', () => {
mockIsCloud.value = true
mockGetOutputAssetMetadata.mockReturnValue({
@@ -943,6 +1416,82 @@ describe('useMediaAssetActions', () => {
})
expect(payload.naming_strategy).toBe('preserve')
})
it('should include asset ids for imported assets', async () => {
mockGetAssetType.mockImplementation((asset: AssetItem) =>
asset.tags?.includes('output') ? 'output' : 'input'
)
const asset1 = createMockAsset({ id: 'input-1', tags: ['input'] })
const asset2 = createMockAsset({ id: 'input-2', tags: ['input'] })
const actions = useMediaAssetActions()
actions.downloadAssets([asset1, asset2])
await vi.waitFor(() => {
expect(mockCreateAssetExport).toHaveBeenCalledTimes(1)
})
const payload = mockCreateAssetExport.mock.calls[0][0]
expect(payload.job_ids).toBeUndefined()
expect(payload.asset_ids).toEqual(['input-1', 'input-2'])
expect(payload.naming_strategy).toBe('preserve')
})
it('should mix output job ids and imported asset ids', async () => {
mockGetAssetType.mockImplementation((asset: AssetItem) =>
asset.tags?.includes('output') ? 'output' : 'input'
)
const output = createMockAsset({
id: 'history-id',
name: 'output.png',
tags: ['output']
})
const imported = createMockAsset({ id: 'input-id', tags: ['input'] })
const actions = useMediaAssetActions()
actions.downloadAssets([output, imported])
await vi.waitFor(() => {
expect(mockCreateAssetExport).toHaveBeenCalledTimes(1)
})
const payload = mockCreateAssetExport.mock.calls[0][0]
expect(payload.job_ids).toEqual(['history-id'])
expect(payload.asset_ids).toEqual(['input-id'])
})
it('should only include a filtered output name once', async () => {
const asset1 = createOutputAsset('a1', 'same.png', 'job1')
const asset2 = createOutputAsset('a2', 'same.png', 'job1')
const actions = useMediaAssetActions()
actions.downloadAssets([asset1, asset2])
await vi.waitFor(() => {
expect(mockCreateAssetExport).toHaveBeenCalledTimes(1)
})
const payload = mockCreateAssetExport.mock.calls[0][0]
expect(payload.job_asset_name_filters).toEqual({
job1: ['same.png']
})
})
it('should show an error toast when ZIP export creation fails', async () => {
mockCreateAssetExport.mockRejectedValueOnce(new Error('export failed'))
const asset1 = createOutputAsset('a1', 'img1.png', 'job1')
const asset2 = createOutputAsset('a2', 'img2.png', 'job2')
const actions = useMediaAssetActions()
actions.downloadAssets([asset1, asset2])
await vi.waitFor(() => {
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error' })
)
})
expect(mockTrackExport).not.toHaveBeenCalled()
})
})
describe('downloadAssets - export toast file count', () => {
@@ -1033,6 +1582,200 @@ describe('useMediaAssetActions', () => {
})
})
describe('deleteAssets', () => {
it('returns false for an empty selection', async () => {
const actions = useMediaAssetActions()
const result = await actions.deleteAssets([])
expect(result).toBe(false)
expect(mockShowDialog).not.toHaveBeenCalled()
})
it('returns false when the user cancels', async () => {
mockShowDialog.mockImplementation(
({ props }: { props: { onCancel: () => void } }) => {
props.onCancel()
}
)
const actions = useMediaAssetActions()
const result = await actions.deleteAssets(createMockAsset())
expect(result).toBe(false)
expect(mockDeleteAsset).not.toHaveBeenCalled()
})
it('rejects imported asset deletion outside cloud mode', async () => {
mockIsCloud.value = false
mockGetAssetType.mockReturnValue('input')
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
const actions = useMediaAssetActions()
await actions.deleteAssets(createMockAsset({ tags: ['input'] }))
await vi.waitFor(() => {
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error' })
)
})
expect(mockDeleteAsset).not.toHaveBeenCalled()
})
it('rejects output deletion when no job id can be resolved', async () => {
mockIsCloud.value = true
mockGetAssetType.mockReturnValue('output')
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
const actions = useMediaAssetActions()
await actions.deleteAssets(
createMockAsset({ id: '', name: 'orphan.png', tags: ['output'] })
)
await vi.waitFor(() => {
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error' })
)
})
expect(api.deleteItem).not.toHaveBeenCalled()
})
it('updates output history and input listings for mixed successful deletion', async () => {
mockIsCloud.value = true
mockGetAssetType.mockImplementation((asset: AssetItem) =>
asset.tags?.includes('output') ? 'output' : 'input'
)
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
const actions = useMediaAssetActions()
await actions.deleteAssets([
createMockAsset({ id: 'history-1', tags: ['output'] }),
createMockAsset({ id: 'input-1', tags: ['input'] })
])
await vi.waitFor(() => {
expect(mockUpdateHistory).toHaveBeenCalled()
})
expect(mockUpdateInputs).toHaveBeenCalled()
expect(api.deleteItem).toHaveBeenCalledWith('history', 'history-1')
expect(mockDeleteAsset).toHaveBeenCalledWith('input-1')
})
it('skips graph cleanup when there is no root graph', async () => {
mockIsCloud.value = true
mockGetAssetType.mockReturnValue('input')
mockAppGraph.value = null
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
const actions = useMediaAssetActions()
await actions.deleteAssets(createMockAsset({ tags: ['input'] }))
await vi.waitFor(() => {
expect(mockDeleteAsset).toHaveBeenCalled()
})
expect(mockClearNodePreviewCache).not.toHaveBeenCalled()
expect(mockClearWidgetValues).not.toHaveBeenCalled()
expect(mockCaptureCanvasState).not.toHaveBeenCalled()
})
it('uses temp widget-value variants when deleting temp assets', async () => {
mockIsCloud.value = true
mockGetAssetType.mockReturnValue('temp')
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
const actions = useMediaAssetActions()
await actions.deleteAssets(
createMockAsset({
id: 'temp-1',
name: 'preview.png',
hash: 'preview-hash.png',
tags: ['temp']
})
)
await vi.waitFor(() => {
expect(mockClearNodePreviewCache).toHaveBeenCalled()
})
const [, valuesArg] = mockClearNodePreviewCache.mock.calls[0]
expect(valuesArg).toEqual(
new Set(['preview.png [temp]', 'preview-hash.png'])
)
})
it('uses hash-only cleanup values when the asset name is empty', async () => {
mockIsCloud.value = true
mockGetAssetType.mockReturnValue('input')
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
const actions = useMediaAssetActions()
await actions.deleteAssets(
createMockAsset({
id: 'hash-only',
name: '',
hash: 'only-hash.png',
tags: ['input']
})
)
await vi.waitFor(() => {
expect(mockClearNodePreviewCache).toHaveBeenCalled()
})
const [, valuesArg] = mockClearNodePreviewCache.mock.calls[0]
expect(valuesArg).toEqual(new Set(['only-hash.png']))
})
it('shows a partial warning and cleans up only successfully deleted assets', async () => {
mockIsCloud.value = true
mockGetAssetType.mockReturnValue('input')
mockDeleteAsset
.mockResolvedValueOnce(undefined)
.mockRejectedValueOnce(new Error('delete failed'))
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
const actions = useMediaAssetActions()
await actions.deleteAssets([
createMockAsset({ id: 'ok', name: 'ok.png', tags: ['input'] }),
createMockAsset({ id: 'bad', name: 'bad.png', tags: ['input'] })
])
await vi.waitFor(() => {
expect(useToast().add).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'warn' })
)
})
const [, valuesArg] = mockClearNodePreviewCache.mock.calls[0]
expect(valuesArg).toEqual(new Set(['ok.png', 'ok.png [input]']))
})
})
describe('deleteAssets - model cache invalidation', () => {
beforeEach(() => {
mockIsCloud.value = true

View File

@@ -1,4 +1,5 @@
import { createTestingPinia } from '@pinia/testing'
import { fromPartial } from '@total-typescript/shoehorn'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { createApp, nextTick, ref } from 'vue'
@@ -14,6 +15,7 @@ vi.mock('@/platform/assets/services/assetService', () => ({
assetService: {
getAssetMetadata: vi.fn(),
uploadAssetAsync: vi.fn(),
uploadAssetFromBase64: vi.fn(),
uploadAssetPreviewImage: vi.fn()
}
}))
@@ -248,6 +250,81 @@ describe('useUploadModelWizard', () => {
expect(wizard.selectedModelType.value).toBe('checkpoints')
})
it('does not fetch metadata until the URL matches a supported source', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
const wizard = setupUploadModelWizard(modelTypes)
expect(wizard.canFetchMetadata.value).toBe(false)
await wizard.fetchMetadata()
expect(assetService.getAssetMetadata).not.toHaveBeenCalled()
expect(wizard.currentStep.value).toBe(1)
})
it('decodes metadata filenames and selects a matching model type tag', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.getAssetMetadata).mockResolvedValue({
content_length: 100,
final_url: 'https://huggingface.co/org/model',
filename: '%E6%A8%A1%E5%9E%8B.safetensors',
name: '%E5%90%8D%E7%A8%B1',
tags: ['checkpoints'],
preview_image: 'data:image/png;base64,abc'
})
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = ' https://huggingface.co/org/model '
await wizard.fetchMetadata()
expect(wizard.currentStep.value).toBe(2)
expect(wizard.wizardData.value.url).toBe('https://huggingface.co/org/model')
expect(wizard.wizardData.value.name).toBe('模型.safetensors')
expect(wizard.wizardData.value.previewImage).toBe(
'data:image/png;base64,abc'
)
expect(wizard.selectedModelType.value).toBe('checkpoints')
})
it('keeps metadata text when percent decoding fails', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.getAssetMetadata).mockResolvedValue({
content_length: 100,
final_url: 'https://civitai.com/models/12345',
filename: '%E0%A4%A',
name: '%E0%A4%A',
tags: []
})
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
await wizard.fetchMetadata()
expect(wizard.currentStep.value).toBe(2)
expect(wizard.wizardData.value.name).toBe('%E0%A4%A')
expect(wizard.selectedModelType.value).toBeUndefined()
})
it('uses the fallback metadata error for non-error rejections', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.getAssetMetadata).mockRejectedValue('no metadata')
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
await wizard.fetchMetadata()
expect(wizard.currentStep.value).toBe(1)
expect(wizard.uploadError.value).toBe(
'Failed to retrieve metadata. Please check the link and try again.'
)
})
it('uploads with the required model type even if selection changes', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
@@ -279,6 +356,382 @@ describe('useUploadModelWizard', () => {
expect(result?.modelType).toBe('checkpoints')
})
it('clears upload errors and type mismatches when the URL changes', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
type: 'sync',
asset: {
id: 'asset-lora',
name: 'model.safetensors',
tags: ['models', 'loras']
}
})
const wizard = setupUploadModelWizard(
ref([
{ name: 'Checkpoint', value: 'checkpoints' },
{ name: 'LoRA', value: 'loras' }
]),
{ requiredModelType: 'checkpoints' }
)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
await wizard.uploadModel()
expect(wizard.uploadTypeMismatch.value).not.toBeNull()
wizard.wizardData.value.url = 'https://civitai.com/models/54321'
await nextTick()
expect(wizard.uploadError.value).toBe('')
expect(wizard.uploadTypeMismatch.value).toBeNull()
})
it('returns null while another upload is in progress', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
type UploadResult = Awaited<
ReturnType<typeof assetService.uploadAssetAsync>
>
let resolveUpload!: (value: UploadResult) => void
vi.mocked(assetService.uploadAssetAsync).mockReturnValue(
new Promise<UploadResult>((resolve) => {
resolveUpload = resolve
})
)
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
wizard.selectedModelType.value = 'checkpoints'
const firstUpload = wizard.uploadModel()
await nextTick()
await expect(wizard.uploadModel()).resolves.toBeNull()
resolveUpload({
type: 'sync',
asset: {
id: 'asset-1',
name: 'model.safetensors',
tags: ['models', 'checkpoints']
}
})
await expect(firstUpload).resolves.toEqual(
expect.objectContaining({ status: 'success' })
)
})
it('returns null when no model type is selected', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
const result = await wizard.uploadModel()
expect(result).toBeNull()
expect(assetService.uploadAssetAsync).not.toHaveBeenCalled()
})
it('reports an upload error when no valid source is detected', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://example.com/model'
wizard.selectedModelType.value = 'checkpoints'
const result = await wizard.uploadModel()
expect(result).toBeNull()
expect(assetService.uploadAssetAsync).not.toHaveBeenCalled()
})
it('uploads preview images and passes the preview id to the model upload', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.uploadAssetFromBase64).mockResolvedValue(
fromPartial({ id: 'preview-1' })
)
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
type: 'sync',
asset: {
id: 'asset-1',
name: 'model.safetensors',
tags: ['models', 'checkpoints']
}
})
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
wizard.wizardData.value.metadata = {
content_length: 100,
final_url: 'https://civitai.com/models/12345',
filename: 'model.safetensors'
}
wizard.wizardData.value.previewImage = 'data:image/jpeg;base64,abc'
wizard.selectedModelType.value = 'checkpoints'
await wizard.uploadModel()
expect(assetService.uploadAssetFromBase64).toHaveBeenCalledWith({
data: 'data:image/jpeg;base64,abc',
name: 'model_preview.jpg',
tags: ['preview']
})
expect(assetService.uploadAssetAsync).toHaveBeenCalledWith(
expect.objectContaining({ preview_id: 'preview-1' })
)
})
it('continues model upload when preview upload fails', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.uploadAssetFromBase64).mockRejectedValue(
new Error('preview failed')
)
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
type: 'sync',
asset: {
id: 'asset-1',
name: 'model.safetensors',
tags: ['models', 'checkpoints']
}
})
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
wizard.wizardData.value.metadata = {
content_length: 100,
final_url: 'https://civitai.com/models/12345',
name: 'model'
}
wizard.wizardData.value.previewImage = 'data:image/webp;base64,abc'
wizard.selectedModelType.value = 'checkpoints'
await wizard.uploadModel()
expect(assetService.uploadAssetAsync).toHaveBeenCalledWith(
expect.objectContaining({ preview_id: undefined })
)
expect(wizard.uploadStatus.value).toBe('success')
})
it('treats an already completed async upload as success', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
type: 'async',
task: {
task_id: 'task-complete',
status: 'completed',
message: 'Download complete'
}
})
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
wizard.wizardData.value.metadata = {
content_length: 100,
final_url: 'https://civitai.com/models/12345',
filename: 'queued.safetensors'
}
wizard.selectedModelType.value = 'checkpoints'
const result = await wizard.uploadModel()
expect(result).toEqual({
filename: 'queued.safetensors',
modelType: 'checkpoints',
status: 'success'
})
expect(wizard.uploadStatus.value).toBe('success')
})
it('cleans up an immediately resolved async watcher', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
const { useAssetDownloadStore } =
await import('@/stores/assetDownloadStore')
const assetDownloadStore = useAssetDownloadStore()
assetDownloadStore.trackDownload(
'task-ready',
'checkpoints',
'ready.safetensors'
)
const { api } = await import('@/scripts/api')
const handler = vi
.mocked(api.addEventListener)
.mock.calls.find((c) => c[0] === 'asset_download')?.[1] as
| ((e: CustomEvent) => void)
| undefined
expect(handler).toBeDefined()
handler!(
new CustomEvent('asset_download', {
detail: {
task_id: 'task-ready',
asset_id: 'asset-ready',
asset_name: 'ready.safetensors',
bytes_total: 100,
bytes_downloaded: 100,
progress: 100,
status: 'completed' as const
}
})
)
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
type: 'async',
task: {
task_id: 'task-ready',
status: 'created',
message: 'Download queued'
}
})
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
wizard.selectedModelType.value = 'checkpoints'
await wizard.uploadModel()
await nextTick()
expect(wizard.uploadStatus.value).toBe('success')
})
it('uses the default failed-download message when no error is available', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
type: 'async',
task: {
task_id: 'task-fallback-fail',
status: 'created',
message: 'Download queued'
}
})
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
wizard.selectedModelType.value = 'checkpoints'
await wizard.uploadModel()
const { api } = await import('@/scripts/api')
const handler = vi
.mocked(api.addEventListener)
.mock.calls.find((c) => c[0] === 'asset_download')?.[1] as
| ((e: CustomEvent) => void)
| undefined
expect(handler).toBeDefined()
handler!(
new CustomEvent('asset_download', {
detail: {
task_id: 'task-fallback-fail',
asset_id: '',
asset_name: '',
bytes_total: 1000,
bytes_downloaded: 500,
progress: 50,
status: 'failed' as const
}
})
)
await nextTick()
expect(wizard.uploadStatus.value).toBe('error')
expect(wizard.uploadError.value).toBe('assetBrowser.downloadFailed')
})
it('uses fallback labels for unknown mismatch types', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
type: 'sync',
asset: {
id: 'asset-unknown',
name: 'model.safetensors',
tags: ['models']
}
})
const wizard = setupUploadModelWizard(modelTypes, {
requiredModelType: 'unknown-required'
})
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
const result = await wizard.uploadModel()
expect(result).toBeNull()
expect(wizard.uploadTypeMismatch.value).toEqual({
importedModelType: undefined,
importedModelTypeLabel: undefined,
requiredModelType: 'unknown-required',
requiredModelTypeLabel: 'unknown-required'
})
})
it('uses a generic upload error for non-error upload failures', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.uploadAssetAsync).mockRejectedValue('failed')
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
wizard.selectedModelType.value = 'checkpoints'
const result = await wizard.uploadModel()
expect(result).toBeNull()
expect(wizard.uploadStatus.value).toBe('error')
expect(wizard.uploadError.value).toBe('Failed to upload model')
})
it('navigates backward only after the first step', () => {
const wizard = setupUploadModelWizard(modelTypes)
wizard.goToPreviousStep()
expect(wizard.currentStep.value).toBe(1)
wizard.currentStep.value = 3
wizard.goToPreviousStep()
expect(wizard.currentStep.value).toBe(2)
})
it('resets wizard state and cancels pending async status watching', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')
vi.mocked(assetService.uploadAssetAsync).mockResolvedValue({
type: 'async',
task: {
task_id: 'task-reset',
status: 'created',
message: 'Download queued'
}
})
const wizard = setupUploadModelWizard(modelTypes)
wizard.wizardData.value.url = 'https://civitai.com/models/12345'
wizard.wizardData.value.name = 'Model'
wizard.wizardData.value.tags = ['checkpoints']
wizard.selectedModelType.value = 'checkpoints'
await wizard.uploadModel()
wizard.resetWizard()
expect(wizard.currentStep.value).toBe(1)
expect(wizard.uploadStatus.value).toBeUndefined()
expect(wizard.uploadError.value).toBe('')
expect(wizard.wizardData.value).toEqual({
url: '',
name: '',
tags: []
})
expect(wizard.selectedModelType.value).toBeUndefined()
})
it('returns the synced asset filename for sync imports', async () => {
const { assetService } =
await import('@/platform/assets/services/assetService')

View File

@@ -12,6 +12,7 @@ import { api } from '@/scripts/api'
const mockDistributionState = vi.hoisted(() => ({ isCloud: false }))
const mockSettingStoreGet = vi.hoisted(() => vi.fn(() => false))
const mockGetCategoryForNodeType = vi.hoisted(() => vi.fn())
vi.mock('@/platform/distribution/types', () => ({
get isCloud() {
@@ -33,7 +34,7 @@ vi.mock('@/stores/modelToNodeStore', () => {
return {
useModelToNodeStore: vi.fn(() => ({
getRegisteredNodeTypes: () => registeredNodeTypes,
getCategoryForNodeType: vi.fn()
getCategoryForNodeType: mockGetCategoryForNodeType
}))
}
})
@@ -172,6 +173,28 @@ describe(assetService.getAssetMetadata, () => {
).rejects.toThrow('File too large')
})
it('falls back to the unknown localized message for unrecognized error codes', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse({ code: 'NOT_A_REAL_CODE' }, { ok: false, status: 400 })
)
await expect(
assetService.getAssetMetadata('https://example.com/model.safetensors')
).rejects.toThrow('Unknown error')
})
it('falls back to unknown when error JSON cannot be parsed', async () => {
fetchApiMock.mockResolvedValueOnce({
ok: false,
status: 400,
json: vi.fn().mockRejectedValue(new Error('bad json'))
} as unknown as Response)
await expect(
assetService.getAssetMetadata('https://example.com/model.safetensors')
).rejects.toThrow('Unknown error')
})
it('throws a localized message when validation reports is_valid=false', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse({
@@ -189,6 +212,20 @@ describe(assetService.getAssetMetadata, () => {
).rejects.toThrow('Unsafe virus scan')
})
it('falls back to unknown when validation errors are absent', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse({
content_length: 100,
final_url: 'https://example.com/model.safetensors',
validation: { is_valid: false }
})
)
await expect(
assetService.getAssetMetadata('https://example.com/model.safetensors')
).rejects.toThrow('Unknown error')
})
it('encodes the URL in the query string', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse({
@@ -208,12 +245,115 @@ describe(assetService.getAssetMetadata, () => {
})
})
describe(assetService.getAssetsForNodeType, () => {
beforeEach(() => {
vi.clearAllMocks()
mockGetCategoryForNodeType.mockReset()
})
it('returns an empty list for invalid node types without fetching', async () => {
await expect(assetService.getAssetsForNodeType('')).resolves.toEqual([])
expect(fetchApiMock).not.toHaveBeenCalled()
})
it('returns an empty list when the node type has no asset category', async () => {
mockGetCategoryForNodeType.mockReturnValue(undefined)
await expect(
assetService.getAssetsForNodeType('UnknownNode')
).resolves.toEqual([])
expect(fetchApiMock).not.toHaveBeenCalled()
})
it('fetches category assets with default pagination', async () => {
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
const assets = [
validAsset({ id: 'ckpt-1', tags: ['models', 'checkpoints'] })
]
fetchApiMock.mockResolvedValueOnce(buildAssetListResponse(assets))
await expect(
assetService.getAssetsForNodeType('CheckpointLoaderSimple')
).resolves.toEqual(assets)
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
const params = new URL(requestedUrl, 'http://localhost').searchParams
expect(params.get('include_tags')).toBe('models,checkpoints')
expect(params.get('limit')).toBe('500')
expect(params.has('offset')).toBe(false)
})
it('passes positive offsets for category asset pagination', async () => {
mockGetCategoryForNodeType.mockReturnValue('loras')
fetchApiMock.mockResolvedValueOnce(buildAssetListResponse([]))
await assetService.getAssetsForNodeType('LoraLoader', {
limit: 25,
offset: 50
})
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
const params = new URL(requestedUrl, 'http://localhost').searchParams
expect(params.get('include_tags')).toBe('models,loras')
expect(params.get('limit')).toBe('25')
expect(params.get('offset')).toBe('50')
})
})
describe(assetService.getAssetDetails, () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('throws when the details response is not ok', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse({}, { ok: false, status: 404 })
)
await expect(assetService.getAssetDetails('missing')).rejects.toThrow(
'Unable to load asset details for missing: Server returned 404'
)
})
it('throws when the details response is invalid', async () => {
fetchApiMock.mockResolvedValueOnce(buildResponse({ id: 'asset-1' }))
await expect(assetService.getAssetDetails('asset-1')).rejects.toThrow(
/Invalid asset response/
)
})
it('returns validated asset details', async () => {
const asset = validAsset({ id: 'asset-details' })
fetchApiMock.mockResolvedValueOnce(buildResponse(asset))
await expect(
assetService.getAssetDetails('asset-details')
).resolves.toEqual(asset)
})
})
describe(assetService.uploadAssetFromUrl, () => {
beforeEach(() => {
vi.clearAllMocks()
assetService.invalidateInputAssetsIncludingPublic()
})
it('throws when URL upload returns a non-ok response', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse(null, { ok: false, status: 500 })
)
await expect(
assetService.uploadAssetFromUrl({
url: 'https://example.com/input.png',
name: 'input.png'
})
).rejects.toThrow('Failed to upload asset')
})
it('does not invalidate cached input assets when the upload response is invalid', async () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
@@ -294,6 +434,61 @@ describe(assetService.uploadAssetFromBase64, () => {
expect(fetchApiMock).not.toHaveBeenCalled()
})
it('throws when base64 upload returns a non-ok response', async () => {
const fetchSpy = vi
.spyOn(globalThis, 'fetch')
.mockResolvedValueOnce(new Response('hello'))
try {
fetchApiMock.mockResolvedValueOnce(
buildResponse(null, { ok: false, status: 507 })
)
await expect(
assetService.uploadAssetFromBase64({
data: 'data:text/plain;base64,aGVsbG8=',
name: 'input.txt'
})
).rejects.toThrow('Failed to upload asset from base64: 507')
} finally {
fetchSpy.mockRestore()
}
})
it('posts base64 uploads with tags and user metadata', async () => {
const uploadedAsset = {
...validAsset({ id: 'uploaded-input', tags: ['input'] }),
created_new: false
}
const fetchSpy = vi
.spyOn(globalThis, 'fetch')
.mockResolvedValueOnce(new Response('hello'))
try {
fetchApiMock.mockResolvedValueOnce(buildResponse(uploadedAsset))
const result = await assetService.uploadAssetFromBase64({
data: 'data:text/plain;base64,aGVsbG8=',
name: 'input.txt',
tags: ['input', 'mask'],
user_metadata: { source: 'paste' }
})
expect(result).toEqual(uploadedAsset)
const request = fetchApiMock.mock.calls[0]?.[1]
expect(request).toEqual(expect.objectContaining({ method: 'POST' }))
expect(request?.body).toBeInstanceOf(FormData)
const formData = request?.body
if (!(formData instanceof FormData)) {
throw new Error('Expected base64 upload body to be FormData')
}
expect(formData.get('tags')).toBe(JSON.stringify(['input', 'mask']))
expect(formData.get('user_metadata')).toBe(
JSON.stringify({ source: 'paste' })
)
} finally {
fetchSpy.mockRestore()
}
})
it('does not invalidate cached input assets when the upload response is invalid', async () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
@@ -355,6 +550,7 @@ describe(assetService.uploadAssetFromBase64, () => {
describe(assetService.uploadAssetAsync, () => {
beforeEach(() => {
vi.clearAllMocks()
assetService.invalidateInputAssetsIncludingPublic()
})
it('returns an async result when the server responds 202', async () => {
@@ -389,6 +585,64 @@ describe(assetService.uploadAssetAsync, () => {
asset: expect.objectContaining({ id: 'asset-2' })
})
})
it('throws when the async upload response is not ok', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse(null, { ok: false, status: 502 })
)
await expect(
assetService.uploadAssetAsync({
source_url: 'https://example.com/model.safetensors'
})
).rejects.toThrow('Failed to upload asset')
})
it('throws when an async upload task response is invalid', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse({ task_id: 'task-1', status: 'waiting' }, { status: 202 })
)
await expect(
assetService.uploadAssetAsync({
source_url: 'https://example.com/model.safetensors'
})
).rejects.toThrow('Failed to parse async upload response')
})
it('throws when a sync upload asset response is invalid', async () => {
fetchApiMock.mockResolvedValueOnce(buildResponse({ id: 'asset-2' }))
await expect(
assetService.uploadAssetAsync({
source_url: 'https://example.com/model.safetensors'
})
).rejects.toThrow('Failed to parse sync upload response')
})
it('invalidates cached input assets for completed async input uploads', async () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
const freshAssets = [validAsset({ id: 'fresh-input', tags: ['input'] })]
fetchApiMock
.mockResolvedValueOnce(buildAssetListResponse(staleAssets))
.mockResolvedValueOnce(
buildResponse(
{ task_id: 'task-1', status: 'completed' },
{ ok: true, status: 202 }
)
)
.mockResolvedValueOnce(buildAssetListResponse(freshAssets))
await assetService.getInputAssetsIncludingPublic()
await assetService.uploadAssetAsync({
source_url: 'https://example.com/input.png',
tags: ['input']
})
const refreshed = await assetService.getInputAssetsIncludingPublic()
expect(refreshed).toEqual(freshAssets)
expect(fetchApiMock).toHaveBeenCalledTimes(3)
})
})
describe(assetService.deleteAsset, () => {
@@ -416,6 +670,94 @@ describe(assetService.deleteAsset, () => {
})
})
describe(assetService.addAssetTags, () => {
beforeEach(() => {
vi.clearAllMocks()
assetService.invalidateInputAssetsIncludingPublic()
})
it('posts tags and returns the parsed tag operation result', async () => {
const result = { total_tags: ['input', 'mask'], added: ['mask'] }
fetchApiMock.mockResolvedValueOnce(buildResponse(result))
await expect(
assetService.addAssetTags('asset-1', ['mask'])
).resolves.toEqual(result)
expect(fetchApiMock).toHaveBeenCalledWith(
'/assets/asset-1/tags',
expect.objectContaining({
method: 'POST',
body: JSON.stringify({ tags: ['mask'] })
})
)
})
it('throws when adding tags fails', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse(null, { ok: false, status: 403 })
)
await expect(
assetService.addAssetTags('asset-1', ['mask'])
).rejects.toThrow(
'Unable to add tags to asset asset-1: Server returned 403'
)
})
it('throws when the add-tags response is invalid', async () => {
fetchApiMock.mockResolvedValueOnce(buildResponse({ added: ['mask'] }))
await expect(
assetService.addAssetTags('asset-1', ['mask'])
).rejects.toThrow()
})
})
describe(assetService.removeAssetTags, () => {
beforeEach(() => {
vi.clearAllMocks()
assetService.invalidateInputAssetsIncludingPublic()
})
it('deletes tags and returns the parsed tag operation result', async () => {
const result = { total_tags: ['input'], removed: ['mask'] }
fetchApiMock.mockResolvedValueOnce(buildResponse(result))
await expect(
assetService.removeAssetTags('asset-1', ['mask'])
).resolves.toEqual(result)
expect(fetchApiMock).toHaveBeenCalledWith(
'/assets/asset-1/tags',
expect.objectContaining({
method: 'DELETE',
body: JSON.stringify({ tags: ['mask'] })
})
)
})
it('throws when removing tags fails', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse(null, { ok: false, status: 404 })
)
await expect(
assetService.removeAssetTags('asset-1', ['mask'])
).rejects.toThrow(
'Unable to remove tags from asset asset-1: Server returned 404'
)
})
it('throws when the remove-tags response is invalid', async () => {
fetchApiMock.mockResolvedValueOnce(buildResponse({ removed: ['mask'] }))
await expect(
assetService.removeAssetTags('asset-1', ['mask'])
).rejects.toThrow()
})
})
describe(assetService.getAssetModelFolders, () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -481,6 +823,16 @@ describe(assetService.updateAsset, () => {
})
)
})
it('throws when the update response is not ok', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse(null, { ok: false, status: 409 })
)
await expect(
assetService.updateAsset('asset-1', { name: 'renamed.safetensors' })
).rejects.toThrow('Unable to update asset asset-1: Server returned 409')
})
})
describe(assetService.getAssetsByTag, () => {
@@ -515,6 +867,21 @@ describe(assetService.getAssetsByTag, () => {
expect(params.get('include_tags')).toBe('input')
expect(params.get('exclude_tags')).toBe(MISSING_TAG)
})
it('forwards explicit public filtering and offset pagination', async () => {
fetchApiMock.mockResolvedValueOnce(buildAssetListResponse([]))
await assetService.getAssetsByTag('input', false, {
limit: 30,
offset: 60
})
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
const params = new URL(requestedUrl, 'http://localhost').searchParams
expect(params.get('include_public')).toBe('false')
expect(params.get('limit')).toBe('30')
expect(params.get('offset')).toBe('60')
})
})
describe(assetService.getAllAssetsByTag, () => {
@@ -562,6 +929,31 @@ describe(assetService.getAllAssetsByTag, () => {
expect(secondParams.has('offset')).toBe(false)
})
it('uses the default page size when limit is not positive', async () => {
fetchApiMock.mockResolvedValueOnce(buildAssetListResponse([]))
await expect(
assetService.getAllAssetsByTag('input', true, { limit: 0 })
).resolves.toEqual([])
const requestedUrl = fetchApiMock.mock.calls[0]?.[0] as string
const params = new URL(requestedUrl, 'http://localhost').searchParams
expect(params.get('limit')).toBe('500')
})
it('throws before fetching when the pagination signal is already aborted', async () => {
const controller = new AbortController()
controller.abort()
await expect(
assetService.getAllAssetsByTag('input', true, {
signal: controller.signal
})
).rejects.toMatchObject({ name: 'AbortError' })
expect(fetchApiMock).not.toHaveBeenCalled()
})
it('honors has_more when walking tagged asset pages', async () => {
fetchApiMock
.mockResolvedValueOnce(
@@ -703,6 +1095,75 @@ describe(assetService.getAllAssetsByTag, () => {
})
})
describe(assetService.createAssetExport, () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('posts export options and returns the export task', async () => {
const task = { task_id: 'export-1', status: 'created', message: 'queued' }
fetchApiMock.mockResolvedValueOnce(buildResponse(task))
await expect(
assetService.createAssetExport({
asset_ids: ['asset-1'],
include_previews: true
})
).resolves.toEqual(task)
expect(fetchApiMock).toHaveBeenCalledWith(
'/assets/export',
expect.objectContaining({
method: 'POST',
body: JSON.stringify({
asset_ids: ['asset-1'],
include_previews: true
})
})
)
})
it('throws when creating an export fails', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse(null, { ok: false, status: 503 })
)
await expect(
assetService.createAssetExport({ asset_ids: ['asset-1'] })
).rejects.toThrow('Failed to create asset export: 503')
})
})
describe(assetService.getExportDownloadUrl, () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('returns the export download URL', async () => {
const download = {
url: 'https://example.com/export.zip',
expires_at: '2026-07-01T00:00:00Z'
}
fetchApiMock.mockResolvedValueOnce(buildResponse(download))
await expect(
assetService.getExportDownloadUrl('export.zip')
).resolves.toEqual(download)
expect(fetchApiMock).toHaveBeenCalledWith('/assets/exports/export.zip')
})
it('throws when export download URL lookup fails', async () => {
fetchApiMock.mockResolvedValueOnce(
buildResponse(null, { ok: false, status: 404 })
)
await expect(
assetService.getExportDownloadUrl('missing.zip')
).rejects.toThrow('Failed to get export download URL: 404')
})
})
describe(assetService.getInputAssetsIncludingPublic, () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -729,6 +1190,17 @@ describe(assetService.getInputAssetsIncludingPublic, () => {
expect(params.get('limit')).toBe('500')
})
it('throws before starting a shared request when the caller signal is already aborted', async () => {
const controller = new AbortController()
controller.abort()
await expect(
assetService.getInputAssetsIncludingPublic(controller.signal)
).rejects.toMatchObject({ name: 'AbortError' })
expect(fetchApiMock).not.toHaveBeenCalled()
})
it('fetches fresh input assets after explicit invalidation', async () => {
const staleAssets = [validAsset({ id: 'stale-input', tags: ['input'] })]
const freshAssets = [validAsset({ id: 'fresh-input', tags: ['input'] })]

View File

@@ -0,0 +1,37 @@
import { describe, expect, it } from 'vitest'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { getAssetType } from '@/platform/assets/utils/assetTypeUtil'
function asset(overrides: Partial<AssetItem> = {}): AssetItem {
return {
id: 'asset-1',
name: 'image.png',
preview_url: '',
tags: [],
created_at: '',
updated_at: '',
size: 0,
mime_type: 'image/png',
user_metadata: {},
...overrides
} as AssetItem
}
describe('getAssetType', () => {
it('prefers the preview URL type over tags', () => {
expect(
getAssetType(
asset({
preview_url: '/api/view?filename=image.png&type=temp',
tags: ['output']
})
)
).toBe('temp')
})
it('falls back to tags and then the supplied default type', () => {
expect(getAssetType(asset({ tags: ['input'] }))).toBe('input')
expect(getAssetType(asset(), 'input')).toBe('input')
})
})

View File

@@ -0,0 +1,62 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { getAssetUrl } from '@/platform/assets/utils/assetUrlUtil'
const { apiURL } = vi.hoisted(() => ({
apiURL: vi.fn((path: string) => `https://comfy.local${path}`)
}))
vi.mock('@/scripts/api', () => ({
api: { apiURL }
}))
function asset(overrides: Partial<AssetItem> = {}): AssetItem {
return {
id: 'asset-1',
name: 'folder image.png',
preview_url: '',
tags: ['output'],
created_at: '',
updated_at: '',
size: 0,
mime_type: 'image/png',
user_metadata: {},
...overrides
} as AssetItem
}
beforeEach(() => {
apiURL.mockClear()
})
describe('getAssetUrl', () => {
it('builds encoded view URLs with type and subfolder', () => {
const url = getAssetUrl(
asset({
user_metadata: { subfolder: 'nested/path' }
})
)
expect(apiURL).toHaveBeenCalledWith(
'/view?filename=folder+image.png&type=output&subfolder=nested%2Fpath'
)
expect(url).toBe(
'https://comfy.local/view?filename=folder+image.png&type=output&subfolder=nested%2Fpath'
)
})
it('uses preview URL type and omits empty subfolders', () => {
getAssetUrl(
asset({
preview_url: '/api/view?filename=image.png&type=temp',
tags: ['output'],
user_metadata: { subfolder: '' }
})
)
expect(apiURL).toHaveBeenCalledWith(
'/view?filename=folder+image.png&type=temp'
)
})
})

View File

@@ -1,4 +1,5 @@
import { createTestingPinia } from '@pinia/testing'
import { fromPartial } from '@total-typescript/shoehorn'
import { setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
@@ -28,6 +29,8 @@ interface HostAssetWidget extends IBaseWidget<
node: LGraphNode
}
type AssetWidget = IBaseWidget<string | undefined, 'asset', IWidgetAssetOptions>
type OnWidgetChanged = NonNullable<LGraphNode['onWidgetChanged']>
function checkpointAsset(name: string): AssetItem {
@@ -166,4 +169,118 @@ describe('createAssetWidget', () => {
)
expect(captureCanvasState).toHaveBeenCalledOnce()
})
it('falls back to widget name and empty current value for cloned widgets', async () => {
const { node } = createAssetWidgetNode()
const sourceWidget = createAssetWidget({
node,
widgetName: 'lora_name',
nodeTypeForBrowser: 'LoraLoader'
})
assertAssetOptions(sourceWidget.options)
const clonedWidget: AssetWidget = {
type: 'asset',
name: 'lora_name',
value: undefined,
options: sourceWidget.options,
y: 0
}
await sourceWidget.options.openModal(clonedWidget)
expect(firstShowOptions()).toMatchObject({
nodeType: 'LoraLoader',
inputName: 'lora_name',
currentValue: ''
})
})
it('rejects malformed asset selections', async () => {
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
const { node } = createAssetWidgetNode()
const widget = createAssetWidget({
node,
widgetName: 'ckpt_name',
nodeTypeForBrowser: 'CheckpointLoaderSimple',
defaultValue: 'fake_model.safetensors'
})
assertAssetOptions(widget.options)
await widget.options.openModal(widget)
firstShowOptions().onAssetSelected?.(
fromPartial({ id: 'asset-without-name' })
)
expect(widget.value).toBe('fake_model.safetensors')
expect(captureCanvasState).not.toHaveBeenCalled()
consoleError.mockRestore()
})
it('rejects invalid asset filenames', async () => {
const consoleError = vi.spyOn(console, 'error').mockImplementation(() => {})
const { node } = createAssetWidgetNode()
const widget = createAssetWidget({
node,
widgetName: 'ckpt_name',
nodeTypeForBrowser: 'CheckpointLoaderSimple',
defaultValue: 'fake_model.safetensors'
})
assertAssetOptions(widget.options)
await widget.options.openModal(widget)
firstShowOptions().onAssetSelected?.(checkpointAsset('../bad.safetensors'))
expect(widget.value).toBe('fake_model.safetensors')
expect(captureCanvasState).not.toHaveBeenCalled()
consoleError.mockRestore()
})
it('updates ownerless cloned widgets without node callbacks', async () => {
const { node, onWidgetChanged } = createAssetWidgetNode()
const sourceWidget = createAssetWidget({
node,
widgetName: 'ckpt_name',
nodeTypeForBrowser: 'CheckpointLoaderSimple',
defaultValue: 'fake_model.safetensors'
})
assertAssetOptions(sourceWidget.options)
const callback = vi.fn<NonNullable<IBaseWidget['callback']>>()
const clonedWidget: AssetWidget = {
type: 'asset',
name: 'ckpt_name',
value: 'fake_model.safetensors',
callback,
options: sourceWidget.options,
y: 0
}
await sourceWidget.options.openModal(clonedWidget)
firstShowOptions().onAssetSelected?.(
checkpointAsset('real_model.safetensors')
)
expect(clonedWidget.value).toBe('real_model.safetensors')
expect(callback).toHaveBeenCalledWith('real_model.safetensors')
expect(onWidgetChanged).not.toHaveBeenCalled()
expect(captureCanvasState).toHaveBeenCalledOnce()
})
it('does not capture canvas state when the selection is unchanged', async () => {
const { node } = createAssetWidgetNode()
const widget = createAssetWidget({
node,
widgetName: 'ckpt_name',
nodeTypeForBrowser: 'CheckpointLoaderSimple',
defaultValue: 'fake_model.safetensors'
})
assertAssetOptions(widget.options)
await widget.options.openModal(widget)
firstShowOptions().onAssetSelected?.(
checkpointAsset('fake_model.safetensors')
)
expect(widget.value).toBe('fake_model.safetensors')
expect(captureCanvasState).not.toHaveBeenCalled()
})
})

View File

@@ -140,7 +140,10 @@ describe('CloudSubscriptionRedirectView', () => {
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
'creator',
'monthly',
false
{
openInNewTab: false,
paymentIntentSource: 'deep_link'
}
)
// Shows loading affordances
@@ -169,7 +172,10 @@ describe('CloudSubscriptionRedirectView', () => {
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
'creator',
'monthly',
false
{
openInNewTab: false,
paymentIntentSource: 'deep_link'
}
)
})
@@ -180,7 +186,8 @@ describe('CloudSubscriptionRedirectView', () => {
expect(screen.getByText('Subscribe to Team Plan')).toBeInTheDocument()
expect(mockPerformTeamSubscriptionCheckout).toHaveBeenCalledWith(
'team_700',
'yearly'
'yearly',
{ paymentIntentSource: 'deep_link' }
)
// Team never goes through the personal checkout path
expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled()

View File

@@ -94,7 +94,9 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
return
}
isTeamCheckout.value = true
await performTeamSubscriptionCheckout(stopId, billingCycle)
await performTeamSubscriptionCheckout(stopId, billingCycle, {
paymentIntentSource: 'deep_link'
})
return
}
@@ -112,7 +114,10 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
if (isActiveSubscription.value) {
await accessBillingPortal(undefined, false)
} else {
await performSubscriptionCheckout(tierKeyParam, billingCycle, false)
await performSubscriptionCheckout(tierKeyParam, billingCycle, {
openInNewTab: false,
paymentIntentSource: 'deep_link'
})
}
}, reportError)

View File

@@ -351,12 +351,12 @@ const handleRefresh = wrapWithErrorHandlingAsync(async () => {
})
function handleAddCredits() {
telemetry?.trackAddApiCreditButtonClicked()
telemetry?.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
void dialogService.showTopUpCreditsDialog()
}
function handleUpgradeToAddCredits() {
showPricingTable()
showPricingTable({ reason: 'upgrade_to_add_credits' })
}
async function handleWindowFocus() {

View File

@@ -5,6 +5,8 @@ import { render, screen } from '@testing-library/vue'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import FreeTierDialogContent from './FreeTierDialogContent.vue'
const mockRenewalDate = vi.hoisted(() => ({ value: null as string | null }))
@@ -15,7 +17,7 @@ vi.mock('@/composables/billing/useBillingContext', () => ({
}))
}))
function renderComponent() {
function renderComponent(props?: { reason?: PaymentIntentSource }) {
const i18n = createI18n({
legacy: false,
locale: 'en',
@@ -23,6 +25,7 @@ function renderComponent() {
})
return render(FreeTierDialogContent, {
props,
global: {
plugins: [i18n]
}
@@ -43,4 +46,18 @@ describe('FreeTierDialogContent', () => {
renderComponent()
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
})
it('keeps the generic copy for intent reasons outside the credits variants', () => {
mockRenewalDate.value = '2026-07-15T10:00:00Z'
renderComponent({ reason: 'subscribe_to_run' })
expect(
screen.getByText('Your credits refresh on Jul 15, 2026.')
).toBeInTheDocument()
})
it('swaps to the out-of-credits copy without the refresh line', () => {
mockRenewalDate.value = '2026-07-15T10:00:00Z'
renderComponent({ reason: 'out_of_credits' })
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
})
})

View File

@@ -52,7 +52,7 @@
</p>
<p
v-if="!reason || reason === 'subscription_required'"
v-if="!isCreditsBlockedVariant"
class="m-0 text-sm text-text-secondary"
>
{{
@@ -65,10 +65,7 @@
</p>
<p
v-if="
(!reason || reason === 'subscription_required') &&
formattedRenewalDate
"
v-if="!isCreditsBlockedVariant && formattedRenewalDate"
class="m-0 text-sm text-text-secondary"
>
{{
@@ -88,7 +85,7 @@
@click="$emit('upgrade')"
>
{{
reason === 'out_of_credits' || reason === 'top_up_blocked'
isCreditsBlockedVariant
? $t('subscription.freeTier.upgradeCta')
: $t('subscription.freeTier.subscribeCta')
}}
@@ -103,12 +100,12 @@ import { computed } from 'vue'
import Button from '@/components/ui/button/Button.vue'
import { useBillingContext } from '@/composables/billing/useBillingContext'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import SubscriptionBenefits from '@/platform/cloud/subscription/components/SubscriptionBenefits.vue'
import { getTierCredits } from '@/platform/cloud/subscription/constants/tierPricing'
defineProps<{
reason?: SubscriptionDialogReason
const { reason } = defineProps<{
reason?: PaymentIntentSource
}>()
defineEmits<{
@@ -129,4 +126,10 @@ const formattedRenewalDate = computed(() => {
})
const freeTierCredits = computed(() => getTierCredits('free'))
// Only these two variants replace the generic free-tier copy; any other
// intent reason (subscribe_to_run, deep_link, ...) keeps the default pitch.
const isCreditsBlockedVariant = computed(
() => reason === 'out_of_credits' || reason === 'top_up_blocked'
)
</script>

View File

@@ -261,6 +261,7 @@ describe('PricingTable', () => {
tier: 'creator',
cycle: 'yearly',
checkout_type: 'change',
checkout_attempt_id: expect.any(String),
previous_tier: 'standard'
})
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
@@ -341,6 +342,7 @@ describe('PricingTable', () => {
expect(
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
).toBeNull()
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('should use the latest userId value when it changes after mount', async () => {
@@ -366,6 +368,7 @@ describe('PricingTable', () => {
tier: 'creator',
cycle: 'yearly',
checkout_type: 'change',
checkout_attempt_id: expect.any(String),
previous_tier: 'standard'
})
})

View File

@@ -277,13 +277,19 @@ import type {
TierKey,
TierPricing
} from '@/platform/cloud/subscription/constants/tierPricing'
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import {
recordPendingSubscriptionCheckoutAttempt,
withPendingCheckoutAttemptId
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil'
import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import type {
CheckoutAttributionMetadata,
PaymentIntentSource
} from '@/platform/telemetry/types'
import { useAuthStore } from '@/stores/authStore'
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
@@ -321,6 +327,10 @@ interface PricingTierConfig {
isPopular?: boolean
}
const { reason } = defineProps<{
reason?: PaymentIntentSource
}>()
const emit = defineEmits<{
chooseTeamWorkspace: []
}>()
@@ -463,16 +473,17 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
} as const
const previousPlan = currentPlanDescriptor.value
const checkoutAttribution = await getCheckoutAttributionForCloud()
if (userId.value) {
telemetry?.trackBeginCheckout({
user_id: userId.value,
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change',
...checkoutAttribution,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
})
}
const beginCheckoutMetadata = userId.value
? {
user_id: userId.value,
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change' as const,
...(reason ? { payment_intent_source: reason } : {}),
...checkoutAttribution,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
}
: null
// Pass the target tier to create a deep link to subscription update confirmation
const checkoutTier = getCheckoutTier(
targetPlan.tierKey,
@@ -487,29 +498,39 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
if (downgrade) {
// TODO(COMFY-StripeProration): Remove once backend checkout creation mirrors portal proration ("change at billing end")
await accessBillingPortal()
const didOpenPortal = await accessBillingPortal()
if (didOpenPortal && beginCheckoutMetadata) {
telemetry?.trackBeginCheckout(beginCheckoutMetadata)
}
} else {
const didOpenPortal = await accessBillingPortal(checkoutTier)
if (!didOpenPortal) {
return
}
recordPendingSubscriptionCheckoutAttempt({
const pendingAttempt = recordPendingSubscriptionCheckoutAttempt({
tier: targetPlan.tierKey,
cycle: targetPlan.billingCycle,
checkout_type: 'change',
payment_intent_source: reason,
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {}),
...(previousPlan
? { previous_cycle: previousPlan.billingCycle }
: {})
})
if (beginCheckoutMetadata) {
telemetry?.trackBeginCheckout(
withPendingCheckoutAttemptId(
beginCheckoutMetadata,
pendingAttempt
)
)
}
}
} else {
await performSubscriptionCheckout(
tierKey,
currentBillingCycle.value,
true
)
await performSubscriptionCheckout(tierKey, currentBillingCycle.value, {
paymentIntentSource: reason
})
}
} finally {
isLoading.value = false

View File

@@ -56,7 +56,7 @@ const handleSubscribe = () => {
current_tier: tier.value?.toLowerCase()
})
isAwaitingStripeSubscription.value = true
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_now_button' })
}
onBeforeUnmount(() => {

View File

@@ -54,6 +54,6 @@ function handleSubscribeToRun() {
trackRunButton({ subscribe_to_run: true })
}
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscribe_to_run' })
}
</script>

View File

@@ -48,7 +48,9 @@
v-if="isActiveSubscription"
variant="primary"
class="rounded-lg px-4 py-2 text-sm font-normal text-text-primary"
@click="showSubscriptionDialog"
@click="
showSubscriptionDialog({ reason: 'settings_billing_panel' })
"
>
{{ $t('subscription.upgradePlan') }}
</Button>

View File

@@ -33,7 +33,11 @@
</i18n-t>
</div>
<PricingTable class="flex-1" @choose-team-workspace="handleChooseTeam" />
<PricingTable
:reason
class="flex-1"
@choose-team-workspace="handleChooseTeam"
/>
<!-- Contact and Enterprise Links -->
<div class="flex flex-col items-center gap-2">
@@ -157,11 +161,11 @@ import { useBillingContext } from '@/composables/billing/useBillingContext'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import { useCommandStore } from '@/stores/commandStore'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
const { onClose, reason, onChooseTeam } = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
onChooseTeam?: () => void
}>()

View File

@@ -24,7 +24,9 @@ export function useAccountPreconditionDialog() {
)
return
case 'subscription':
void dialogService.showSubscriptionRequiredDialog()
void dialogService.showSubscriptionRequiredDialog({
reason: 'subscription_required'
})
return
case 'credits':
void dialogService.showTopUpCreditsDialog({

View File

@@ -55,12 +55,6 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
})
}))
const mockTrackSubscription = vi.hoisted(() => vi.fn())
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
}))
describe('usePricingTableUrlLoader', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -96,9 +90,6 @@ describe('usePricingTableUrlLoader', () => {
reason: 'deep_link',
planMode: undefined
})
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
reason: 'deep_link'
})
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
})
@@ -150,7 +141,6 @@ describe('usePricingTableUrlLoader', () => {
await loadPricingTableFromUrl()
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
it('denies, strips, and clears together when the user is not eligible', async () => {
@@ -161,7 +151,6 @@ describe('usePricingTableUrlLoader', () => {
await loadPricingTableFromUrl()
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
expect(mockRouterReplace).toHaveBeenCalledWith({
query: { other: 'param' }
})
@@ -230,7 +219,6 @@ describe('usePricingTableUrlLoader', () => {
)
expect(mockShowPricingTable).not.toHaveBeenCalled()
expect(mockTrackSubscription).not.toHaveBeenCalled()
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledWith(
'pricing'

View File

@@ -7,7 +7,6 @@ import {
mergePreservedQueryIntoQuery
} from '@/platform/navigation/preservedQueryManager'
import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces'
import { useTelemetry } from '@/platform/telemetry'
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
@@ -62,7 +61,6 @@ export function usePricingTableUrlLoader() {
const planMode =
param === 'team' || param === 'personal' ? param : undefined
useTelemetry()?.trackSubscription('modal_opened', { reason: 'deep_link' })
subscriptionDialog.showPricingTable({ reason: 'deep_link', planMode })
}

View File

@@ -15,7 +15,7 @@ import { t } from '@/i18n'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
import { useDialogService } from '@/services/dialogService'
@@ -237,14 +237,7 @@ function useSubscriptionInternal() {
})
}, reportError)
const showSubscriptionDialog = (options?: {
reason?: SubscriptionDialogReason
}) => {
useTelemetry()?.trackSubscription('modal_opened', {
current_tier: subscriptionTier.value?.toLowerCase(),
reason: options?.reason
})
const showSubscriptionDialog = (options?: SubscriptionDialogOptions) => {
void showSubscriptionRequiredDialog(options)
}
@@ -277,7 +270,7 @@ function useSubscriptionInternal() {
await fetchSubscriptionStatus()
if (!isSubscribedOrIsNotCloud.value) {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'subscription_required' })
}
}

View File

@@ -39,15 +39,23 @@ vi.mock('@/stores/commandStore', () => ({
}))
// useTelemetry() returns null in OSS, a dispatcher in cloud — toggle via mockIsCloud.
const { mockIsCloud, mockTrackHelpResourceClicked } = vi.hoisted(() => ({
const {
mockIsCloud,
mockTrackHelpResourceClicked,
mockTrackAddApiCreditButtonClicked
} = vi.hoisted(() => ({
mockIsCloud: { value: true },
mockTrackHelpResourceClicked: vi.fn()
mockTrackHelpResourceClicked: vi.fn(),
mockTrackAddApiCreditButtonClicked: vi.fn()
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () =>
mockIsCloud.value
? { trackHelpResourceClicked: mockTrackHelpResourceClicked }
? {
trackHelpResourceClicked: mockTrackHelpResourceClicked,
trackAddApiCreditButtonClicked: mockTrackAddApiCreditButtonClicked
}
: null
}))
@@ -69,6 +77,9 @@ describe('useSubscriptionActions', () => {
const { handleAddApiCredits } = useSubscriptionActions()
handleAddApiCredits()
expect(mockShowTopUpCreditsDialog).toHaveBeenCalledOnce()
expect(mockTrackAddApiCreditButtonClicked).toHaveBeenCalledWith({
source: 'settings_billing_panel'
})
})
})

View File

@@ -21,6 +21,9 @@ export function useSubscriptionActions() {
})
const handleAddApiCredits = () => {
telemetry?.trackAddApiCreditButtonClicked({
source: 'settings_billing_panel'
})
void dialogService.showTopUpCreditsDialog()
}

View File

@@ -5,8 +5,10 @@ import { useSubscriptionDialog } from './useSubscriptionDialog'
const mockCloseDialog = vi.fn()
const mockShowLayoutDialog = vi.fn()
const mockShowTeamWorkspacesDialog = vi.fn()
const mockTrackSubscription = vi.hoisted(() => vi.fn())
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
const mockIsCloud = vi.hoisted(() => ({ value: true }))
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
@@ -60,10 +62,15 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
vi.mock('@/composables/billing/useBillingContext', () => ({
useBillingContext: () => ({
isFreeTier: mockIsFreeTier,
isLegacyTeamPlan: mockIsLegacyTeamPlan
isLegacyTeamPlan: mockIsLegacyTeamPlan,
tier: mockTier
})
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
}))
vi.mock('@/platform/workspace/composables/useWorkspaceUI', () => ({
useWorkspaceUI: () => ({
permissions: {
@@ -80,6 +87,7 @@ describe('useSubscriptionDialog', () => {
mockIsCloud.value = true
mockIsInPersonalWorkspace.value = true
mockIsFreeTier.value = false
mockTier.value = 'FREE'
mockTeamWorkspacesEnabled.value = false
mockIsLegacyTeamPlan.value = false
mockCanManageSubscription.value = true
@@ -198,6 +206,51 @@ describe('useSubscriptionDialog', () => {
const props = mockShowLayoutDialog.mock.calls[0][0].props
expect(props.initialPlanMode).toBe('team')
})
it('tracks modal_opened with the caller reason and current tier', () => {
mockTier.value = 'STANDARD'
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'upgrade_to_add_credits' })
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
current_tier: 'standard',
reason: 'upgrade_to_add_credits'
})
})
it('tracks modal_opened on the workspace (unified) path too', () => {
mockTeamWorkspacesEnabled.value = true
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockTrackSubscription).toHaveBeenCalledWith(
'modal_opened',
expect.objectContaining({ reason: 'subscribe_to_run' })
)
})
it('does not track modal_opened for the inactive member dialog', () => {
mockTeamWorkspacesEnabled.value = true
mockIsInPersonalWorkspace.value = false
mockCanManageSubscription.value = false
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockShowLayoutDialog).toHaveBeenCalledTimes(1)
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
it('does not track on non-cloud', () => {
mockIsCloud.value = false
const { showPricingTable } = useSubscriptionDialog()
showPricingTable({ reason: 'subscribe_to_run' })
expect(mockTrackSubscription).not.toHaveBeenCalled()
})
})
describe('show', () => {
@@ -235,6 +288,20 @@ describe('useSubscriptionDialog', () => {
expect.objectContaining({ key: 'subscription-required' })
)
})
it('tracks modal_opened with the reason for the free-tier dialog', () => {
mockIsFreeTier.value = true
mockIsInPersonalWorkspace.value = true
const { show } = useSubscriptionDialog()
show({ reason: 'out_of_credits' })
expect(mockTrackSubscription).toHaveBeenCalledTimes(1)
expect(mockTrackSubscription).toHaveBeenCalledWith(
'modal_opened',
expect.objectContaining({ reason: 'out_of_credits' })
)
})
})
describe('startTeamWorkspaceUpgradeFlow', () => {

View File

@@ -4,6 +4,8 @@ import { useDialogStore } from '@/stores/dialogStore'
import { useBillingContext } from '@/composables/billing/useBillingContext'
import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
@@ -11,14 +13,8 @@ const DIALOG_KEY = 'subscription-required'
const FREE_TIER_DIALOG_KEY = 'free-tier-info'
const RESUME_PRICING_KEY = 'comfy:resume-team-pricing'
export type SubscriptionDialogReason =
| 'subscription_required'
| 'out_of_credits'
| 'top_up_blocked'
| 'deep_link'
interface SubscriptionDialogOptions {
reason?: SubscriptionDialogReason
export interface SubscriptionDialogOptions {
reason?: PaymentIntentSource
/**
* Forces the unified pricing dialog to open on a specific plan tab,
* overriding the workspace-derived default (e.g. an "Upgrade to Team" CTA
@@ -38,6 +34,17 @@ export const useSubscriptionDialog = () => {
dialogStore.closeDialog({ key: FREE_TIER_DIALOG_KEY })
}
// Fired here — the choke point every paywall/pricing dialog variant passes
// through — so both the legacy and workspace billing paths emit it.
function trackModalOpened(reason?: PaymentIntentSource) {
// Resolved lazily to avoid the useBillingContext import cycle (see below).
const { tier } = useBillingContext()
useTelemetry()?.trackSubscription('modal_opened', {
current_tier: tier.value?.toLowerCase(),
reason
})
}
function showPricingTable(options?: SubscriptionDialogOptions) {
if (!isCloud) return
@@ -71,6 +78,8 @@ export const useSubscriptionDialog = () => {
return
}
trackModalOpened(options?.reason)
// Shared dialog shell styling for both variants.
const dialogComponentProps = {
style: 'width: min(1328px, 95vw); max-height: 958px;',
@@ -167,6 +176,8 @@ export const useSubscriptionDialog = () => {
// (not at composable setup) to avoid the useBillingContext import cycle.
const { isFreeTier } = useBillingContext()
if (isFreeTier.value && workspaceStore.isInPersonalWorkspace) {
trackModalOpened(options?.reason)
const component = defineAsyncComponent(
() =>
import('@/platform/cloud/subscription/components/FreeTierDialogContent.vue')
@@ -236,7 +247,7 @@ export const useSubscriptionDialog = () => {
sessionStorage.removeItem(RESUME_PRICING_KEY)
if (!workspaceStore.isInPersonalWorkspace) {
showPricingTable()
showPricingTable({ reason: 'team_upgrade_resume' })
}
} catch {
// sessionStorage may be unavailable

View File

@@ -0,0 +1,49 @@
import { beforeEach, describe, expect, it } from 'vitest'
import {
clearPendingSubscriptionCheckoutAttempt,
consumePendingSubscriptionCheckoutSuccess,
recordPendingSubscriptionCheckoutAttempt
} from './subscriptionCheckoutTracker'
const activeProStatus = {
is_active: true,
subscription_tier: 'PRO',
subscription_duration: 'MONTHLY'
} as const
describe('subscriptionCheckoutTracker', () => {
beforeEach(() => {
clearPendingSubscriptionCheckoutAttempt()
})
it('round-trips payment_intent_source from attempt to success metadata', () => {
recordPendingSubscriptionCheckoutAttempt({
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
expect(metadata).toMatchObject({
tier: 'pro',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
})
it('omits payment_intent_source when the attempt had none', () => {
recordPendingSubscriptionCheckoutAttempt({
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new'
})
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
expect(metadata).not.toBeNull()
expect(metadata).not.toHaveProperty('payment_intent_source')
})
})

View File

@@ -7,7 +7,12 @@ import type {
TierKey
} from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { SubscriptionSuccessMetadata } from '@/platform/telemetry/types'
import type {
BeginCheckoutMetadata,
PaymentIntentSource,
SubscriptionCheckoutType,
SubscriptionSuccessMetadata
} from '@/platform/telemetry/types'
const PENDING_SUBSCRIPTION_CHECKOUT_MAX_AGE_MS = 6 * 60 * 60 * 1000
const VALID_TIER_KEYS = new Set<TierKey>([
@@ -23,7 +28,6 @@ export const PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY =
export const PENDING_SUBSCRIPTION_CHECKOUT_EVENT =
'comfy:subscription-checkout-attempt-changed'
type CheckoutType = 'new' | 'change'
type SubscriptionDuration = 'MONTHLY' | 'ANNUAL'
interface SubscriptionStatusSnapshot {
@@ -32,22 +36,24 @@ interface SubscriptionStatusSnapshot {
subscription_duration?: SubscriptionDuration | null
}
interface PendingSubscriptionCheckoutAttempt {
export interface PendingSubscriptionCheckoutAttempt {
attempt_id: string
started_at_ms: number
tier: TierKey
cycle: BillingCycle
checkout_type: CheckoutType
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
previous_cycle?: BillingCycle
payment_intent_source?: PaymentIntentSource
}
interface RecordPendingSubscriptionCheckoutAttemptInput {
interface PendingSubscriptionCheckoutAttemptInput {
tier: TierKey
cycle: BillingCycle
checkout_type: CheckoutType
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
previous_cycle?: BillingCycle
payment_intent_source?: PaymentIntentSource
}
const dispatchPendingCheckoutChangeEvent = () => {
@@ -168,6 +174,9 @@ const normalizeAttempt = (
...(candidate.previous_cycle === 'monthly' ||
candidate.previous_cycle === 'yearly'
? { previous_cycle: candidate.previous_cycle }
: {}),
...(typeof candidate.payment_intent_source === 'string'
? { payment_intent_source: candidate.payment_intent_source }
: {})
}
}
@@ -224,20 +233,27 @@ const getPendingSubscriptionCheckoutAttempt =
export const hasPendingSubscriptionCheckoutAttempt = (): boolean =>
getPendingSubscriptionCheckoutAttempt() !== null
export const recordPendingSubscriptionCheckoutAttempt = (
input: RecordPendingSubscriptionCheckoutAttemptInput
export const createPendingSubscriptionCheckoutAttempt = (
input: PendingSubscriptionCheckoutAttemptInput
): PendingSubscriptionCheckoutAttempt => {
const storage = getStorage()
const attempt: PendingSubscriptionCheckoutAttempt = {
return {
attempt_id: createAttemptId(),
started_at_ms: Date.now(),
tier: input.tier,
cycle: input.cycle,
checkout_type: input.checkout_type,
...(input.previous_tier ? { previous_tier: input.previous_tier } : {}),
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {})
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {}),
...(input.payment_intent_source
? { payment_intent_source: input.payment_intent_source }
: {})
}
}
export const persistPendingSubscriptionCheckoutAttempt = (
attempt: PendingSubscriptionCheckoutAttempt
): PendingSubscriptionCheckoutAttempt => {
const storage = getStorage()
if (!storage) {
return attempt
}
@@ -255,6 +271,21 @@ export const recordPendingSubscriptionCheckoutAttempt = (
return attempt
}
export const recordPendingSubscriptionCheckoutAttempt = (
input: PendingSubscriptionCheckoutAttemptInput
): PendingSubscriptionCheckoutAttempt =>
persistPendingSubscriptionCheckoutAttempt(
createPendingSubscriptionCheckoutAttempt(input)
)
export const withPendingCheckoutAttemptId = (
metadata: BeginCheckoutMetadata,
attempt: PendingSubscriptionCheckoutAttempt
): BeginCheckoutMetadata => ({
...metadata,
checkout_attempt_id: attempt.attempt_id
})
const didAttemptSucceed = (
attempt: PendingSubscriptionCheckoutAttempt,
status: SubscriptionStatusSnapshot
@@ -287,6 +318,9 @@ export const consumePendingSubscriptionCheckoutSuccess = (
cycle: attempt.cycle,
checkout_type: attempt.checkout_type,
...(attempt.previous_tier ? { previous_tier: attempt.previous_tier } : {}),
...(attempt.payment_intent_source
? { payment_intent_source: attempt.payment_intent_source }
: {}),
value,
currency: 'USD',
ecommerce: {

View File

@@ -132,13 +132,14 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'yearly', true)
await performSubscriptionCheckout('pro', 'yearly')
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-123',
tier: 'pro',
cycle: 'yearly',
checkout_type: 'new',
checkout_attempt_id: expect.any(String),
ga_client_id: 'ga-client-id',
ga_session_id: 'ga-session-id',
ga_session_number: 'ga-session-number',
@@ -150,6 +151,12 @@ describe('performSubscriptionCheckout', () => {
gbraid: 'gbraid-456',
wbraid: 'wbraid-789'
})
const beginCheckoutMetadata =
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
JSON.parse(storedAttempt).attempt_id
)
expect(global.fetch).toHaveBeenCalledWith(
expect.stringContaining(
'/customers/cloud-subscription-checkout/pro-yearly'
@@ -186,7 +193,7 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', true)
await performSubscriptionCheckout('pro', 'monthly')
expect(warnSpy).toHaveBeenCalledWith(
'[SubscriptionCheckout] Failed to collect checkout attribution',
@@ -203,11 +210,43 @@ describe('performSubscriptionCheckout', () => {
user_id: 'user-123',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new'
checkout_type: 'new',
checkout_attempt_id: expect.any(String)
})
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
})
it('carries the payment intent source into begin_checkout and the pending attempt', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi
.spyOn(window, 'open')
.mockImplementation(() => window as unknown as Window)
vi.mocked(global.fetch).mockResolvedValue({
ok: true,
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', {
paymentIntentSource: 'out_of_credits'
})
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({ payment_intent_source: 'out_of_credits' })
)
const beginCheckoutMetadata =
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
const pendingAttempt = JSON.parse(storedAttempt)
expect(pendingAttempt).toMatchObject({
payment_intent_source: 'out_of_credits'
})
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
pendingAttempt.attempt_id
)
openSpy.mockRestore()
})
it('uses the latest userId when it changes after checkout starts', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi
@@ -222,7 +261,7 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly', true)
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly')
mockUserId.value = 'user-late'
authHeader.resolve({ Authorization: 'Bearer test-token' })
@@ -235,13 +274,14 @@ describe('performSubscriptionCheckout', () => {
user_id: 'user-late',
tier: 'pro',
cycle: 'yearly',
checkout_type: 'new'
checkout_type: 'new',
checkout_attempt_id: expect.any(String)
})
)
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
})
it('does not persist a pending attempt when the checkout popup is blocked', async () => {
it('does not persist the pending attempt when the checkout popup is blocked', async () => {
const checkoutUrl = 'https://checkout.stripe.com/test'
const openSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
@@ -250,11 +290,18 @@ describe('performSubscriptionCheckout', () => {
json: async () => ({ checkout_url: checkoutUrl })
} as Response)
await performSubscriptionCheckout('pro', 'monthly', true)
await performSubscriptionCheckout('pro', 'monthly')
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
expect(
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
).toBeNull()
const storedAttempt = window.localStorage.getItem(
PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY
)
expect(storedAttempt).toBeNull()
expect(mockLocalStorage.setItem).not.toHaveBeenCalled()
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
checkout_attempt_id: expect.any(String)
})
)
})
})

View File

@@ -4,12 +4,19 @@ import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { getComfyApiBaseUrl } from '@/config/comfyApi'
import { t } from '@/i18n'
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import {
createPendingSubscriptionCheckoutAttempt,
persistPendingSubscriptionCheckoutAttempt,
withPendingCheckoutAttemptId
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import { isCloud } from '@/platform/distribution/types'
import { useTelemetry } from '@/platform/telemetry'
import type {
CheckoutAttributionMetadata,
PaymentIntentSource
} from '@/platform/telemetry/types'
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
import type { BillingCycle } from './subscriptionTierRank'
type CheckoutTier = TierKey | `${TierKey}-yearly`
@@ -31,6 +38,11 @@ const getCheckoutAttributionForCloud =
return getCheckoutAttribution()
}
interface PerformSubscriptionCheckoutOptions {
openInNewTab?: boolean
paymentIntentSource?: PaymentIntentSource
}
/**
* Core subscription checkout logic shared between PricingTable and
* SubscriptionRedirectView. Handles:
@@ -47,10 +59,12 @@ const getCheckoutAttributionForCloud =
export async function performSubscriptionCheckout(
tierKey: TierKey,
currentBillingCycle: BillingCycle,
openInNewTab: boolean = true
options: PerformSubscriptionCheckoutOptions = {}
): Promise<void> {
if (!isCloud) return
const { openInNewTab = true, paymentIntentSource } = options
const authStore = useAuthStore()
const { userId } = storeToRefs(authStore)
const telemetry = useTelemetry()
@@ -108,14 +122,29 @@ export async function performSubscriptionCheckout(
const data = await response.json()
if (data.checkout_url) {
const pendingAttempt = createPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
payment_intent_source: paymentIntentSource
})
if (userId.value) {
telemetry?.trackBeginCheckout({
user_id: userId.value,
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
...checkoutAttribution
})
telemetry?.trackBeginCheckout(
withPendingCheckoutAttemptId(
{
user_id: userId.value,
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new',
...(paymentIntentSource
? { payment_intent_source: paymentIntentSource }
: {}),
...checkoutAttribution
},
pendingAttempt
)
)
}
if (openInNewTab) {
@@ -123,18 +152,9 @@ export async function performSubscriptionCheckout(
if (!checkoutWindow) {
return
}
recordPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new'
})
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
} else {
recordPendingSubscriptionCheckoutAttempt({
tier: tierKey,
cycle: currentBillingCycle,
checkout_type: 'new'
})
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
globalThis.location.href = data.checkout_url
}
}

View File

@@ -1,9 +1,13 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { computed, reactive } from 'vue'
const { mockIsCloud, mockSubscribe } = vi.hoisted(() => ({
mockIsCloud: { value: true },
mockSubscribe: vi.fn()
}))
const { mockIsCloud, mockSubscribe, mockTrackBeginCheckout, mockUserId } =
vi.hoisted(() => ({
mockIsCloud: { value: true },
mockSubscribe: vi.fn(),
mockTrackBeginCheckout: vi.fn(),
mockUserId: { value: 'user-1' as string | null }
}))
vi.mock('@/platform/distribution/types', () => ({
get isCloud() {
@@ -16,6 +20,12 @@ vi.mock('@/config/comfyApi', () => ({
vi.mock('@/platform/workspace/api/workspaceApi', () => ({
workspaceApi: { subscribe: mockSubscribe }
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackBeginCheckout: mockTrackBeginCheckout })
}))
vi.mock('@/stores/authStore', () => ({
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
}))
import { performTeamSubscriptionCheckout } from './teamSubscriptionCheckoutUtil'
@@ -43,7 +53,9 @@ describe('performTeamSubscriptionCheckout', () => {
billing_op_id: 'op_1'
})
await performTeamSubscriptionCheckout('team_700', 'yearly')
await performTeamSubscriptionCheckout('team_700', 'yearly', {
paymentIntentSource: 'deep_link'
})
expect(mockSubscribe).toHaveBeenCalledWith('team_per_credit_annual', {
returnUrl: 'https://app.test/payment/success',
@@ -51,6 +63,14 @@ describe('performTeamSubscriptionCheckout', () => {
teamCreditStopId: 'team_700'
})
expect(assignedHref).toBe('https://stripe.test/pay')
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-1',
tier: 'team',
cycle: 'yearly',
checkout_type: 'new',
billing_op_id: 'op_1',
payment_intent_source: 'deep_link'
})
})
it('uses the monthly slug and lands in the app when no Stripe step is needed', async () => {
@@ -82,6 +102,16 @@ describe('performTeamSubscriptionCheckout', () => {
expect(assignedHref).toBeUndefined()
})
it('does not track begin_checkout when subscribe fails', async () => {
mockSubscribe.mockRejectedValueOnce(new Error('subscribe failed'))
await expect(
performTeamSubscriptionCheckout('team_700', 'yearly')
).rejects.toThrow('subscribe failed')
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('does nothing off cloud', async () => {
mockIsCloud.value = false

View File

@@ -1,10 +1,16 @@
import { getComfyPlatformBaseUrl } from '@/config/comfyApi'
import { getTeamPlanSlug } from '@/platform/cloud/subscription/constants/teamPlanCreditStops'
import { isCloud } from '@/platform/distribution/types'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
import type { BillingCycle } from './subscriptionTierRank'
interface PerformTeamSubscriptionCheckoutOptions {
paymentIntentSource?: PaymentIntentSource
}
/**
* Direct team-plan checkout for the marketing `/cloud/subscribe?tier=team` deep
* link: subscribes to the per-credit Team plan at the chosen slider stop and
@@ -22,7 +28,8 @@ import type { BillingCycle } from './subscriptionTierRank'
*/
export async function performTeamSubscriptionCheckout(
teamCreditStopId: string,
billingCycle: BillingCycle
billingCycle: BillingCycle,
options: PerformTeamSubscriptionCheckoutOptions = {}
): Promise<void> {
if (!isCloud) return
@@ -33,6 +40,14 @@ export async function performTeamSubscriptionCheckout(
teamCreditStopId
})
trackWorkspaceCheckoutStarted({
tier: 'team',
cycle: billingCycle,
checkoutType: 'new',
billingOpId: response.billing_op_id,
paymentIntentSource: options.paymentIntentSource
})
if (response.status === 'needs_payment_method') {
// A needs_payment_method response without a URL is unusable: surface it to
// the caller's error handling rather than silently dropping the user home

View File

@@ -1,7 +1,5 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { DownloadApiError } from '@/platform/modelManager/types'
import {
downloadModel,
fetchModelMetadata,
@@ -9,23 +7,13 @@ import {
toBrowsableUrl
} from './missingModelDownload'
const {
fetchMock,
mockIsDesktop,
mockSidebarTabStore,
mockStartDownload,
mockEnqueue,
mockToastAdd,
mockFlags
} = vi.hoisted(() => ({
fetchMock: vi.fn(),
mockIsDesktop: { value: false },
mockSidebarTabStore: { activeSidebarTabId: null as string | null },
mockStartDownload: vi.fn(),
mockEnqueue: vi.fn(),
mockToastAdd: vi.fn(),
mockFlags: { serverSideModelDownloads: false }
}))
const { fetchMock, mockIsDesktop, mockSidebarTabStore, mockStartDownload } =
vi.hoisted(() => ({
fetchMock: vi.fn(),
mockIsDesktop: { value: false },
mockSidebarTabStore: { activeSidebarTabId: null as string | null },
mockStartDownload: vi.fn()
}))
vi.stubGlobal('fetch', fetchMock)
@@ -45,38 +33,12 @@ vi.mock('@/stores/workspace/sidebarTabStore', () => ({
useSidebarTabStore: () => mockSidebarTabStore
}))
vi.mock('@/composables/useFeatureFlags', () => ({
useFeatureFlags: () => ({ flags: mockFlags })
}))
vi.mock('@/platform/modelManager/stores/modelDownloadStore', () => ({
useModelDownloadStore: () => ({ enqueue: mockEnqueue })
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: () => ({ add: mockToastAdd })
}))
const mockRefreshModelFolder = vi.fn()
const mockRefreshMissingModels = vi.fn()
vi.mock('@/stores/modelStore', () => ({
useModelStore: () => ({ refreshModelFolder: mockRefreshModelFolder })
}))
vi.mock('@/platform/missingModel/missingModelStore', () => ({
useMissingModelStore: () => ({
refreshMissingModels: mockRefreshMissingModels
})
}))
let testId = 0
beforeEach(() => {
vi.restoreAllMocks()
vi.resetAllMocks()
delete window.__comfyDesktop2
mockFlags.serverSideModelDownloads = false
})
describe('fetchModelMetadata', () => {
@@ -155,26 +117,6 @@ describe('fetchModelMetadata', () => {
expect(fetchMock).not.toHaveBeenCalled()
})
it('returns null metadata when the Civitai request throws', async () => {
fetchMock.mockRejectedValueOnce(new Error('network down'))
const metadata = await fetchModelMetadata(
`https://civitai.com/api/download/models/${testId}`
)
expect(metadata).toEqual({ fileSize: null, gatedRepoUrl: null })
})
it('returns null metadata when the HEAD request throws', async () => {
fetchMock.mockRejectedValueOnce(new Error('network down'))
const metadata = await fetchModelMetadata(
`https://huggingface.co/org/model/resolve/main/throws-${testId}.safetensors`
)
expect(metadata).toEqual({ fileSize: null, gatedRepoUrl: null })
})
it('returns cached metadata on second call', async () => {
const url = `https://huggingface.co/org/model/resolve/main/cached-${testId}.safetensors`
@@ -298,16 +240,6 @@ describe('isModelDownloadable', () => {
})
).toBe(false)
})
it('allows explicitly whitelisted URLs from an otherwise disallowed host', () => {
expect(
isModelDownloadable({
name: 'RealESRGAN_x4plus.pth',
url: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
directory: 'upscale_models'
})
).toBe(true)
})
})
describe('downloadModel', () => {
@@ -447,152 +379,6 @@ describe('downloadModel', () => {
expect(anchorClick).toHaveBeenCalledTimes(1)
})
it('uses the browser fallback on web when server-side downloads are disabled', () => {
const anchorClick = vi
.spyOn(HTMLAnchorElement.prototype, 'click')
.mockImplementation(() => {})
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
expect(anchorClick).toHaveBeenCalledTimes(1)
expect(mockEnqueue).not.toHaveBeenCalled()
})
it('enqueues a server-side download and reveals the manager when enabled', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockResolvedValue({ download_id: 'd1', accepted: true })
const anchorClick = vi
.spyOn(HTMLAnchorElement.prototype, 'click')
.mockImplementation(() => {})
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockSidebarTabStore.activeSidebarTabId).toBe('model-manager')
})
expect(mockEnqueue).toHaveBeenCalledWith({
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
model_id: 'checkpoints/model.safetensors'
})
expect(anchorClick).not.toHaveBeenCalled()
})
it('shows a toast when a server-side enqueue fails', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockRejectedValue(new Error('boom'))
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error', detail: 'boom' })
)
})
expect(mockSidebarTabStore.activeSidebarTabId).toBeNull()
})
it('reveals the download manager and shows an info toast for an in-progress download', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockRejectedValue(
new DownloadApiError('exists', 'ALREADY_DOWNLOADING', 409)
)
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockSidebarTabStore.activeSidebarTabId).toBe('model-manager')
})
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({
severity: 'info',
detail: 'model.safetensors'
})
)
})
it('refreshes the model folder and re-scans missing models when already available', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockRejectedValue(
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
)
mockRefreshModelFolder.mockResolvedValue(undefined)
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({
severity: 'info',
detail: 'model.safetensors'
})
)
})
await vi.waitFor(() => {
expect(mockRefreshModelFolder).toHaveBeenCalledWith('checkpoints')
})
expect(mockRefreshMissingModels).toHaveBeenCalled()
expect(mockSidebarTabStore.activeSidebarTabId).toBeNull()
})
it('still re-scans missing models when the post-available folder refresh fails', async () => {
mockFlags.serverSideModelDownloads = true
mockEnqueue.mockRejectedValue(
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
)
const consoleWarn = vi.spyOn(console, 'warn').mockImplementation(() => {})
mockRefreshModelFolder.mockRejectedValue(new Error('boom'))
downloadModel(
{
name: 'model.safetensors',
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
directory: 'checkpoints'
},
{ checkpoints: ['/models/checkpoints'] }
)
await vi.waitFor(() => {
expect(mockRefreshMissingModels).toHaveBeenCalled()
})
expect(consoleWarn).toHaveBeenCalled()
consoleWarn.mockRestore()
})
it('opens the model library sidebar before starting a desktop download', () => {
mockIsDesktop.value = true

View File

@@ -1,11 +1,5 @@
import { downloadUrlToHfRepoUrl, isCivitaiModelUrl } from '@/utils/formatUtil'
import { useFeatureFlags } from '@/composables/useFeatureFlags'
import { t } from '@/i18n'
import { isDesktop } from '@/platform/distribution/types'
import { useModelDownloadStore } from '@/platform/modelManager/stores/modelDownloadStore'
import { DownloadApiError } from '@/platform/modelManager/types'
import { buildModelId } from '@/platform/modelManager/utils/modelId'
import { useToastStore } from '@/platform/updates/common/toastStore'
import { useElectronDownloadStore } from '@/stores/electronDownloadStore'
import { useSidebarTabStore } from '@/stores/workspace/sidebarTabStore'
import type { ComfyDesktop2Bridge } from '@/types'
@@ -35,7 +29,6 @@ const WHITE_LISTED_URLS: ReadonlySet<string> = new Set([
])
const MODEL_LIBRARY_TAB_ID = 'model-library'
const MODEL_MANAGER_TAB_ID = 'model-manager'
export interface ModelWithUrl {
name: string
@@ -54,96 +47,6 @@ async function startDesktop2ModelDownload(
}
}
function revealDownloadManager(): void {
useSidebarTabStore().activeSidebarTabId = MODEL_MANAGER_TAB_ID
}
/**
* Already on disk: surface a confirmation, refresh the model folder, and
* re-scan missing models so any node error for this model clears. Loaded
* lazily to keep this module's import graph (and its unit tests) light.
*/
async function refreshAfterModelAvailable(model: ModelWithUrl): Promise<void> {
try {
const [{ useModelStore }, { useMissingModelStore }] = await Promise.all([
import('@/stores/modelStore'),
import('@/platform/missingModel/missingModelStore')
])
if (model.directory) {
try {
await useModelStore().refreshModelFolder(model.directory)
} catch (error) {
console.warn(
'[MissingModel] Failed to refresh model folder after model available',
error
)
}
}
void useMissingModelStore().refreshMissingModels()
} catch (error) {
console.warn(
'[MissingModel] Failed to refresh after model available',
error
)
}
}
/**
* Enqueues a server-side download and reveals the Model Manager panel so the
* user can watch live progress, status, and completion. The two benign `409`
* cases get an info toast: `ALREADY_DOWNLOADING` links to the existing job,
* `ALREADY_AVAILABLE` confirms it's installed and clears the node error. Any
* other failure is reported via an error toast.
*/
async function startServerSideModelDownload(
model: ModelWithUrl
): Promise<void> {
const toast = useToastStore()
try {
await useModelDownloadStore().enqueue({
url: model.url,
model_id: buildModelId(model.directory, model.name)
})
revealDownloadManager()
} catch (error: unknown) {
if (error instanceof DownloadApiError && error.is('ALREADY_DOWNLOADING')) {
revealDownloadManager()
toast.add({
severity: 'info',
summary: t('modelManager.alreadyDownloading'),
detail: model.name,
life: 4000
})
return
}
if (error instanceof DownloadApiError && error.is('ALREADY_AVAILABLE')) {
toast.add({
severity: 'info',
summary: t('modelManager.alreadyInstalled'),
detail: model.name,
life: 4000
})
void refreshAfterModelAvailable(model)
return
}
toast.add({
severity: 'error',
summary: t('modelManager.actionFailed'),
detail: error instanceof Error ? error.message : String(error),
life: 5000
})
}
}
function startBrowserModelDownload(model: ModelWithUrl): void {
const link = document.createElement('a')
link.href = model.url
link.download = model.name
link.target = '_blank'
link.rel = 'noopener noreferrer'
link.click()
}
/**
* Converts a model download URL to a browsable page URL.
* - HuggingFace: `/resolve/` → `/blob/` (file page with model info)
@@ -179,11 +82,12 @@ export function downloadModel(
}
if (!isDesktop) {
if (useFeatureFlags().flags.serverSideModelDownloads) {
void startServerSideModelDownload(model)
} else {
startBrowserModelDownload(model)
}
const link = document.createElement('a')
link.href = model.url
link.download = model.name
link.target = '_blank'
link.rel = 'noopener noreferrer'
link.click()
return
}

View File

@@ -1,240 +0,0 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { api } from '@/scripts/api'
import { DownloadApiError } from '../types'
import {
cancelDownload,
checkAvailability,
clearDownloads,
deleteCredential,
deleteDownload,
enqueueDownload,
listCredentials,
listDownloads,
pauseDownload,
resumeDownload,
setDownloadPriority,
upsertCredential
} from './modelDownloadApi'
vi.mock('@/scripts/api', () => ({
api: {
fetchApi: vi.fn()
}
}))
const fetchApi = vi.mocked(api.fetchApi)
function jsonResponse(status: number, body: unknown): Response {
return {
ok: status >= 200 && status < 300,
status,
statusText: `status ${status}`,
json: () => Promise.resolve(body)
} as unknown as Response
}
function nonJsonErrorResponse(status: number, statusText: string): Response {
return {
ok: false,
status,
statusText,
json: () => Promise.reject(new Error('not json'))
} as unknown as Response
}
describe('modelDownloadApi', () => {
beforeEach(() => {
vi.resetAllMocks()
})
describe('enqueueDownload', () => {
it('returns the enqueue response on 202', async () => {
fetchApi.mockResolvedValue(
jsonResponse(202, { download_id: 'd1', accepted: true })
)
const result = await enqueueDownload({
url: 'https://huggingface.co/x.safetensors',
model_id: 'loras/x.safetensors'
})
expect(result).toEqual({ download_id: 'd1', accepted: true })
expect(fetchApi).toHaveBeenCalledWith(
'/download/enqueue',
expect.objectContaining({ method: 'POST' })
)
})
it('throws a DownloadApiError carrying the error code on 409', async () => {
fetchApi.mockResolvedValue(
jsonResponse(409, {
error: { code: 'ALREADY_DOWNLOADING', message: 'exists' }
})
)
await expect(
enqueueDownload({ url: 'u', model_id: 'loras/x.safetensors' })
).rejects.toMatchObject({
code: 'ALREADY_DOWNLOADING',
status: 409,
message: 'exists'
})
})
it('falls back to statusText and an UNKNOWN code for a non-JSON error body', async () => {
fetchApi.mockResolvedValue(nonJsonErrorResponse(502, 'Bad Gateway'))
await expect(
enqueueDownload({ url: 'u', model_id: 'loras/x.safetensors' })
).rejects.toMatchObject({
message: 'Bad Gateway',
code: 'UNKNOWN',
status: 502
})
})
it('exposes the code through DownloadApiError.is()', async () => {
fetchApi.mockResolvedValue(
jsonResponse(400, {
error: { code: 'URL_NOT_ALLOWED', message: 'nope' }
})
)
const error = await enqueueDownload({
url: 'u',
model_id: 'loras/x.safetensors'
}).catch((e) => e)
expect(error).toBeInstanceOf(DownloadApiError)
expect(error.is('URL_NOT_ALLOWED')).toBe(true)
})
})
describe('listDownloads', () => {
it('unwraps the downloads array', async () => {
fetchApi.mockResolvedValue(
jsonResponse(200, { downloads: [{ download_id: 'd1' }] })
)
const result = await listDownloads()
expect(result).toEqual([{ download_id: 'd1' }])
})
})
describe('actions', () => {
it('posts to the pause route', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
await pauseDownload('d1')
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1/pause',
expect.objectContaining({ method: 'POST' })
)
})
it('posts to the resume route', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
await resumeDownload('d1')
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1/resume',
expect.objectContaining({ method: 'POST' })
)
})
it('posts to the cancel route', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
await cancelDownload('d1')
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1/cancel',
expect.objectContaining({ method: 'POST' })
)
})
it('sends the priority in the body', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
await setDownloadPriority('d1', 5)
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1/priority',
expect.objectContaining({ body: JSON.stringify({ priority: 5 }) })
)
})
it('sends a DELETE to the download route', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { deleted: true }))
await deleteDownload('d1')
expect(fetchApi).toHaveBeenCalledWith(
'/download/d1',
expect.objectContaining({ method: 'DELETE' })
)
})
it('posts to the clear route and returns the deleted count', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { deleted: 3 }))
const result = await clearDownloads()
expect(result).toBe(3)
expect(fetchApi).toHaveBeenCalledWith(
'/download/clear',
expect.objectContaining({ method: 'POST' })
)
})
})
describe('checkAvailability', () => {
it('wraps the models map in the request body', async () => {
fetchApi.mockResolvedValue(jsonResponse(200, { models: {} }))
await checkAvailability({ 'loras/x.safetensors': 'https://h.co/x' })
expect(fetchApi).toHaveBeenCalledWith(
'/download/availability',
expect.objectContaining({
body: JSON.stringify({
models: { 'loras/x.safetensors': 'https://h.co/x' }
})
})
)
})
})
describe('credentials', () => {
it('unwraps the credentials list', async () => {
fetchApi.mockResolvedValue(
jsonResponse(200, { credentials: [{ id: 'c1' }] })
)
expect(await listCredentials()).toEqual([{ id: 'c1' }])
})
it('returns the created credential view', async () => {
fetchApi.mockResolvedValue(jsonResponse(201, { id: 'c1', host: 'h' }))
const result = await upsertCredential({ host: 'h', secret: 's' })
expect(result).toEqual({ id: 'c1', host: 'h' })
})
it('throws on a failed delete', async () => {
fetchApi.mockResolvedValue(
jsonResponse(404, { error: { code: 'NOT_FOUND', message: 'gone' } })
)
await expect(deleteCredential('c1')).rejects.toMatchObject({
code: 'NOT_FOUND'
})
})
})
})

View File

@@ -1,121 +0,0 @@
import { api } from '@/scripts/api'
import type {
AvailabilityResponse,
DownloadStatus,
EnqueueRequest,
EnqueueResponse,
HostCredentialUpsert,
HostCredentialView
} from '../types'
import { DownloadApiError } from '../types'
const BASE = '/download'
interface ErrorEnvelope {
error?: {
code?: string
message?: string
details?: Record<string, unknown>
}
}
async function throwFromResponse(response: Response): Promise<never> {
let body: ErrorEnvelope = {}
try {
body = await response.json()
} catch {
// Non-JSON error body
}
const error = body.error
throw new DownloadApiError(
error?.message ?? response.statusText,
error?.code ?? 'UNKNOWN',
response.status,
error?.details
)
}
async function parseJson<T>(response: Response): Promise<T> {
if (!response.ok) return throwFromResponse(response)
return response.json() as Promise<T>
}
function postJson(route: string, body: unknown): Promise<Response> {
return api.fetchApi(route, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(body)
})
}
export async function enqueueDownload(
body: EnqueueRequest
): Promise<EnqueueResponse> {
const response = await postJson(`${BASE}/enqueue`, body)
if (response.status === 202) {
return response.json() as Promise<EnqueueResponse>
}
return throwFromResponse(response)
}
export async function listDownloads(): Promise<DownloadStatus[]> {
const response = await api.fetchApi(BASE)
const data = await parseJson<{ downloads: DownloadStatus[] }>(response)
return data.downloads
}
async function postAction(id: string, action: string): Promise<void> {
const response = await postJson(`${BASE}/${id}/${action}`, undefined)
await parseJson<{ ok: boolean }>(response)
}
export const pauseDownload = (id: string) => postAction(id, 'pause')
export const resumeDownload = (id: string) => postAction(id, 'resume')
export const cancelDownload = (id: string) => postAction(id, 'cancel')
export async function deleteDownload(id: string): Promise<void> {
const response = await api.fetchApi(`${BASE}/${id}`, { method: 'DELETE' })
await parseJson<{ deleted: boolean }>(response)
}
export async function clearDownloads(): Promise<number> {
const response = await postJson(`${BASE}/clear`, undefined)
const data = await parseJson<{ deleted: number }>(response)
return data.deleted
}
export async function setDownloadPriority(
id: string,
priority: number
): Promise<void> {
const response = await postJson(`${BASE}/${id}/priority`, { priority })
await parseJson<{ ok: boolean }>(response)
}
export async function checkAvailability(
models: Record<string, string>
): Promise<AvailabilityResponse> {
const response = await postJson(`${BASE}/availability`, { models })
return parseJson<AvailabilityResponse>(response)
}
export async function listCredentials(): Promise<HostCredentialView[]> {
const response = await api.fetchApi(`${BASE}/credentials`)
const data = await parseJson<{ credentials: HostCredentialView[] }>(response)
return data.credentials
}
export async function upsertCredential(
body: HostCredentialUpsert
): Promise<HostCredentialView> {
const response = await postJson(`${BASE}/credentials`, body)
return parseJson<HostCredentialView>(response)
}
export async function deleteCredential(id: string): Promise<void> {
const response = await api.fetchApi(`${BASE}/credentials/${id}`, {
method: 'DELETE'
})
await parseJson<{ deleted: boolean }>(response)
}

View File

@@ -1,253 +0,0 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { defineComponent, ref } from 'vue'
import { createI18n } from 'vue-i18n'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import { DownloadApiError } from '../types'
import AddModelByUrlDialog from './AddModelByUrlDialog.vue'
const mockEnqueue = vi.fn()
const mockToastAdd = vi.fn()
const mockFetchModelTypes = vi.fn()
const mockModelTypes = ref([
{ name: 'Checkpoints', value: 'checkpoints' },
{ name: 'LoRA', value: 'loras' }
])
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => ({ enqueue: mockEnqueue })
}))
vi.mock('@/platform/assets/composables/useModelTypes', () => ({
useModelTypes: () => ({
modelTypes: mockModelTypes,
isLoading: ref(false),
fetchModelTypes: mockFetchModelTypes
})
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: () => ({ add: mockToastAdd })
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: { en: enMessages },
missingWarn: false,
fallbackWarn: false
})
const stubs = {
Dialog: { template: '<div><slot /></div>' },
DialogPortal: { template: '<div><slot /></div>' },
DialogOverlay: { template: '<div />' },
DialogContent: defineComponent({
emits: ['open-auto-focus'],
mounted() {
this.$emit('open-auto-focus')
},
template: '<div><slot /></div>'
}),
DialogHeader: { template: '<div><slot /></div>' },
DialogTitle: { template: '<div><slot /></div>' },
DialogDescription: { template: '<div><slot /></div>' },
DialogFooter: { template: '<div><slot /></div>' },
SingleSelect: {
props: ['modelValue', 'options', 'label', 'loading'],
emits: ['update:modelValue'],
template: `
<select
data-testid="folder-select"
:value="modelValue ?? ''"
@change="$emit('update:modelValue', $event.target.value)"
>
<option value="" disabled>{{ label }}</option>
<option v-for="opt in options" :key="opt.value" :value="opt.value">
{{ opt.name }}
</option>
</select>
`
}
}
function mountDialog(open = true) {
return render(AddModelByUrlDialog, {
props: { open },
global: { plugins: [i18n], stubs }
})
}
async function fillValidForm() {
await userEvent.type(
screen.getByLabelText('URL'),
'https://huggingface.co/org/model/resolve/main/model.safetensors'
)
await userEvent.selectOptions(
screen.getByTestId('folder-select'),
'checkpoints'
)
}
describe('AddModelByUrlDialog', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('auto-fills the filename from the url until edited manually', async () => {
mountDialog()
const urlInput = screen.getByLabelText('URL')
const filenameInput = screen.getByLabelText('Filename')
await userEvent.type(
urlInput,
'https://huggingface.co/org/model/resolve/main/model.safetensors'
)
expect(filenameInput).toHaveValue('model.safetensors')
await userEvent.clear(filenameInput)
await userEvent.type(filenameInput, 'custom.safetensors')
await userEvent.type(urlInput, '?download=true')
expect(filenameInput).toHaveValue('custom.safetensors')
})
it('shows a hint for a url on a non-allowlisted host', async () => {
mountDialog()
await userEvent.type(
screen.getByLabelText('URL'),
'https://example.com/model.safetensors'
)
expect(
screen.getByText('This host may not be on the download allowlist.')
).toBeInTheDocument()
})
it('disables submit until url, folder, and a valid filename are set', async () => {
mountDialog()
const submit = screen.getByRole('button', { name: 'Download' })
expect(submit).toBeDisabled()
await userEvent.type(
screen.getByLabelText('URL'),
'https://huggingface.co/org/model/resolve/main/model.safetensors'
)
expect(submit).toBeDisabled()
await userEvent.selectOptions(
screen.getByTestId('folder-select'),
'checkpoints'
)
expect(submit).toBeEnabled()
})
it('requires a known model extension', async () => {
mountDialog()
await userEvent.type(
screen.getByLabelText('URL'),
'https://huggingface.co/org/model/resolve/main/readme.txt'
)
await userEvent.selectOptions(
screen.getByTestId('folder-select'),
'checkpoints'
)
const submit = screen.getByRole('button', { name: 'Download' })
expect(submit).toBeDisabled()
})
it('closes without submitting when cancel is clicked', async () => {
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Cancel' }))
expect(mockEnqueue).not.toHaveBeenCalled()
expect(emitted('update:open')?.at(-1)).toEqual([false])
})
it('enqueues the download and closes on success', async () => {
mockEnqueue.mockResolvedValue({ download_id: 'd1', accepted: true })
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
expect(mockEnqueue).toHaveBeenCalledWith({
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
model_id: 'checkpoints/model.safetensors',
allow_any_extension: false
})
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'success' })
)
})
expect(emitted('update:open')?.at(-1)).toEqual([false])
})
it('shows an info toast and closes when the model is already downloading', async () => {
mockEnqueue.mockRejectedValue(
new DownloadApiError('exists', 'ALREADY_DOWNLOADING', 409)
)
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'info' })
)
})
expect(emitted('update:open')?.at(-1)).toEqual([false])
})
it('shows an info toast and closes when the model is already installed', async () => {
mockEnqueue.mockRejectedValue(
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
)
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'info' })
)
})
expect(emitted('update:open')?.at(-1)).toEqual([false])
})
it('shows an inline error for other API failures and stays open', async () => {
mockEnqueue.mockRejectedValue(
new DownloadApiError('url not allowed', 'URL_NOT_ALLOWED', 400)
)
const { emitted } = mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
await vi.waitFor(() => {
expect(screen.getByText('url not allowed')).toBeInTheDocument()
})
expect(emitted('update:open')).toBeUndefined()
})
it('shows a generic error message for a non-api error', async () => {
mockEnqueue.mockRejectedValue(new Error('network down'))
mountDialog()
await fillValidForm()
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
await vi.waitFor(() => {
expect(screen.getByText('network down')).toBeInTheDocument()
})
})
})

View File

@@ -1,237 +0,0 @@
<template>
<Dialog v-model:open="isOpen">
<DialogPortal>
<DialogOverlay class="bg-black/70" />
<DialogContent
size="md"
class="flex flex-col gap-4 p-6"
@open-auto-focus="onOpen"
>
<DialogHeader>
<DialogTitle>{{ $t('modelManager.addModel') }}</DialogTitle>
<DialogDescription>
{{ $t('modelManager.addModelDescription') }}
</DialogDescription>
</DialogHeader>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="download-url">
{{ $t('modelManager.url') }}
</label>
<Input
id="download-url"
v-model="url"
:placeholder="$t('modelManager.urlPlaceholder')"
@update:model-value="onUrlChanged"
/>
<p v-if="hostHint" class="text-xs text-amber-400">{{ hostHint }}</p>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground">
{{ $t('modelManager.folder') }}
</label>
<SingleSelect
v-model="directory"
:options="folderOptions"
:label="$t('modelManager.selectFolder')"
:loading="isLoadingFolders"
/>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="download-filename">
{{ $t('modelManager.filename') }}
</label>
<Input
id="download-filename"
v-model="filename"
:placeholder="$t('modelManager.filenamePlaceholder')"
@update:model-value="onFilenameEdited"
/>
</div>
<!-- TODO: re-enable once we think we'd want to allow any extension
<label class="flex w-fit items-center gap-2 text-xs text-muted-foreground">
<input v-model="allowAnyExtension" type="checkbox" class="size-4" />
{{ $t('modelManager.allowAnyExtension') }}
</label>
-->
<p
v-if="modelId"
class="truncate rounded-md bg-secondary-background px-2 py-1 text-xs text-muted-foreground"
>
{{ modelId }}
</p>
<p v-if="errorMessage" class="text-xs text-red-400">
{{ errorMessage }}
</p>
<DialogFooter>
<Button variant="secondary" @click="isOpen = false">
{{ $t('g.cancel') }}
</Button>
<Button
variant="primary"
:disabled="!canSubmit"
:loading="isSubmitting"
@click="submit"
>
{{ $t('modelManager.download') }}
</Button>
</DialogFooter>
</DialogContent>
</DialogPortal>
</Dialog>
</template>
<script setup lang="ts">
import { computed, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import Button from '@/components/ui/button/Button.vue'
import Dialog from '@/components/ui/dialog/Dialog.vue'
import DialogContent from '@/components/ui/dialog/DialogContent.vue'
import DialogDescription from '@/components/ui/dialog/DialogDescription.vue'
import DialogFooter from '@/components/ui/dialog/DialogFooter.vue'
import DialogHeader from '@/components/ui/dialog/DialogHeader.vue'
import DialogOverlay from '@/components/ui/dialog/DialogOverlay.vue'
import DialogPortal from '@/components/ui/dialog/DialogPortal.vue'
import DialogTitle from '@/components/ui/dialog/DialogTitle.vue'
import Input from '@/components/ui/input/Input.vue'
import SingleSelect from '@/components/ui/single-select/SingleSelect.vue'
import { useModelTypes } from '@/platform/assets/composables/useModelTypes'
import { useToastStore } from '@/platform/updates/common/toastStore'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
import { DownloadApiError } from '../types'
import {
buildModelId,
filenameFromUrl,
hasModelExtension,
isLikelyAllowedHost,
isValidPathSegment
} from '../utils/modelId'
const isOpen = defineModel<boolean>('open', { required: true })
const { t } = useI18n()
const store = useModelDownloadStore()
const {
modelTypes,
isLoading: isLoadingFolders,
fetchModelTypes
} = useModelTypes()
const url = ref('')
const directory = ref<string | undefined>(undefined)
const filename = ref('')
const isFilenameUserEdited = ref(false)
const allowAnyExtension = ref(false)
const isSubmitting = ref(false)
const errorMessage = ref('')
const folderOptions = computed(() => modelTypes.value)
const modelId = computed(() =>
directory.value && filename.value
? buildModelId(directory.value, filename.value)
: ''
)
const hostHint = computed(() =>
url.value && !isLikelyAllowedHost(url.value)
? t('modelManager.hostNotAllowedHint')
: ''
)
const canSubmit = computed(
() =>
!!url.value &&
!!directory.value &&
isValidPathSegment(filename.value) &&
(allowAnyExtension.value || hasModelExtension(filename.value))
)
function onOpen() {
void fetchModelTypes()
errorMessage.value = ''
}
function onUrlChanged() {
if (!isFilenameUserEdited.value) {
filename.value = filenameFromUrl(url.value)
}
}
function onFilenameEdited() {
isFilenameUserEdited.value = true
}
function reset() {
url.value = ''
directory.value = undefined
filename.value = ''
isFilenameUserEdited.value = false
allowAnyExtension.value = false
errorMessage.value = ''
}
async function submit() {
if (!canSubmit.value) return
isSubmitting.value = true
errorMessage.value = ''
try {
await store.enqueue({
url: url.value,
model_id: modelId.value,
allow_any_extension: allowAnyExtension.value
})
useToastStore().add({
severity: 'success',
summary: t('modelManager.downloadQueued'),
detail: modelId.value,
life: 4000
})
reset()
isOpen.value = false
} catch (error) {
handleEnqueueError(error)
} finally {
isSubmitting.value = false
}
}
function handleEnqueueError(error: unknown) {
if (error instanceof DownloadApiError) {
if (error.is('ALREADY_AVAILABLE')) {
useToastStore().add({
severity: 'info',
summary: t('modelManager.alreadyInstalled'),
detail: modelId.value,
life: 4000
})
reset()
isOpen.value = false
return
}
if (error.is('ALREADY_DOWNLOADING')) {
useToastStore().add({
severity: 'info',
summary: t('modelManager.alreadyDownloading'),
detail: modelId.value,
life: 4000
})
reset()
isOpen.value = false
return
}
errorMessage.value = error.message
return
}
errorMessage.value =
error instanceof Error ? error.message : t('modelManager.actionFailed')
}
</script>

View File

@@ -1,313 +0,0 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { defineComponent, nextTick, reactive, ref } from 'vue'
import { createI18n } from 'vue-i18n'
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { HostCredentialView } from '../types'
import HostCredentialsDialog from './HostCredentialsDialog.vue'
const mockCloseDialog = vi.fn()
const mockToastAdd = vi.fn()
const mockCredentialsStore = reactive({
credentials: ref<HostCredentialView[]>([]),
isLoading: ref(false),
fetchCredentials: vi.fn(),
upsert: vi.fn(),
remove: vi.fn()
})
vi.mock('../stores/hostCredentialsStore', () => ({
useHostCredentialsStore: () => mockCredentialsStore
}))
vi.mock('@/stores/dialogStore', () => ({
useDialogStore: () => ({ closeDialog: mockCloseDialog })
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: () => ({ add: mockToastAdd })
}))
vi.mock('@/components/dialog/confirm/confirmDialog')
const mockShowConfirmDialog = vi.mocked(showConfirmDialog)
interface CapturedConfirmOptions {
footerProps?: {
onConfirm?: () => void | Promise<void>
onCancel?: () => void
}
}
function capturedOptions(): CapturedConfirmOptions {
return mockShowConfirmDialog.mock
.calls[0][0] as unknown as CapturedConfirmOptions
}
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: { en: enMessages },
missingWarn: false,
fallbackWarn: false
})
const stubs = {
Dialog: { template: '<div><slot /></div>' },
DialogPortal: { template: '<div><slot /></div>' },
DialogOverlay: { template: '<div />' },
DialogContent: defineComponent({
emits: ['open-auto-focus'],
mounted() {
this.$emit('open-auto-focus')
},
template: '<div><slot /></div>'
}),
DialogHeader: { template: '<div><slot /></div>' },
DialogTitle: { template: '<div><slot /></div>' },
DialogDescription: { template: '<div><slot /></div>' },
SingleSelect: {
props: ['modelValue', 'options', 'label'],
emits: ['update:modelValue'],
template: `
<select
data-testid="scheme-select"
:value="modelValue ?? ''"
@change="$emit('update:modelValue', $event.target.value)"
>
<option v-for="opt in options" :key="opt.value" :value="opt.value">
{{ opt.name }}
</option>
</select>
`
}
}
function createCredential(
overrides: Partial<HostCredentialView> = {}
): HostCredentialView {
return {
id: 'c1',
host: 'huggingface.co',
auth_scheme: 'bearer',
header_name: null,
query_param: null,
label: null,
match_subdomains: false,
enabled: true,
secret_last4: '1234',
created_at: 0,
updated_at: 0,
...overrides
}
}
function mountDialog(props: { open?: boolean; prefillHost?: string } = {}) {
return render(HostCredentialsDialog, {
props: { open: true, ...props },
global: { plugins: [i18n], stubs }
})
}
describe('HostCredentialsDialog', () => {
beforeEach(() => {
vi.clearAllMocks()
mockCredentialsStore.credentials = []
mockCredentialsStore.fetchCredentials.mockResolvedValue(undefined)
mockShowConfirmDialog.mockReturnValue(
{} as ReturnType<typeof showConfirmDialog>
)
})
it('renders existing credentials with host, scheme, and last 4 of the secret', () => {
mockCredentialsStore.credentials = [
createCredential({
host: 'huggingface.co',
auth_scheme: 'bearer',
secret_last4: '9f21'
})
]
mountDialog()
expect(
screen.getByText('huggingface.co · Bearer token · ••••9f21')
).toBeInTheDocument()
})
it('shows a disabled marker for a disabled credential', () => {
mockCredentialsStore.credentials = [createCredential({ enabled: false })]
mountDialog()
expect(screen.getByText(/Disabled/)).toBeInTheDocument()
})
it('prefills the host from the prefillHost prop on open', async () => {
mountDialog({ prefillHost: 'civitai.com' })
await nextTick()
expect(screen.getByLabelText('Host')).toHaveValue('civitai.com')
})
it('populates the form when editing an existing credential', async () => {
mockCredentialsStore.credentials = [
createCredential({
id: 'c1',
host: 'civitai.com',
auth_scheme: 'header',
header_name: 'X-Api-Key',
label: 'My Civitai key'
})
]
mountDialog()
await userEvent.click(screen.getByTitle('Edit'))
expect(screen.getByLabelText('Host')).toHaveValue('civitai.com')
expect(screen.getByLabelText('Label')).toHaveValue('My Civitai key')
expect(screen.getByText('Update credential')).toBeInTheDocument()
})
it('disables submit until host and secret are filled', async () => {
mountDialog()
const submit = screen.getByRole('button', { name: 'Save' })
expect(submit).toBeDisabled()
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
expect(submit).toBeDisabled()
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
expect(submit).toBeEnabled()
})
it('requires a query parameter name when the scheme is query', async () => {
mountDialog()
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
await userEvent.selectOptions(screen.getByTestId('scheme-select'), 'query')
const submit = screen.getByRole('button', { name: 'Save' })
expect(submit).toBeDisabled()
await userEvent.type(screen.getByLabelText('Query parameter'), 'token')
expect(submit).toBeEnabled()
})
it('shows the header name field and a subdomain warning for the header scheme', async () => {
mountDialog()
await userEvent.selectOptions(screen.getByTestId('scheme-select'), 'header')
expect(screen.getByLabelText('Header name')).toBeInTheDocument()
await userEvent.click(screen.getByLabelText('Match subdomains'))
expect(
screen.getByText(
'Not recommended: hubs redirect to sibling CDN hosts that must not receive your key.'
)
).toBeInTheDocument()
})
it('cancels an in-progress edit and resets the form', async () => {
mockCredentialsStore.credentials = [
createCredential({ id: 'c1', host: 'civitai.com' })
]
mountDialog()
await userEvent.click(screen.getByTitle('Edit'))
expect(screen.getByText('Update credential')).toBeInTheDocument()
await userEvent.click(screen.getByRole('button', { name: 'Cancel' }))
expect(screen.getByText('Add credential')).toBeInTheDocument()
expect(screen.getByLabelText('Host')).toHaveValue('')
})
it('submits the form payload and resets on success', async () => {
mockCredentialsStore.upsert.mockResolvedValue(createCredential())
mountDialog()
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
await userEvent.click(screen.getByRole('button', { name: 'Save' }))
expect(mockCredentialsStore.upsert).toHaveBeenCalledWith({
host: 'huggingface.co',
secret: 's3cret',
auth_scheme: 'bearer',
header_name: null,
query_param: null,
label: null,
match_subdomains: false
})
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'success' })
)
})
expect(screen.getByLabelText('Host')).toHaveValue('')
})
it('shows an inline error message when saving fails', async () => {
mockCredentialsStore.upsert.mockRejectedValue(new Error('save failed'))
mountDialog()
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
await userEvent.click(screen.getByRole('button', { name: 'Save' }))
await vi.waitFor(() => {
expect(screen.getByText('save failed')).toBeInTheDocument()
})
})
it('shows an inline error when the initial credentials fetch fails', async () => {
mockCredentialsStore.fetchCredentials.mockRejectedValue(
new Error('offline')
)
mountDialog()
await vi.waitFor(() => {
expect(screen.getByText('offline')).toBeInTheDocument()
})
})
it('deletes a credential through a confirm dialog', async () => {
mockCredentialsStore.remove.mockResolvedValue(undefined)
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
mountDialog()
await userEvent.click(screen.getByTitle('Delete'))
await capturedOptions().footerProps?.onConfirm?.()
expect(mockCredentialsStore.remove).toHaveBeenCalledWith('c1')
expect(mockCloseDialog).toHaveBeenCalled()
})
it('shows an error toast when deletion fails', async () => {
mockCredentialsStore.remove.mockRejectedValue(new Error('boom'))
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
mountDialog()
await userEvent.click(screen.getByTitle('Delete'))
await capturedOptions().footerProps?.onConfirm?.()
await vi.waitFor(() => {
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error' })
)
})
})
it('does not delete when the confirm dialog is dismissed', async () => {
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
mountDialog()
await userEvent.click(screen.getByTitle('Delete'))
capturedOptions().footerProps?.onCancel?.()
expect(mockCredentialsStore.remove).not.toHaveBeenCalled()
})
})

View File

@@ -1,333 +0,0 @@
<template>
<Dialog v-model:open="isOpen">
<DialogPortal>
<DialogOverlay class="bg-black/70" />
<DialogContent
size="md"
class="flex flex-col gap-4 p-6"
@open-auto-focus="onOpen"
>
<DialogHeader>
<DialogTitle>{{ $t('modelManager.credentials.title') }}</DialogTitle>
<DialogDescription>
{{ $t('modelManager.credentials.description') }}
</DialogDescription>
</DialogHeader>
<div v-if="credentials.length" class="flex flex-col gap-2">
<div
v-for="credential in credentials"
:key="credential.id"
class="flex items-center gap-2 rounded-lg border border-border-default bg-secondary-background px-3 py-2"
>
<div class="flex min-w-0 flex-1 flex-col">
<span class="truncate text-sm font-medium text-base-foreground">
{{ credential.label || credential.host }}
</span>
<span class="truncate text-xs text-muted-foreground">
{{ credential.host }} ·
{{
$t(
`modelManager.credentials.scheme.${credential.auth_scheme}`
)
}}
<template v-if="credential.secret_last4">
· {{ credential.secret_last4 }}
</template>
<template v-if="!credential.enabled">
· {{ $t('modelManager.credentials.disabled') }}
</template>
</span>
</div>
<Button
variant="textonly"
size="icon"
:title="$t('modelManager.credentials.edit')"
@click="editCredential(credential)"
>
<i class="icon-[lucide--pencil] size-4" />
</Button>
<Button
variant="textonly"
size="icon"
:title="$t('g.delete')"
@click="confirmDelete(credential)"
>
<i class="icon-[lucide--trash-2] size-4 text-red-400" />
</Button>
</div>
</div>
<div class="flex flex-col gap-3 border-t border-border-default pt-3">
<span class="text-sm font-medium text-base-foreground">
{{
form.id
? $t('modelManager.credentials.update')
: $t('modelManager.credentials.add')
}}
</span>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-host">
{{ $t('modelManager.credentials.host') }}
</label>
<Input
id="cred-host"
v-model="form.host"
:placeholder="HOST_EXAMPLE"
/>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-secret">
{{ $t('modelManager.credentials.secret') }}
</label>
<Input
id="cred-secret"
v-model="form.secret"
type="password"
:placeholder="$t('modelManager.credentials.secretPlaceholder')"
/>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground">
{{ $t('modelManager.credentials.authScheme') }}
</label>
<SingleSelect v-model="form.auth_scheme" :options="schemeOptions" />
</div>
<div v-if="form.auth_scheme === 'header'" class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-header">
{{ $t('modelManager.credentials.headerName') }}
</label>
<Input
id="cred-header"
v-model="form.header_name"
:placeholder="HEADER_NAME_EXAMPLE"
/>
</div>
<div v-if="form.auth_scheme === 'query'" class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-query">
{{ $t('modelManager.credentials.queryParam') }}
</label>
<Input
id="cred-query"
v-model="form.query_param"
:placeholder="QUERY_PARAM_EXAMPLE"
/>
</div>
<div class="flex flex-col gap-1">
<label class="text-xs text-muted-foreground" for="cred-label">
{{ $t('modelManager.credentials.label') }}
</label>
<Input id="cred-label" v-model="form.label" />
</div>
<label class="flex items-center gap-2 text-xs text-muted-foreground">
<input
v-model="form.match_subdomains"
type="checkbox"
class="size-4"
/>
{{ $t('modelManager.credentials.matchSubdomains') }}
</label>
<p v-if="form.match_subdomains" class="text-xs text-amber-400">
{{ $t('modelManager.credentials.matchSubdomainsWarning') }}
</p>
<p v-if="errorMessage" class="text-xs text-red-400">
{{ errorMessage }}
</p>
<div class="flex justify-end gap-2">
<Button v-if="form.id" variant="secondary" @click="resetForm">
{{ $t('g.cancel') }}
</Button>
<Button
variant="primary"
:disabled="!canSubmit"
:loading="isSubmitting"
@click="submit"
>
{{ $t('modelManager.credentials.save') }}
</Button>
</div>
</div>
</DialogContent>
</DialogPortal>
</Dialog>
</template>
<script setup lang="ts">
import { computed, reactive, ref } from 'vue'
import { useI18n } from 'vue-i18n'
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
import Button from '@/components/ui/button/Button.vue'
import Dialog from '@/components/ui/dialog/Dialog.vue'
import DialogContent from '@/components/ui/dialog/DialogContent.vue'
import DialogDescription from '@/components/ui/dialog/DialogDescription.vue'
import DialogHeader from '@/components/ui/dialog/DialogHeader.vue'
import DialogOverlay from '@/components/ui/dialog/DialogOverlay.vue'
import DialogPortal from '@/components/ui/dialog/DialogPortal.vue'
import DialogTitle from '@/components/ui/dialog/DialogTitle.vue'
import Input from '@/components/ui/input/Input.vue'
import SingleSelect from '@/components/ui/single-select/SingleSelect.vue'
import { useToastStore } from '@/platform/updates/common/toastStore'
import { storeToRefs } from 'pinia'
import { useDialogStore } from '@/stores/dialogStore'
import { useHostCredentialsStore } from '../stores/hostCredentialsStore'
import { AUTH_SCHEMES } from '../types'
import type { AuthScheme, HostCredentialView } from '../types'
const HOST_EXAMPLE = 'huggingface.co'
const HEADER_NAME_EXAMPLE = 'Authorization'
const QUERY_PARAM_EXAMPLE = 'token'
const { prefillHost = '' } = defineProps<{ prefillHost?: string }>()
const isOpen = defineModel<boolean>('open', { required: true })
const { t } = useI18n()
const store = useHostCredentialsStore()
const { credentials } = storeToRefs(store)
const dialogStore = useDialogStore()
const isSubmitting = ref(false)
const errorMessage = ref('')
interface CredentialForm {
id: string | null
host: string
secret: string
auth_scheme: AuthScheme
header_name: string
query_param: string
label: string
match_subdomains: boolean
}
function emptyForm(): CredentialForm {
return {
id: null,
host: '',
secret: '',
auth_scheme: 'bearer',
header_name: '',
query_param: '',
label: '',
match_subdomains: false
}
}
const form = reactive<CredentialForm>(emptyForm())
const schemeOptions = computed(() =>
AUTH_SCHEMES.map((scheme) => ({
value: scheme,
name: t(`modelManager.credentials.scheme.${scheme}`)
}))
)
const canSubmit = computed(
() =>
!!form.host.trim() &&
!!form.secret &&
(form.auth_scheme !== 'query' || !!form.query_param.trim()) &&
(form.auth_scheme !== 'header' || !!form.header_name.trim())
)
async function onOpen() {
resetForm()
if (prefillHost) {
form.host = prefillHost
}
try {
await store.fetchCredentials()
} catch (error) {
errorMessage.value =
error instanceof Error ? error.message : t('modelManager.actionFailed')
}
}
function resetForm() {
Object.assign(form, emptyForm())
errorMessage.value = ''
}
function editCredential(credential: HostCredentialView) {
Object.assign(form, {
id: credential.id,
host: credential.host,
secret: '',
auth_scheme: credential.auth_scheme,
header_name: credential.header_name ?? '',
query_param: credential.query_param ?? '',
label: credential.label ?? '',
match_subdomains: credential.match_subdomains
})
errorMessage.value = ''
}
async function submit() {
if (!canSubmit.value) return
isSubmitting.value = true
errorMessage.value = ''
try {
await store.upsert({
host: form.host,
secret: form.secret,
auth_scheme: form.auth_scheme,
header_name: form.auth_scheme === 'header' ? form.header_name : null,
query_param: form.auth_scheme === 'query' ? form.query_param : null,
label: form.label || null,
match_subdomains: form.match_subdomains
})
useToastStore().add({
severity: 'success',
summary: t('modelManager.credentials.saved'),
detail: form.host,
life: 4000
})
resetForm()
} catch (error) {
errorMessage.value =
error instanceof Error ? error.message : t('modelManager.actionFailed')
} finally {
isSubmitting.value = false
}
}
function confirmDelete(credential: HostCredentialView) {
const dialog = showConfirmDialog({
headerProps: { title: t('modelManager.credentials.deleteTitle') },
props: {
promptText: t('modelManager.credentials.deleteMessage', {
host: credential.host
})
},
footerProps: {
confirmText: t('g.delete'),
confirmVariant: 'destructive' as const,
onCancel: () => dialogStore.closeDialog(dialog),
onConfirm: async () => {
dialogStore.closeDialog(dialog)
try {
await store.remove(credential.id)
} catch (error) {
useToastStore().add({
severity: 'error',
summary: t('modelManager.actionFailed'),
detail: error instanceof Error ? error.message : String(error),
life: 5000
})
}
}
}
})
}
</script>

View File

@@ -1,336 +0,0 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createI18n } from 'vue-i18n'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { DownloadStatus } from '../types'
import ModelDownloadRow from './ModelDownloadRow.vue'
const mockPause = vi.fn()
const mockResume = vi.fn()
const mockCancel = vi.fn()
const mockRaisePriority = vi.fn()
const mockRemove = vi.fn()
vi.mock('../composables/useModelDownloadActions', () => ({
useModelDownloadActions: () => ({
pause: mockPause,
resume: mockResume,
cancel: mockCancel,
raisePriority: mockRaisePriority,
remove: mockRemove,
toastError: vi.fn()
})
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: { en: enMessages },
missingWarn: false,
fallbackWarn: false
})
function createDownload(
overrides: Partial<DownloadStatus> = {}
): DownloadStatus {
return {
download_id: 'd1',
model_id: 'loras/x.safetensors',
url: 'https://huggingface.co/org/x.safetensors',
status: 'active',
priority: 0,
total_bytes: 2048,
bytes_done: 1024,
progress: 0.5,
speed_bps: 512,
eta_seconds: 125,
segments: null,
error: null,
created_at: 0,
updated_at: 0,
...overrides
}
}
function mountRow(
download: DownloadStatus,
onOpenCredentials?: (host: string) => void
) {
return render(ModelDownloadRow, {
props: {
download,
...(onOpenCredentials ? { onOpenCredentials } : {})
},
global: { plugins: [i18n] }
})
}
describe('ModelDownloadRow', () => {
beforeEach(() => {
vi.clearAllMocks()
})
it('splits the model id into directory and filename', () => {
mountRow(createDownload({ model_id: 'loras/x.safetensors' }))
expect(screen.getByText('x.safetensors')).toBeInTheDocument()
expect(screen.getByText('loras')).toBeInTheDocument()
})
it('renders an empty directory when the model id has no folder', () => {
mountRow(createDownload({ model_id: 'x.safetensors' }))
expect(screen.getByText('x.safetensors')).toBeInTheDocument()
})
it('formats the meta line with percent, size, speed, and eta while active', () => {
mountRow(createDownload({ status: 'active' }))
expect(
screen.getByText('50% · 1 KB / 2 KB · 512 B/s · 2:05')
).toBeInTheDocument()
})
it('renders an empty meta line when no progress metrics are known yet', () => {
mountRow(
createDownload({
status: 'queued',
progress: null,
total_bytes: null,
bytes_done: 0,
speed_bps: null,
eta_seconds: null
})
)
expect(screen.getByTestId('meta-line')).toBeEmptyDOMElement()
})
it('omits the eta once a download is no longer active', () => {
mountRow(
createDownload({
status: 'paused',
progress: 0.5,
eta_seconds: 125
})
)
expect(screen.getByText('50% · 1 KB / 2 KB · 512 B/s')).toBeInTheDocument()
})
describe('progress bar visibility', () => {
it('shows a progress bar for an active download with known progress', () => {
mountRow(createDownload({ status: 'active', progress: 0.5 }))
expect(screen.getByTestId('progress-bar')).toBeInTheDocument()
})
it('hides the progress bar for a cancelled download', () => {
mountRow(createDownload({ status: 'cancelled', progress: 0.5 }))
expect(screen.queryByTestId('progress-bar')).not.toBeInTheDocument()
expect(screen.getByText('Cancelled')).toBeInTheDocument()
})
})
describe('action buttons by status', () => {
it('shows pause and cancel for a queued download, plus raise priority', () => {
mountRow(createDownload({ status: 'queued' }))
expect(screen.getByTitle('Raise priority')).toBeInTheDocument()
expect(screen.getByTitle('Pause')).toBeInTheDocument()
expect(screen.getByTitle('Cancel')).toBeInTheDocument()
expect(screen.queryByTitle('Resume')).not.toBeInTheDocument()
expect(screen.queryByTitle('Remove from list')).not.toBeInTheDocument()
})
it('shows only resume for a failed download without an auth error', () => {
mountRow(createDownload({ status: 'failed', error: 'disk full' }))
expect(screen.getByTitle('Resume')).toBeInTheDocument()
expect(screen.queryByTitle('Pause')).not.toBeInTheDocument()
expect(screen.queryByTitle('Cancel')).not.toBeInTheDocument()
expect(screen.queryByTitle('Remove from list')).not.toBeInTheDocument()
expect(screen.queryByTitle('Add credentials')).not.toBeInTheDocument()
})
it('shows the remove action for terminal downloads', () => {
mountRow(createDownload({ status: 'completed' }))
expect(screen.getByTitle('Remove from list')).toBeInTheDocument()
expect(screen.queryByTitle('Cancel')).not.toBeInTheDocument()
expect(screen.queryByTitle('Resume')).not.toBeInTheDocument()
})
})
describe('auth errors', () => {
it('shows the credentials button and a host-specific hint', async () => {
const onOpenCredentials = vi.fn()
mountRow(
createDownload({
status: 'failed',
url: 'https://huggingface.co/org/x.safetensors',
error: '401 Unauthorized'
}),
onOpenCredentials
)
expect(
screen.getByText(
'huggingface.co needs an API key. Add one in the Credentials Manager, then resume.'
)
).toBeInTheDocument()
await userEvent.click(screen.getByTitle('Add credentials'))
expect(onOpenCredentials).toHaveBeenCalledWith('huggingface.co')
})
it('falls back to a hostless hint when the url cannot be parsed', () => {
mountRow(
createDownload({
status: 'failed',
url: 'not-a-url',
error: '403 forbidden'
})
)
expect(
screen.getByText(
'This host/model needs an API key. Add one in the Credentials Manager, then resume.'
)
).toBeInTheDocument()
})
it('shows the raw error message for a non-auth failure', () => {
mountRow(createDownload({ status: 'failed', error: 'disk full' }))
expect(screen.getByText('disk full')).toBeInTheDocument()
expect(screen.queryByTitle('Add credentials')).not.toBeInTheDocument()
})
it('hides a leftover error while the download is not yet terminal', () => {
mountRow(createDownload({ status: 'active', error: 'disk full' }))
expect(screen.queryByText('disk full')).not.toBeInTheDocument()
})
})
describe('gated models', () => {
const gatedError =
'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/vae/diffusion_pytorch_model.safetensors is a gated model — Access to model black-forest-labs/FLUX.2-dev is restricted. You must have access to it and be authenticated to access it.'
it('shows the gated hint, credentials button, and a link to accept the license on the model page', () => {
mountRow(
createDownload({
status: 'failed',
url: 'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/vae/diffusion_pytorch_model.safetensors',
error: gatedError
})
)
expect(
screen.getByText(
"This model is gated. Accept its license on the model's page, then add an API key and resume."
)
).toBeInTheDocument()
expect(screen.getByTitle('Add credentials')).toBeInTheDocument()
const link = screen.getByRole('link', { name: 'Accept license' })
expect(link).toHaveAttribute(
'href',
'https://huggingface.co/black-forest-labs/FLUX.2-dev'
)
})
it('derives the model page from the error when the download url is a cdn link', () => {
mountRow(
createDownload({
status: 'failed',
url: 'https://cas-bridge.xethub.hf.co/xet-bridge-us/abc/def',
error: gatedError
})
)
expect(
screen.getByRole('link', { name: 'Accept license' })
).toHaveAttribute(
'href',
'https://huggingface.co/black-forest-labs/FLUX.2-dev'
)
})
it('hides the raw backend error text for gated failures', () => {
mountRow(
createDownload({
status: 'failed',
url: 'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/model.safetensors',
error: gatedError
})
)
expect(screen.queryByText(gatedError)).not.toBeInTheDocument()
})
it('omits the accept-license link when no huggingface url is present', () => {
mountRow(
createDownload({
status: 'failed',
url: 'https://example.com/model.safetensors',
error: 'Access to this model is restricted, request access.'
})
)
expect(
screen.queryByRole('link', { name: 'Accept license' })
).not.toBeInTheDocument()
})
})
describe('user actions', () => {
it('pauses on click', async () => {
const download = createDownload({ status: 'active' })
mountRow(download)
await userEvent.click(screen.getByTitle('Pause'))
expect(mockPause).toHaveBeenCalledWith(download)
})
it('resumes on click', async () => {
const download = createDownload({ status: 'paused' })
mountRow(download)
await userEvent.click(screen.getByTitle('Resume'))
expect(mockResume).toHaveBeenCalledWith(download)
})
it('cancels on click', async () => {
const download = createDownload({ status: 'active' })
mountRow(download)
await userEvent.click(screen.getByTitle('Cancel'))
expect(mockCancel).toHaveBeenCalledWith(download)
})
it('raises priority by 1 on click', async () => {
const download = createDownload({ status: 'queued', priority: 2 })
mountRow(download)
await userEvent.click(screen.getByTitle('Raise priority'))
expect(mockRaisePriority).toHaveBeenCalledWith(download, 1)
})
it('removes the download on click', async () => {
const download = createDownload({
download_id: 'd1',
status: 'completed'
})
mountRow(download)
await userEvent.click(screen.getByTitle('Remove from list'))
expect(mockRemove).toHaveBeenCalledWith(download)
})
})
})

View File

@@ -1,245 +0,0 @@
<template>
<div
:class="
cn(
'relative flex flex-col gap-1 overflow-hidden rounded-lg border border-border-default bg-secondary-background px-3 py-2',
isCancelled && 'opacity-60'
)
"
>
<div
v-if="showProgressBar"
:class="progressBarContainerClass"
data-testid="progress-bar"
>
<div :class="progressBarPrimaryClass" :style="barStyle" />
</div>
<div class="relative flex items-center gap-2">
<div class="flex min-w-0 flex-1 flex-col">
<span class="truncate text-sm font-medium text-base-foreground">
{{ filename }}
</span>
<span class="truncate text-xs text-muted-foreground">
{{ directory }}
</span>
</div>
<div class="flex shrink-0 items-center gap-0.5">
<template v-if="canRaisePriority">
<Button
variant="textonly"
size="icon"
:title="$t('modelManager.raisePriority')"
@click="actions.raisePriority(download, 1)"
>
<i class="icon-[lucide--chevron-up] size-4" />
</Button>
</template>
<Button
v-if="canPause"
variant="textonly"
size="icon"
:title="$t('g.pause')"
@click="actions.pause(download)"
>
<i class="icon-[lucide--pause] size-4" />
</Button>
<Button
v-if="isAuthError"
variant="textonly"
size="icon"
:title="$t('modelManager.addCredentials')"
@click="emit('openCredentials', host)"
>
<i class="icon-[lucide--key-round] size-4" />
</Button>
<Button
v-if="canResume"
variant="textonly"
size="icon"
:title="$t('modelManager.resume')"
@click="actions.resume(download)"
>
<i class="icon-[lucide--play] size-4" />
</Button>
<Button
v-if="canCancel"
variant="textonly"
size="icon"
:title="$t('g.cancel')"
@click="actions.cancel(download)"
>
<i class="icon-[lucide--x] size-4 text-red-400" />
</Button>
<Button
v-if="isTerminal"
variant="textonly"
size="icon"
:title="$t('modelManager.removeFromList')"
@click="actions.remove(download)"
>
<i class="icon-[lucide--x] size-4" />
</Button>
</div>
</div>
<div
class="relative flex items-center justify-between gap-2 text-xs text-muted-foreground"
>
<span class="flex items-center gap-1">
<i v-if="isCancelled" class="icon-[lucide--ban] size-3.5" />
{{ statusLabel }}
</span>
<span class="truncate" data-testid="meta-line">{{ metaLine }}</span>
</div>
<p
v-if="isGatedModel"
class="relative text-xs wrap-break-word text-amber-400"
>
{{ $t('modelManager.gatedModelHint') }}
<a
v-if="modelPageUrl"
:href="modelPageUrl"
target="_blank"
rel="noopener noreferrer"
class="underline"
>
{{ $t('modelManager.openModelPage') }}
</a>
</p>
<p
v-else-if="isAuthError"
class="relative text-xs wrap-break-word text-amber-400"
>
{{
host
? $t('modelManager.authErrorHint', { host })
: $t('modelManager.authErrorHintNoHost')
}}
</p>
<p
v-else-if="isFailed && download.error"
class="relative text-xs wrap-break-word text-red-400"
>
{{ download.error }}
</p>
</div>
</template>
<script setup lang="ts">
import { cn } from '@comfyorg/tailwind-utils'
import { computed } from 'vue'
import { useI18n } from 'vue-i18n'
import Button from '@/components/ui/button/Button.vue'
import { useProgressBarBackground } from '@/composables/useProgressBarBackground'
import { formatSize } from '@/utils/formatUtil'
import { useModelDownloadActions } from '../composables/useModelDownloadActions'
import { downloadProgressFraction } from '../stores/modelDownloadStore'
import type { DownloadStatus } from '../types'
import { directoryOf, filenameOf } from '../utils/modelId'
const { download } = defineProps<{ download: DownloadStatus }>()
const emit = defineEmits<{ openCredentials: [host: string] }>()
const { t } = useI18n()
const actions = useModelDownloadActions()
const {
progressBarContainerClass,
progressBarPrimaryClass,
progressPercentStyle
} = useProgressBarBackground()
const directory = computed(() => directoryOf(download.model_id))
const filename = computed(() => filenameOf(download.model_id))
const host = computed(() => {
try {
return new URL(download.url).hostname
} catch {
return ''
}
})
const percent = computed(() => {
const fraction = downloadProgressFraction(download)
return fraction == null ? undefined : Math.round(fraction * 100)
})
const barStyle = computed(() => progressPercentStyle(percent.value))
const isTerminal = computed(() =>
['completed', 'cancelled'].includes(download.status)
)
const isCancelled = computed(() => download.status === 'cancelled')
const showProgressBar = computed(
() =>
download.status !== 'failed' &&
!isCancelled.value &&
percent.value !== undefined
)
const canPause = computed(() => ['queued', 'active'].includes(download.status))
const canResume = computed(() => ['paused', 'failed'].includes(download.status))
const canCancel = computed(
() => !isTerminal.value && download.status !== 'failed'
)
const canRaisePriority = computed(() => download.status === 'queued')
const GATED_MODEL_PATTERN = /gated|restricted|request access|must have access/i
const AUTH_ERROR_PATTERN =
/api key|credential|token|unauthor|forbidden|\b401\b|\b403\b/i
const isFailed = computed(() => download.status === 'failed')
const isGatedModel = computed(
() => isFailed.value && !!download.error?.match(GATED_MODEL_PATTERN)
)
const isAuthError = computed(
() =>
isFailed.value &&
(isGatedModel.value || !!download.error?.match(AUTH_ERROR_PATTERN))
)
const HF_MODEL_URL_PATTERN =
/https?:\/\/huggingface\.co\/([^/\s?#]+)\/([^/\s?#]+)/i
const modelPageUrl = computed(() => {
const match = `${download.url} ${download.error ?? ''}`.match(
HF_MODEL_URL_PATTERN
)
if (!match) return ''
const [, owner, repo] = match
return `https://huggingface.co/${owner}/${repo}`
})
const statusLabel = computed(() => t(`modelManager.status.${download.status}`))
const metaLine = computed(() => {
const parts: string[] = []
if (percent.value !== undefined) parts.push(`${percent.value}%`)
if (download.total_bytes != null) {
parts.push(
`${formatSize(download.bytes_done)} / ${formatSize(download.total_bytes)}`
)
}
if (download.speed_bps) parts.push(`${formatSize(download.speed_bps)}/s`)
if (download.eta_seconds != null && download.status === 'active') {
parts.push(formatEta(download.eta_seconds))
}
return parts.join(' · ')
})
function formatEta(seconds: number): string {
const total = Math.max(0, Math.round(seconds))
const hours = Math.floor(total / 3600)
const minutes = Math.floor((total % 3600) / 60)
const secs = total % 60
const paddedSecs = secs.toString().padStart(2, '0')
if (hours > 0) {
return `${hours}:${minutes.toString().padStart(2, '0')}:${paddedSecs}`
}
return `${minutes}:${paddedSecs}`
}
</script>

View File

@@ -1,178 +0,0 @@
import { render, screen } from '@testing-library/vue'
import userEvent from '@testing-library/user-event'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { reactive, ref } from 'vue'
import { createI18n } from 'vue-i18n'
import enMessages from '@/locales/en/main.json' with { type: 'json' }
import type { DownloadStatus } from '../types'
import ModelManagerSidebarTab from './ModelManagerSidebarTab.vue'
const mockHydrate = vi.fn()
const mockClearHistory = vi.fn()
const mockStore = reactive({
downloadList: ref<DownloadStatus[]>([]),
activeDownloads: ref<DownloadStatus[]>([]),
historyDownloads: ref<DownloadStatus[]>([]),
hydrate: mockHydrate
})
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => mockStore
}))
vi.mock('../composables/useModelDownloadActions', () => ({
useModelDownloadActions: () => ({ clearHistory: mockClearHistory })
}))
vi.mock('./ModelDownloadRow.vue', () => ({
default: {
name: 'ModelDownloadRow',
props: ['download'],
emits: ['openCredentials'],
template:
'<div><span>{{ download.download_id }}</span>' +
'<button @click="$emit(\'openCredentials\', download.model_id)">open-credentials</button></div>'
}
}))
vi.mock('./AddModelByUrlDialog.vue', () => ({
default: {
name: 'AddModelByUrlDialog',
props: ['open'],
template: '<div data-testid="add-model-dialog">{{ open }}</div>'
}
}))
vi.mock('./HostCredentialsDialog.vue', () => ({
default: {
name: 'HostCredentialsDialog',
props: ['open', 'prefillHost'],
template:
'<div data-testid="credentials-dialog">{{ open }}:{{ prefillHost }}</div>'
}
}))
vi.mock('@/components/sidebar/tabs/SidebarTabTemplate.vue', () => ({
default: {
name: 'SidebarTabTemplate',
template: '<div><slot name="tool-buttons" /><slot name="body" /></div>'
}
}))
const i18n = createI18n({
legacy: false,
locale: 'en',
messages: { en: enMessages },
missingWarn: false,
fallbackWarn: false
})
function createDownload(
overrides: Partial<DownloadStatus> = {}
): DownloadStatus {
return {
download_id: 'd1',
model_id: 'loras/x.safetensors',
url: 'https://huggingface.co/org/x.safetensors',
status: 'active',
priority: 0,
total_bytes: null,
bytes_done: 0,
progress: null,
speed_bps: null,
eta_seconds: null,
segments: null,
error: null,
created_at: 0,
updated_at: 0,
...overrides
}
}
function mountTab() {
return render(ModelManagerSidebarTab, { global: { plugins: [i18n] } })
}
describe('ModelManagerSidebarTab', () => {
beforeEach(() => {
vi.clearAllMocks()
mockStore.downloadList = []
mockStore.activeDownloads = []
mockStore.historyDownloads = []
})
it('hydrates the store on mount', () => {
mountTab()
expect(mockHydrate).toHaveBeenCalled()
})
it('shows the empty state when there are no downloads', () => {
mountTab()
expect(screen.getByText('No downloads yet')).toBeInTheDocument()
})
it('renders active downloads under the Active section', () => {
const download = createDownload({ download_id: 'd1' })
mockStore.activeDownloads = [download]
mockStore.downloadList = [download]
mountTab()
expect(screen.getByText('Active')).toBeInTheDocument()
expect(screen.getByText('d1')).toBeInTheDocument()
expect(screen.queryByText('No downloads yet')).not.toBeInTheDocument()
})
it('renders history downloads under the History section with a clear action', async () => {
const download = createDownload({ download_id: 'd2' })
mockStore.historyDownloads = [download]
mockStore.downloadList = [download]
mountTab()
expect(screen.getByText('History')).toBeInTheDocument()
expect(screen.getByText('d2')).toBeInTheDocument()
await userEvent.click(screen.getByText('Clear history'))
expect(mockClearHistory).toHaveBeenCalled()
})
it('opens the add-model dialog from the toolbar button', async () => {
mountTab()
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('false')
await userEvent.click(screen.getByTitle('Add model'))
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('true')
})
it('opens the add-model dialog from the empty state button', async () => {
mountTab()
await userEvent.click(screen.getByText('Add model'))
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('true')
})
it('opens the credentials dialog with an empty prefill from the toolbar button', async () => {
mountTab()
await userEvent.click(screen.getByTitle('Credentials Manager'))
expect(screen.getByTestId('credentials-dialog')).toHaveTextContent('true:')
})
it('opens the credentials dialog prefilled with the row host', async () => {
const download = createDownload({
download_id: 'd1',
model_id: 'loras/x.safetensors'
})
mockStore.activeDownloads = [download]
mockStore.downloadList = [download]
mountTab()
await userEvent.click(screen.getByText('open-credentials'))
expect(screen.getByTestId('credentials-dialog')).toHaveTextContent(
'true:loras/x.safetensors'
)
})
})

View File

@@ -1,103 +0,0 @@
<template>
<SidebarTabTemplate :title="$t('modelManager.title')">
<template #tool-buttons>
<Button
variant="textonly"
size="icon"
:title="$t('modelManager.credentials.title')"
@click="openCredentials('')"
>
<i class="icon-[lucide--key-round] size-4" />
</Button>
<Button
variant="textonly"
size="icon"
:title="$t('modelManager.addModel')"
@click="addOpen = true"
>
<i class="icon-[lucide--plus] size-4" />
</Button>
</template>
<template #body>
<div class="flex flex-col gap-4 p-3">
<section v-if="activeDownloads.length" class="flex flex-col gap-2">
<h3 class="text-xs font-semibold text-muted-foreground uppercase">
{{ $t('modelManager.active') }}
</h3>
<ModelDownloadRow
v-for="download in activeDownloads"
:key="download.download_id"
:download
@open-credentials="openCredentials"
/>
</section>
<section v-if="historyDownloads.length" class="flex flex-col gap-2">
<div class="flex items-center justify-between">
<h3 class="text-xs font-semibold text-muted-foreground uppercase">
{{ $t('modelManager.history') }}
</h3>
<Button variant="link" size="sm" @click="actions.clearHistory()">
{{ $t('modelManager.clearHistory') }}
</Button>
</div>
<ModelDownloadRow
v-for="download in historyDownloads"
:key="download.download_id"
:download
@open-credentials="openCredentials"
/>
</section>
<div
v-if="!store.downloadList.length"
class="flex flex-col items-center gap-3 py-10 text-center text-muted-foreground"
>
<i class="icon-[lucide--download] size-8" />
<p class="text-sm">{{ $t('modelManager.empty') }}</p>
<Button variant="primary" size="sm" @click="addOpen = true">
{{ $t('modelManager.addModel') }}
</Button>
</div>
</div>
</template>
</SidebarTabTemplate>
<AddModelByUrlDialog v-model:open="addOpen" />
<HostCredentialsDialog
v-model:open="credentialsOpen"
:prefill-host="prefillHost"
/>
</template>
<script setup lang="ts">
import { storeToRefs } from 'pinia'
import { onMounted, ref } from 'vue'
import SidebarTabTemplate from '@/components/sidebar/tabs/SidebarTabTemplate.vue'
import Button from '@/components/ui/button/Button.vue'
import AddModelByUrlDialog from './AddModelByUrlDialog.vue'
import HostCredentialsDialog from './HostCredentialsDialog.vue'
import ModelDownloadRow from './ModelDownloadRow.vue'
import { useModelDownloadActions } from '../composables/useModelDownloadActions'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
const store = useModelDownloadStore()
const actions = useModelDownloadActions()
const { activeDownloads, historyDownloads } = storeToRefs(store)
const addOpen = ref(false)
const credentialsOpen = ref(false)
const prefillHost = ref('')
function openCredentials(host: string) {
prefillHost.value = host
credentialsOpen.value = true
}
onMounted(() => {
void store.hydrate()
})
</script>

View File

@@ -1,51 +0,0 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import * as downloadApi from '../api/modelDownloadApi'
import { useModelAvailability } from './useModelAvailability'
vi.mock('../api/modelDownloadApi', () => ({
checkAvailability: vi.fn()
}))
describe('useModelAvailability', () => {
beforeEach(() => {
vi.resetAllMocks()
})
it('short-circuits without calling the API for an empty map', async () => {
const { check, results } = useModelAvailability()
await check({})
expect(downloadApi.checkAvailability).not.toHaveBeenCalled()
expect(results.value).toEqual({})
})
it('stores the availability map from the response', async () => {
vi.mocked(downloadApi.checkAvailability).mockResolvedValue({
models: {
'loras/x.safetensors': { state: 'available', url_allowed: true }
}
})
const { check, results, isChecking } = useModelAvailability()
await check({ 'loras/x.safetensors': 'https://h.co/x' })
expect(results.value['loras/x.safetensors']).toEqual({
state: 'available',
url_allowed: true
})
expect(isChecking.value).toBe(false)
})
it('captures errors and resets the checking flag', async () => {
vi.mocked(downloadApi.checkAvailability).mockRejectedValue(
new Error('boom')
)
const { check, error, isChecking } = useModelAvailability()
await expect(check({ 'loras/x.safetensors': 'u' })).rejects.toThrow('boom')
expect(error.value).toBeInstanceOf(Error)
expect(isChecking.value).toBe(false)
})
})

View File

@@ -1,41 +0,0 @@
import { ref, shallowRef } from 'vue'
import { checkAvailability } from '../api/modelDownloadApi'
import type { AvailabilityEntry } from '../types'
/**
* Bulk-checks whether the models declared by a workflow are present,
* downloading, or missing. Pass a `model_id -> source URL` map and badge
* each model from the returned entries.
*/
export function useModelAvailability() {
const results = shallowRef<Record<string, AvailabilityEntry>>({})
const isChecking = ref(false)
const error = ref<unknown>(null)
async function check(models: Record<string, string>) {
if (Object.keys(models).length === 0) {
results.value = {}
return results.value
}
isChecking.value = true
error.value = null
try {
const response = await checkAvailability(models)
results.value = response.models
return response.models
} catch (e) {
error.value = e
throw e
} finally {
isChecking.value = false
}
}
return {
results,
isChecking,
error,
check
}
}

View File

@@ -1,168 +0,0 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
import type { DownloadStatus } from '../types'
import { DownloadApiError } from '../types'
import { useModelDownloadActions } from './useModelDownloadActions'
const mockStore = {
pause: vi.fn(),
resume: vi.fn(),
cancel: vi.fn(),
setPriority: vi.fn(),
hydrate: vi.fn()
}
const mockToastAdd = vi.fn()
const mockCloseDialog = vi.fn()
vi.mock('vue-i18n', () => ({
useI18n: () => ({ t: (key: string) => key })
}))
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => mockStore
}))
vi.mock('@/platform/updates/common/toastStore', () => ({
useToastStore: () => ({ add: mockToastAdd })
}))
vi.mock('@/stores/dialogStore', () => ({
useDialogStore: () => ({ closeDialog: mockCloseDialog })
}))
vi.mock('@/components/dialog/confirm/confirmDialog')
const mockShowConfirmDialog = vi.mocked(showConfirmDialog)
interface CapturedConfirmOptions {
footerProps?: {
confirmVariant?: string
onConfirm?: () => void | Promise<void>
onCancel?: () => void
}
}
function capturedOptions(): CapturedConfirmOptions {
return mockShowConfirmDialog.mock
.calls[0][0] as unknown as CapturedConfirmOptions
}
function createDownload(
overrides: Partial<DownloadStatus> = {}
): DownloadStatus {
return {
download_id: 'd1',
model_id: 'loras/x.safetensors',
url: 'https://huggingface.co/x.safetensors',
status: 'active',
priority: 0,
total_bytes: null,
bytes_done: 0,
progress: null,
speed_bps: null,
eta_seconds: null,
segments: null,
error: null,
created_at: 0,
updated_at: 0,
...overrides
}
}
describe('useModelDownloadActions', () => {
beforeEach(() => {
vi.clearAllMocks()
mockShowConfirmDialog.mockReturnValue(
{} as ReturnType<typeof showConfirmDialog>
)
})
it('pauses a download by id', async () => {
mockStore.pause.mockResolvedValue(undefined)
const { pause } = useModelDownloadActions()
await pause(createDownload({ download_id: 'd1' }))
expect(mockStore.pause).toHaveBeenCalledWith('d1')
})
it('resumes a download by id', async () => {
mockStore.resume.mockResolvedValue(undefined)
const { resume } = useModelDownloadActions()
await resume(createDownload({ download_id: 'd1' }))
expect(mockStore.resume).toHaveBeenCalledWith('d1')
})
it('raises priority relative to the current value', async () => {
mockStore.setPriority.mockResolvedValue(undefined)
const { raisePriority } = useModelDownloadActions()
await raisePriority(createDownload({ download_id: 'd1', priority: 2 }), 1)
expect(mockStore.setPriority).toHaveBeenCalledWith('d1', 3)
})
it('shows an error toast and re-hydrates the store when an action fails', async () => {
mockStore.pause.mockRejectedValue(new Error('offline'))
mockStore.hydrate.mockResolvedValue(undefined)
const { pause } = useModelDownloadActions()
await pause(createDownload())
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ severity: 'error', detail: 'offline' })
)
expect(mockStore.hydrate).toHaveBeenCalled()
})
it('uses the DownloadApiError message in the toast', async () => {
mockStore.pause.mockRejectedValue(
new DownloadApiError('nope', 'URL_NOT_ALLOWED', 400)
)
mockStore.hydrate.mockResolvedValue(undefined)
const { pause } = useModelDownloadActions()
await pause(createDownload())
expect(mockToastAdd).toHaveBeenCalledWith(
expect.objectContaining({ detail: 'nope' })
)
})
it('swallows a failed re-hydrate after an action error', async () => {
mockStore.pause.mockRejectedValue(new Error('offline'))
mockStore.hydrate.mockRejectedValue(new Error('still offline'))
const { pause } = useModelDownloadActions()
await expect(pause(createDownload())).resolves.toBeUndefined()
})
describe('cancel', () => {
it('opens a destructive confirm dialog and only cancels on confirm', async () => {
mockStore.cancel.mockResolvedValue(undefined)
const { cancel } = useModelDownloadActions()
cancel(createDownload({ download_id: 'd1' }))
expect(capturedOptions().footerProps?.confirmVariant).toBe('destructive')
await capturedOptions().footerProps?.onConfirm?.()
expect(mockStore.cancel).toHaveBeenCalledWith('d1')
expect(mockCloseDialog).toHaveBeenCalled()
})
it('does not cancel when the confirm dialog is dismissed', () => {
const { cancel } = useModelDownloadActions()
cancel(createDownload({ download_id: 'd1' }))
capturedOptions().footerProps?.onCancel?.()
expect(mockStore.cancel).not.toHaveBeenCalled()
expect(mockCloseDialog).toHaveBeenCalled()
})
})
})

View File

@@ -1,97 +0,0 @@
import { useI18n } from 'vue-i18n'
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
import { useToastStore } from '@/platform/updates/common/toastStore'
import { useDialogStore } from '@/stores/dialogStore'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
import type { DownloadStatus } from '../types'
import { DownloadApiError } from '../types'
/**
* Wraps download store mutations with user feedback: optimistic-friendly error
* toasts and a confirmation prompt for the destructive cancel action (which
* deletes the partial file).
*/
export function useModelDownloadActions() {
const store = useModelDownloadStore()
const dialogStore = useDialogStore()
const { t } = useI18n()
function toastError(error: unknown) {
const detail =
error instanceof DownloadApiError || error instanceof Error
? error.message
: String(error)
useToastStore().add({
severity: 'error',
summary: t('modelManager.actionFailed'),
detail,
life: 5000
})
}
async function run(action: () => Promise<void>) {
try {
await action()
} catch (error) {
toastError(error)
// The store applies optimistic status/priority patches before the API
// call; re-fetch the authoritative state so a failed mutation doesn't
// leave the row stuck in the wrong local state.
try {
await store.hydrate()
} catch {
// Server unreachable; the next poll will reconcile.
}
}
}
const pause = (download: DownloadStatus) =>
run(() => store.pause(download.download_id))
const resume = (download: DownloadStatus) =>
run(() => store.resume(download.download_id))
const raisePriority = (download: DownloadStatus, delta: number) =>
run(() =>
store.setPriority(download.download_id, download.priority + delta)
)
const remove = (download: DownloadStatus) =>
run(() => store.remove(download.download_id))
const clearHistory = () => run(() => store.clearHistory())
function cancel(download: DownloadStatus) {
const dialog = showConfirmDialog({
headerProps: { title: t('modelManager.cancelConfirmTitle') },
props: {
promptText: t(
'modelManager.cancelConfirmMessage',
{ name: download.model_id },
{ escapeParameter: false }
)
},
footerProps: {
confirmText: t('modelManager.cancelConfirm'),
confirmVariant: 'destructive' as const,
onCancel: () => dialogStore.closeDialog(dialog),
onConfirm: async () => {
dialogStore.closeDialog(dialog)
await run(() => store.cancel(download.download_id))
}
}
})
}
return {
pause,
resume,
cancel,
raisePriority,
remove,
clearHistory,
toastError
}
}

View File

@@ -1,93 +0,0 @@
import { ref } from 'vue'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { CompletedDownload } from '../stores/modelDownloadStore'
import { useModelDownloadEffects } from './useModelDownloadEffects'
const mockLastCompletedDownload = ref<CompletedDownload | null>(null)
const mockRefreshModelFolder = vi.fn()
const mockRefreshMissingModels = vi.fn()
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => ({
get lastCompletedDownload() {
return mockLastCompletedDownload.value
}
})
}))
vi.mock('@/stores/modelStore', () => ({
useModelStore: () => ({ refreshModelFolder: mockRefreshModelFolder })
}))
vi.mock('@/platform/missingModel/missingModelStore', () => ({
useMissingModelStore: () => ({
refreshMissingModels: mockRefreshMissingModels
})
}))
describe('useModelDownloadEffects', () => {
beforeEach(() => {
vi.clearAllMocks()
mockLastCompletedDownload.value = null
mockRefreshModelFolder.mockResolvedValue(undefined)
})
it('does nothing until a download completes', () => {
useModelDownloadEffects()
expect(mockRefreshModelFolder).not.toHaveBeenCalled()
expect(mockRefreshMissingModels).not.toHaveBeenCalled()
})
it('refreshes the model folder and re-scans missing models on completion', async () => {
useModelDownloadEffects()
mockLastCompletedDownload.value = {
downloadId: 'd1',
modelId: 'loras/x.safetensors',
directory: 'loras',
timestamp: 1
}
await vi.waitFor(() => {
expect(mockRefreshModelFolder).toHaveBeenCalledWith('loras')
})
expect(mockRefreshMissingModels).toHaveBeenCalled()
})
it('skips the folder refresh when the directory is unknown', async () => {
useModelDownloadEffects()
mockLastCompletedDownload.value = {
downloadId: 'd1',
modelId: 'unknown.safetensors',
directory: '',
timestamp: 1
}
await vi.waitFor(() => {
expect(mockRefreshMissingModels).toHaveBeenCalled()
})
expect(mockRefreshModelFolder).not.toHaveBeenCalled()
})
it('still re-scans missing models when the folder refresh fails', async () => {
const consoleWarn = vi.spyOn(console, 'warn').mockImplementation(() => {})
mockRefreshModelFolder.mockRejectedValue(new Error('boom'))
useModelDownloadEffects()
mockLastCompletedDownload.value = {
downloadId: 'd1',
modelId: 'loras/x.safetensors',
directory: 'loras',
timestamp: 1
}
await vi.waitFor(() => {
expect(mockRefreshMissingModels).toHaveBeenCalled()
})
expect(consoleWarn).toHaveBeenCalled()
consoleWarn.mockRestore()
})
})

View File

@@ -1,45 +0,0 @@
import { watch } from 'vue'
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
import { useModelStore } from '@/stores/modelStore'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
/**
* Side effects that run when a server-side download finishes: refresh the
* affected model folder so the file is recognized, and re-scan missing models
* so node errors for the now-present model clear automatically.
*
* Mounted once at the app root so it runs regardless of whether the Model
* Manager panel is open.
*/
export function useModelDownloadEffects() {
const store = useModelDownloadStore()
watch(
() => store.lastCompletedDownload,
async (completed) => {
if (!completed) return
if (completed.directory) {
try {
await useModelStore().refreshModelFolder(completed.directory)
} catch (error) {
console.warn(
'[ModelManager] Failed to refresh model folder after download',
error
)
}
}
try {
await useMissingModelStore().refreshMissingModels()
} catch (error) {
console.warn(
'[ModelManager] Failed to re-scan missing models after download',
error
)
}
}
)
}

View File

@@ -1,44 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
const mockActiveDownloadCount = { value: 0 }
vi.mock('../stores/modelDownloadStore', () => ({
useModelDownloadStore: () => ({
get activeDownloadCount() {
return mockActiveDownloadCount.value
}
})
}))
vi.mock('../components/ModelManagerSidebarTab.vue', () => ({
default: { name: 'ModelManagerSidebarTab' }
}))
import { useModelManagerSidebarTab } from './useModelManagerSidebarTab'
describe('useModelManagerSidebarTab', () => {
it('returns the expected sidebar tab extension shape', () => {
const tab = useModelManagerSidebarTab()
expect(tab.id).toBe('model-manager')
expect(tab.type).toBe('vue')
expect(tab.title).toBe('modelManager.title')
expect(tab.tooltip).toBe('modelManager.title')
expect(tab.label).toBe('modelManager.title')
})
it('shows no badge when there are no active downloads', () => {
mockActiveDownloadCount.value = 0
const tab = useModelManagerSidebarTab()
expect(typeof tab.iconBadge).toBe('function')
expect((tab.iconBadge as () => string | null)()).toBeNull()
})
it('shows the active download count as a badge', () => {
mockActiveDownloadCount.value = 3
const tab = useModelManagerSidebarTab()
expect((tab.iconBadge as () => string | null)()).toBe('3')
})
})

View File

@@ -1,22 +0,0 @@
import { markRaw } from 'vue'
import type { SidebarTabExtension } from '@/types/extensionTypes'
import ModelManagerSidebarTab from '../components/ModelManagerSidebarTab.vue'
import { useModelDownloadStore } from '../stores/modelDownloadStore'
export function useModelManagerSidebarTab(): SidebarTabExtension {
return {
id: 'model-manager',
icon: 'icon-[lucide--download]',
title: 'modelManager.title',
tooltip: 'modelManager.title',
label: 'modelManager.title',
component: markRaw(ModelManagerSidebarTab),
type: 'vue',
iconBadge: () => {
const count = useModelDownloadStore().activeDownloadCount
return count > 0 ? count.toString() : null
}
}
}

View File

@@ -1,191 +0,0 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import * as downloadApi from '../api/modelDownloadApi'
import type { HostCredentialUpsert, HostCredentialView } from '../types'
import { useHostCredentialsStore } from './hostCredentialsStore'
vi.mock('../api/modelDownloadApi', () => ({
listCredentials: vi.fn(),
upsertCredential: vi.fn(),
deleteCredential: vi.fn()
}))
function createCredential(
overrides: Partial<HostCredentialView> = {}
): HostCredentialView {
return {
id: 'c1',
host: 'huggingface.co',
auth_scheme: 'bearer',
header_name: null,
query_param: null,
label: null,
match_subdomains: false,
enabled: true,
secret_last4: '1234',
created_at: 0,
updated_at: 0,
...overrides
}
}
describe('useHostCredentialsStore', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.resetAllMocks()
})
describe('fetchCredentials', () => {
it('toggles isLoading and stores the result', async () => {
const credential = createCredential()
vi.mocked(downloadApi.listCredentials).mockResolvedValue([credential])
const store = useHostCredentialsStore()
const promise = store.fetchCredentials()
expect(store.isLoading).toBe(true)
await promise
expect(store.isLoading).toBe(false)
expect(store.credentials).toEqual([credential])
})
it('clears isLoading even when the request fails', async () => {
vi.mocked(downloadApi.listCredentials).mockRejectedValue(
new Error('boom')
)
const store = useHostCredentialsStore()
await expect(store.fetchCredentials()).rejects.toThrow('boom')
expect(store.isLoading).toBe(false)
})
})
describe('upsert', () => {
it('normalizes the host and appends a new credential', async () => {
const created = createCredential({ id: 'new', host: 'huggingface.co' })
vi.mocked(downloadApi.upsertCredential).mockResolvedValue(created)
const store = useHostCredentialsStore()
const body: HostCredentialUpsert = {
host: ' HuggingFace.co ',
secret: 's3cret'
}
const result = await store.upsert(body)
expect(downloadApi.upsertCredential).toHaveBeenCalledWith(
expect.objectContaining({ host: 'huggingface.co', secret: 's3cret' })
)
expect(result).toEqual(created)
expect(store.credentials).toEqual([created])
})
it('replaces an existing credential by id', async () => {
const original = createCredential({ id: 'c1', label: 'Old' })
const updated = createCredential({ id: 'c1', label: 'New' })
vi.mocked(downloadApi.upsertCredential).mockResolvedValue(updated)
const store = useHostCredentialsStore()
store.credentials.push(original)
await store.upsert({ host: 'huggingface.co', secret: 's' })
expect(store.credentials).toEqual([updated])
})
})
describe('remove', () => {
it('deletes and filters out the credential by id', async () => {
vi.mocked(downloadApi.deleteCredential).mockResolvedValue(undefined)
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({ id: 'c1' }),
createCredential({ id: 'c2' })
)
await store.remove('c1')
expect(downloadApi.deleteCredential).toHaveBeenCalledWith('c1')
expect(store.credentials.map((c) => c.id)).toEqual(['c2'])
})
})
describe('enabledCredentialForHost', () => {
it('matches an exact enabled host case-insensitively', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({ host: 'HuggingFace.co', enabled: true })
)
expect(store.enabledCredentialForHost('huggingface.co')?.host).toBe(
'HuggingFace.co'
)
})
it('returns undefined for a disabled exact match', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({ host: 'huggingface.co', enabled: false })
)
expect(store.enabledCredentialForHost('huggingface.co')).toBeUndefined()
})
it('falls back to a subdomain match only when match_subdomains is set', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({
host: 'huggingface.co',
match_subdomains: true,
enabled: true
})
)
expect(store.enabledCredentialForHost('cdn.huggingface.co')?.host).toBe(
'huggingface.co'
)
})
it('skips a disabled subdomain credential and returns a later enabled match', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({
id: 'disabled',
host: 'huggingface.co',
match_subdomains: true,
enabled: false
}),
createCredential({
id: 'enabled',
host: 'huggingface.co',
match_subdomains: true,
enabled: true
})
)
expect(store.enabledCredentialForHost('cdn.huggingface.co')?.id).toBe(
'enabled'
)
})
it('does not subdomain-match when match_subdomains is false', () => {
const store = useHostCredentialsStore()
store.credentials.push(
createCredential({
host: 'huggingface.co',
match_subdomains: false,
enabled: true
})
)
expect(
store.enabledCredentialForHost('cdn.huggingface.co')
).toBeUndefined()
})
it('returns undefined when no host matches', () => {
const store = useHostCredentialsStore()
expect(store.enabledCredentialForHost('example.com')).toBeUndefined()
})
})
})

View File

@@ -1,75 +0,0 @@
import { defineStore } from 'pinia'
import { computed, ref } from 'vue'
import {
deleteCredential,
listCredentials,
upsertCredential
} from '../api/modelDownloadApi'
import type { HostCredentialUpsert, HostCredentialView } from '../types'
function normalizeHost(host: string): string {
return host.trim().toLowerCase()
}
export const useHostCredentialsStore = defineStore('hostCredentials', () => {
const credentials = ref<HostCredentialView[]>([])
const isLoading = ref(false)
const credentialsByHost = computed(
() => new Map(credentials.value.map((c) => [normalizeHost(c.host), c]))
)
async function fetchCredentials() {
isLoading.value = true
try {
credentials.value = await listCredentials()
} finally {
isLoading.value = false
}
}
async function upsert(
body: HostCredentialUpsert
): Promise<HostCredentialView> {
const view = await upsertCredential({
...body,
host: normalizeHost(body.host)
})
const index = credentials.value.findIndex((c) => c.id === view.id)
credentials.value =
index === -1
? [...credentials.value, view]
: credentials.value.map((c) => (c.id === view.id ? view : c))
return view
}
async function remove(id: string) {
await deleteCredential(id)
credentials.value = credentials.value.filter((c) => c.id !== id)
}
function enabledCredentialForHost(
host: string
): HostCredentialView | undefined {
const normalized = normalizeHost(host)
const exact = credentialsByHost.value.get(normalized)
if (exact) return exact.enabled ? exact : undefined
return credentials.value.find(
(c) =>
c.enabled &&
c.match_subdomains &&
normalized.endsWith(`.${normalizeHost(c.host)}`)
)
}
return {
credentials,
isLoading,
fetchCredentials,
upsert,
remove,
enabledCredentialForHost
}
})

View File

@@ -1,388 +0,0 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
const flushPromises = () => vi.advanceTimersByTimeAsync(0)
import * as downloadApi from '../api/modelDownloadApi'
import type { DownloadState, DownloadStatus } from '../types'
import {
downloadProgressFraction,
useModelDownloadStore
} from './modelDownloadStore'
type ProgressHandler = (e: CustomEvent<DownloadStatus>) => void
const eventHandler = vi.hoisted(() => {
const state: { current: ProgressHandler | null } = { current: null }
return state
})
vi.mock('@/scripts/api', () => ({
api: {
addEventListener: vi.fn((_event: string, handler: ProgressHandler) => {
eventHandler.current = handler
})
}
}))
vi.mock('../api/modelDownloadApi', () => ({
enqueueDownload: vi.fn(),
listDownloads: vi.fn(),
pauseDownload: vi.fn(),
resumeDownload: vi.fn(),
cancelDownload: vi.fn(),
setDownloadPriority: vi.fn(),
deleteDownload: vi.fn(),
clearDownloads: vi.fn()
}))
function createStatus(overrides: Partial<DownloadStatus> = {}): DownloadStatus {
return {
download_id: 'd1',
model_id: 'loras/x.safetensors',
url: 'https://huggingface.co/x.safetensors',
status: 'active',
priority: 0,
total_bytes: 1000,
bytes_done: 250,
progress: 0.25,
speed_bps: 100,
eta_seconds: 10,
segments: null,
error: null,
created_at: 1,
updated_at: 2,
...overrides
}
}
function dispatch(status: DownloadStatus) {
if (!eventHandler.current) throw new Error('handler not registered')
eventHandler.current(new CustomEvent('download_progress', { detail: status }))
}
describe('useModelDownloadStore', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.useFakeTimers({ shouldAdvanceTime: false })
vi.resetAllMocks()
eventHandler.current = null
})
afterEach(() => {
vi.useRealTimers()
})
it('upserts rows from download_progress events', () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', bytes_done: 100 }))
dispatch(createStatus({ download_id: 'd1', bytes_done: 900 }))
dispatch(createStatus({ download_id: 'd2', status: 'queued' }))
expect(store.downloadList).toHaveLength(2)
expect(
store.downloadList.find((d) => d.download_id === 'd1')?.bytes_done
).toBe(900)
})
it('splits active and terminal downloads', () => {
const store = useModelDownloadStore()
const states: DownloadState[] = [
'queued',
'active',
'paused',
'verifying',
'completed',
'failed',
'cancelled'
]
states.forEach((status, idx) =>
dispatch(createStatus({ download_id: `d${idx}`, status }))
)
expect(store.activeDownloads.map((d) => d.status)).toEqual([
'queued',
'active',
'paused',
'verifying'
])
expect(store.historyDownloads.map((d) => d.status)).toEqual([
'completed',
'failed',
'cancelled'
])
expect(store.activeDownloadCount).toBe(4)
})
it('inserts an optimistic queued row on enqueue', async () => {
vi.mocked(downloadApi.enqueueDownload).mockResolvedValue({
download_id: 'new-id',
accepted: true
})
const store = useModelDownloadStore()
const result = await store.enqueue({
url: 'https://huggingface.co/x.safetensors',
model_id: 'loras/x.safetensors'
})
expect(result.download_id).toBe('new-id')
const row = store.downloadList.find((d) => d.download_id === 'new-id')
expect(row?.status).toBe('queued')
expect(row?.model_id).toBe('loras/x.safetensors')
})
it('optimistically updates status when pausing', async () => {
vi.mocked(downloadApi.pauseDownload).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
await store.pause('d1')
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
'paused'
)
expect(downloadApi.pauseDownload).toHaveBeenCalledWith('d1')
})
it('optimistically updates status when resuming', async () => {
vi.mocked(downloadApi.resumeDownload).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'paused' }))
await store.resume('d1')
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
'queued'
)
expect(downloadApi.resumeDownload).toHaveBeenCalledWith('d1')
})
it('marks a download cancelled after the API call resolves', async () => {
vi.mocked(downloadApi.cancelDownload).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
await store.cancel('d1')
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
'cancelled'
)
expect(downloadApi.cancelDownload).toHaveBeenCalledWith('d1')
})
it('optimistically updates priority and calls the API', async () => {
vi.mocked(downloadApi.setDownloadPriority).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'queued', priority: 0 }))
await store.setPriority('d1', 5)
expect(
store.downloadList.find((d) => d.download_id === 'd1')?.priority
).toBe(5)
expect(downloadApi.setDownloadPriority).toHaveBeenCalledWith('d1', 5)
})
it('is a no-op when patching priority for an unknown id', async () => {
vi.mocked(downloadApi.setDownloadPriority).mockResolvedValue()
const store = useModelDownloadStore()
await store.setPriority('missing', 5)
expect(store.downloadList).toHaveLength(0)
expect(downloadApi.setDownloadPriority).toHaveBeenCalledWith('missing', 5)
})
it('finds a download by model id', () => {
const store = useModelDownloadStore()
dispatch(
createStatus({ download_id: 'd1', model_id: 'loras/x.safetensors' })
)
expect(store.findByModelId('loras/x.safetensors')?.download_id).toBe('d1')
expect(store.findByModelId('loras/missing.safetensors')).toBeUndefined()
})
describe('hydrate', () => {
it('replaces the download map with the fetched list', async () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'stale', status: 'active' }))
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
createStatus({ download_id: 'fresh', status: 'active' })
])
await store.hydrate()
expect(store.downloadList.map((d) => d.download_id)).toEqual(['fresh'])
})
it('records a completion when a refreshed row transitions to completed', async () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
createStatus({
download_id: 'd1',
model_id: 'loras/x.safetensors',
status: 'completed'
})
])
await store.hydrate()
expect(store.lastCompletedDownload).toMatchObject({
downloadId: 'd1',
modelId: 'loras/x.safetensors',
directory: 'loras'
})
})
it('does not re-record a completion for a row that was already completed', async () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'completed' }))
const firstTimestamp = store.lastCompletedDownload?.timestamp
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
createStatus({ download_id: 'd1', status: 'completed' })
])
await store.hydrate()
expect(store.lastCompletedDownload?.timestamp).toBe(firstTimestamp)
})
})
it('records the last completed download once on the completing transition', () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
expect(store.lastCompletedDownload).toBeNull()
dispatch(
createStatus({
download_id: 'd1',
model_id: 'loras/x.safetensors',
status: 'completed'
})
)
expect(store.lastCompletedDownload).toMatchObject({
downloadId: 'd1',
modelId: 'loras/x.safetensors',
directory: 'loras'
})
const firstTimestamp = store.lastCompletedDownload?.timestamp
dispatch(createStatus({ download_id: 'd1', status: 'completed' }))
expect(store.lastCompletedDownload?.timestamp).toBe(firstTimestamp)
})
it('refetches when a download fails without an error message', async () => {
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
createStatus({ download_id: 'd1', status: 'failed', error: 'gated' })
])
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
dispatch(createStatus({ download_id: 'd1', status: 'failed', error: null }))
await flushPromises()
expect(downloadApi.listDownloads).toHaveBeenCalledOnce()
expect(store.downloadList.find((d) => d.download_id === 'd1')?.error).toBe(
'gated'
)
})
it('does not refetch when the failure event already carries an error', async () => {
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
dispatch(
createStatus({ download_id: 'd1', status: 'failed', error: 'disk full' })
)
await flushPromises()
expect(downloadApi.listDownloads).not.toHaveBeenCalled()
expect(store.downloadList.find((d) => d.download_id === 'd1')?.error).toBe(
'disk full'
)
})
it('deletes a row through the backend so it stays gone', async () => {
vi.mocked(downloadApi.deleteDownload).mockResolvedValue()
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
await store.remove('d2')
expect(downloadApi.deleteDownload).toHaveBeenCalledWith('d2')
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d1'])
})
it('clears every history row in one backend call, leaving active downloads', async () => {
vi.mocked(downloadApi.clearDownloads).mockResolvedValue(2)
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
dispatch(createStatus({ download_id: 'd3', status: 'failed' }))
await store.clearHistory()
expect(downloadApi.clearDownloads).toHaveBeenCalledOnce()
expect(downloadApi.deleteDownload).not.toHaveBeenCalled()
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d1'])
})
it('keeps history rows locally when the bulk clear fails', async () => {
vi.mocked(downloadApi.clearDownloads).mockRejectedValue(new Error('boom'))
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
await expect(store.clearHistory()).rejects.toThrow('boom')
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d2'])
})
it('keeps the row locally when the backend delete fails', async () => {
vi.mocked(downloadApi.deleteDownload).mockRejectedValue(new Error('boom'))
const store = useModelDownloadStore()
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
await expect(store.remove('d2')).rejects.toThrow('boom')
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d2'])
})
it('polls the list when active downloads go stale', async () => {
vi.mocked(downloadApi.listDownloads).mockResolvedValue([])
useModelDownloadStore()
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
await vi.advanceTimersByTimeAsync(15_000)
expect(downloadApi.listDownloads).toHaveBeenCalled()
})
describe('downloadProgressFraction', () => {
it('uses live progress when present', () => {
expect(downloadProgressFraction(createStatus({ progress: 0.4 }))).toBe(
0.4
)
})
it('derives from bytes when progress is null', () => {
expect(
downloadProgressFraction(
createStatus({ progress: null, bytes_done: 500, total_bytes: 1000 })
)
).toBe(0.5)
})
it('returns null when total is unknown', () => {
expect(
downloadProgressFraction(
createStatus({ progress: null, total_bytes: null })
)
).toBeNull()
})
})
})

View File

@@ -1,229 +0,0 @@
import { useIntervalFn } from '@vueuse/core'
import { defineStore } from 'pinia'
import { computed, ref } from 'vue'
import { api } from '@/scripts/api'
import {
cancelDownload,
clearDownloads,
deleteDownload,
enqueueDownload,
listDownloads,
pauseDownload,
resumeDownload,
setDownloadPriority
} from '../api/modelDownloadApi'
import type {
DownloadState,
DownloadStatus,
EnqueueRequest,
EnqueueResponse
} from '../types'
import { directoryOf } from '../utils/modelId'
const ACTIVE_STATES: ReadonlySet<DownloadState> = new Set([
'queued',
'active',
'paused',
'verifying'
])
const POLL_INTERVAL_MS = 5_000
const STALE_THRESHOLD_MS = 8_000
/**
* Static completion fraction (0..1) for a download. Live values (`progress`)
* are only populated while a worker is running; for paused/terminal rows we
* derive it from `bytes_done / total_bytes`.
*/
export function downloadProgressFraction(
download: DownloadStatus
): number | null {
if (download.progress != null) return download.progress
if (download.total_bytes && download.total_bytes > 0) {
return download.bytes_done / download.total_bytes
}
return null
}
function optimisticRow(
downloadId: string,
request: EnqueueRequest
): DownloadStatus {
const now = Math.floor(Date.now() / 1000)
return {
download_id: downloadId,
model_id: request.model_id,
url: request.url,
status: 'queued',
priority: request.priority ?? 0,
total_bytes: null,
bytes_done: 0,
progress: null,
speed_bps: null,
eta_seconds: null,
segments: null,
error: null,
created_at: now,
updated_at: now
}
}
export interface CompletedDownload {
downloadId: string
modelId: string
directory: string
timestamp: number
}
export const useModelDownloadStore = defineStore('modelDownload', () => {
const downloads = ref<Map<string, DownloadStatus>>(new Map())
const lastWsUpdate = ref(0)
const lastCompletedDownload = ref<CompletedDownload | null>(null)
const downloadList = computed(() => Array.from(downloads.value.values()))
const activeDownloads = computed(() =>
downloadList.value.filter((d) => ACTIVE_STATES.has(d.status))
)
const historyDownloads = computed(() =>
downloadList.value.filter((d) => !ACTIVE_STATES.has(d.status))
)
const hasActiveDownloads = computed(() => activeDownloads.value.length > 0)
const activeDownloadCount = computed(() => activeDownloads.value.length)
function upsert(status: DownloadStatus) {
downloads.value.set(status.download_id, status)
}
function findByModelId(modelId: string): DownloadStatus | undefined {
return downloadList.value.find((d) => d.model_id === modelId)
}
function handleProgress(e: CustomEvent<DownloadStatus>) {
lastWsUpdate.value = Date.now()
const previous = downloads.value.get(e.detail.download_id)
upsert(e.detail)
if (e.detail.status === 'completed' && previous?.status !== 'completed') {
lastCompletedDownload.value = {
downloadId: e.detail.download_id,
modelId: e.detail.model_id,
directory: directoryOf(e.detail.model_id),
timestamp: Date.now()
}
}
// The live failure event carries the terminal status but not the error
// text, and polling stops once nothing is active. Refetch so the failure
// reason surfaces without the user reopening the panel.
if (
e.detail.status === 'failed' &&
previous?.status !== 'failed' &&
!e.detail.error
) {
void hydrate().catch(() => {})
}
}
async function hydrate() {
const previous = downloads.value
const list = await listDownloads()
downloads.value = new Map(list.map((d) => [d.download_id, d]))
for (const download of list) {
const prior = previous.get(download.download_id)
if (
download.status === 'completed' &&
prior &&
prior.status !== 'completed'
) {
lastCompletedDownload.value = {
downloadId: download.download_id,
modelId: download.model_id,
directory: directoryOf(download.model_id),
timestamp: Date.now()
}
}
}
}
async function enqueue(request: EnqueueRequest): Promise<EnqueueResponse> {
const response = await enqueueDownload(request)
upsert(optimisticRow(response.download_id, request))
return response
}
function patchStatus(id: string, status: DownloadState) {
const existing = downloads.value.get(id)
if (existing) {
downloads.value.set(id, { ...existing, status })
}
}
async function pause(id: string) {
patchStatus(id, 'paused')
await pauseDownload(id)
}
async function resume(id: string) {
patchStatus(id, 'queued')
await resumeDownload(id)
}
async function cancel(id: string) {
await cancelDownload(id)
patchStatus(id, 'cancelled')
}
async function setPriority(id: string, priority: number) {
const existing = downloads.value.get(id)
if (existing) {
downloads.value.set(id, { ...existing, priority })
}
await setDownloadPriority(id, priority)
}
async function remove(id: string) {
await deleteDownload(id)
downloads.value.delete(id)
}
async function clearHistory() {
const clearedIds = historyDownloads.value.map((d) => d.download_id)
await clearDownloads()
for (const id of clearedIds) {
downloads.value.delete(id)
}
}
async function pollIfStale() {
if (!hasActiveDownloads.value) return
if (Date.now() - lastWsUpdate.value < STALE_THRESHOLD_MS) return
try {
await hydrate()
} catch {
// Server unreachable; retry on next interval
}
}
useIntervalFn(() => void pollIfStale(), POLL_INTERVAL_MS)
api.addEventListener('download_progress', handleProgress)
return {
downloadList,
activeDownloads,
historyDownloads,
hasActiveDownloads,
activeDownloadCount,
lastCompletedDownload,
upsert,
findByModelId,
hydrate,
enqueue,
pause,
resume,
cancel,
setPriority,
remove,
clearHistory
}
})

View File

@@ -1,130 +0,0 @@
import { z } from 'zod'
import type { DownloadState, DownloadStatus } from '@/schemas/apiSchema'
export type { DownloadState, DownloadStatus }
/**
* Known model file extensions accepted by the backend without
* `allow_any_extension`. Mirrors the server's extension allowlist and is
* used for instant client-side validation only — the server stays the source
* of truth.
*/
export const MODEL_EXTENSIONS = [
'.safetensors',
'.sft',
'.ckpt',
'.pth',
'.pt',
'.gguf',
'.bin'
] as const
/**
* Hosts the backend allows out of the box. Admins can extend this
* server-side, so this list is only for optimistic client-side hints; rely on
* the server's `URL_NOT_ALLOWED` / `url_allowed` as the source of truth.
*/
export const DEFAULT_ALLOWED_HOSTS = [
'huggingface.co',
'civitai.com',
'localhost',
'127.0.0.1'
] as const
export interface EnqueueRequest {
url: string
model_id: string
priority?: number
expected_sha256?: string | null
allow_any_extension?: boolean
credential_id?: string | null
}
export interface EnqueueResponse {
download_id: string
accepted: boolean
}
export const AUTH_SCHEMES = ['bearer', 'header', 'query'] as const
export type AuthScheme = (typeof AUTH_SCHEMES)[number]
const zHostCredentialView = z.object({
id: z.string(),
host: z.string(),
auth_scheme: z.enum(AUTH_SCHEMES),
header_name: z.string().nullable(),
query_param: z.string().nullable(),
label: z.string().nullable(),
match_subdomains: z.boolean(),
enabled: z.boolean(),
secret_last4: z.string().nullable(),
created_at: z.number(),
updated_at: z.number()
})
export type HostCredentialView = z.infer<typeof zHostCredentialView>
export interface HostCredentialUpsert {
host: string
secret: string
auth_scheme?: AuthScheme
header_name?: string | null
query_param?: string | null
label?: string | null
match_subdomains?: boolean
enabled?: boolean
}
interface AvailabilityBase {
url_allowed: boolean
}
export type AvailabilityEntry =
| (AvailabilityBase & { state: 'available' })
| (AvailabilityBase & { state: 'missing' })
| (AvailabilityBase & {
state: 'downloading'
download_id: string
progress: number | null
bytes_done: number
total_bytes: number | null
speed_bps: number | null
})
export interface AvailabilityResponse {
models: Record<string, AvailabilityEntry>
}
const DOWNLOAD_ERROR_CODES = [
'INVALID_JSON',
'INVALID_BODY',
'URL_NOT_ALLOWED',
'INVALID_MODEL_ID',
'INVALID_CREDENTIAL',
'ALREADY_AVAILABLE',
'ALREADY_DOWNLOADING',
'DOWNLOAD_ACTIVE',
'NOT_FOUND'
] as const
export type DownloadErrorCode = (typeof DOWNLOAD_ERROR_CODES)[number]
/**
* Error envelope returned by every download-manager endpoint on failure.
* `code` is the stable machine-readable discriminator; `message` is
* user-facing; `details` is an open object (do not assume a shape).
*/
export class DownloadApiError extends Error {
constructor(
message: string,
public readonly code: string,
public readonly status: number,
public readonly details?: Record<string, unknown>
) {
super(message)
this.name = 'DownloadApiError'
}
is(code: DownloadErrorCode): boolean {
return this.code === code
}
}

View File

@@ -1,99 +0,0 @@
import { describe, expect, it } from 'vitest'
import {
buildModelId,
filenameFromUrl,
hasModelExtension,
hostFromUrl,
isLikelyAllowedHost,
isValidPathSegment
} from './modelId'
describe('modelId utils', () => {
describe('isValidPathSegment', () => {
it('accepts valid filenames', () => {
expect(isValidPathSegment('my_lora.safetensors')).toBe(true)
expect(isValidPathSegment('sdxl-base.ckpt')).toBe(true)
})
it('rejects traversal, leading dots, and slashes', () => {
expect(isValidPathSegment('..')).toBe(false)
expect(isValidPathSegment('.hidden')).toBe(false)
expect(isValidPathSegment('nested/path')).toBe(false)
expect(isValidPathSegment('')).toBe(false)
})
})
describe('hasModelExtension', () => {
it('matches known extensions case-insensitively', () => {
expect(hasModelExtension('model.SAFETENSORS')).toBe(true)
expect(hasModelExtension('weights.gguf')).toBe(true)
})
it('rejects unknown extensions', () => {
expect(hasModelExtension('readme.txt')).toBe(false)
expect(hasModelExtension('archive.zip')).toBe(false)
})
})
describe('buildModelId', () => {
it('joins directory and filename with a single slash', () => {
expect(buildModelId('loras', 'x.safetensors')).toBe('loras/x.safetensors')
})
it('collapses redundant slashes at the join boundary', () => {
expect(buildModelId('loras/', 'x.safetensors')).toBe(
'loras/x.safetensors'
)
expect(buildModelId('loras', '/x.safetensors')).toBe(
'loras/x.safetensors'
)
expect(buildModelId('loras//', '//x.safetensors')).toBe(
'loras/x.safetensors'
)
})
})
describe('hostFromUrl', () => {
it('extracts the lowercased host', () => {
expect(hostFromUrl('https://HuggingFace.co/a/b')).toBe('huggingface.co')
})
it('returns null for invalid URLs', () => {
expect(hostFromUrl('not a url')).toBeNull()
})
})
describe('isLikelyAllowedHost', () => {
it('allows built-in hosts and their subdomains', () => {
expect(isLikelyAllowedHost('https://huggingface.co/x')).toBe(true)
expect(isLikelyAllowedHost('https://cdn.huggingface.co/x')).toBe(true)
expect(isLikelyAllowedHost('https://civitai.com/api/download/1')).toBe(
true
)
})
it('flags unknown hosts', () => {
expect(isLikelyAllowedHost('https://example.com/x')).toBe(false)
expect(isLikelyAllowedHost('garbage')).toBe(false)
})
it('does not treat a fake subdomain of an allowed IP literal as allowed', () => {
expect(isLikelyAllowedHost('https://127.0.0.1/x')).toBe(true)
expect(isLikelyAllowedHost('https://evil.127.0.0.1/x')).toBe(false)
})
})
describe('filenameFromUrl', () => {
it('returns the decoded trailing path segment', () => {
expect(filenameFromUrl('https://h.co/a/my%20model.safetensors')).toBe(
'my model.safetensors'
)
})
it('returns empty string when no segment is present', () => {
expect(filenameFromUrl('https://h.co/')).toBe('')
expect(filenameFromUrl('bad')).toBe('')
})
})
})

View File

@@ -1,79 +0,0 @@
import { DEFAULT_ALLOWED_HOSTS, MODEL_EXTENSIONS } from '../types'
const SEGMENT_PATTERN = /^[A-Za-z0-9][A-Za-z0-9._-]*$/
export function isValidPathSegment(segment: string): boolean {
return SEGMENT_PATTERN.test(segment)
}
export function hasModelExtension(filename: string): boolean {
const lower = filename.toLowerCase()
return MODEL_EXTENSIONS.some((ext) => lower.endsWith(ext))
}
export function buildModelId(directory: string, filename: string): string {
return `${directory.replace(/\/+$/, '')}/${filename.replace(/^\/+/, '')}`
}
/**
* Directory portion of a `directory/filename` model id, or an empty string when
* the id has no directory prefix.
*/
export function directoryOf(modelId: string): string {
const slash = modelId.indexOf('/')
return slash === -1 ? '' : modelId.slice(0, slash)
}
/**
* Filename portion of a `directory/filename` model id, falling back to the full
* id when there is no directory prefix.
*/
export function filenameOf(modelId: string): string {
const slash = modelId.indexOf('/')
return slash === -1 ? modelId : modelId.slice(slash + 1)
}
/**
* Lowercased host of a URL, or `null` when the URL is unparseable.
*/
export function hostFromUrl(url: string): string | null {
try {
return new URL(url).hostname.toLowerCase()
} catch {
return null
}
}
const IPV4_PATTERN = /^\d{1,3}(?:\.\d{1,3}){3}$/
function isIpLiteral(host: string): boolean {
return IPV4_PATTERN.test(host) || host.includes(':')
}
/**
* Optimistic client-side allowlist hint. The server can extend the
* allowlist, so a `false` here is advisory — defer to `URL_NOT_ALLOWED`.
*/
export function isLikelyAllowedHost(url: string): boolean {
const host = hostFromUrl(url)
if (!host) return false
return DEFAULT_ALLOWED_HOSTS.some(
(allowed) =>
host === allowed ||
(!isIpLiteral(allowed) && host.endsWith(`.${allowed}`))
)
}
/**
* Best-effort filename guess from a URL path, for prefilling the model_id
* filename field. May be empty when the URL has no usable trailing segment.
*/
export function filenameFromUrl(url: string): string {
try {
const { pathname } = new URL(url)
const last = pathname.split('/').filter(Boolean).pop() ?? ''
return decodeURIComponent(last)
} catch {
return ''
}
}

View File

@@ -30,6 +30,39 @@ describe('TelemetryRegistry', () => {
expect(b.trackSearchQuery).toHaveBeenCalledExactlyOnceWith(payload)
})
it('dispatches trackBeginCheckout with intent metadata to every provider', () => {
const a: TelemetryProvider = { trackBeginCheckout: vi.fn() }
const b: TelemetryProvider = {}
const registry = new TelemetryRegistry()
registry.registerProvider(a)
registry.registerProvider(b)
const metadata = {
user_id: 'user-1',
tier: 'pro' as const,
cycle: 'monthly' as const,
checkout_type: 'new' as const,
payment_intent_source: 'subscribe_to_run' as const
}
registry.trackBeginCheckout(metadata)
expect(a.trackBeginCheckout).toHaveBeenCalledExactlyOnceWith(metadata)
})
it('dispatches trackAddApiCreditButtonClicked with its source', () => {
const provider: TelemetryProvider = {
trackAddApiCreditButtonClicked: vi.fn()
}
const registry = new TelemetryRegistry()
registry.registerProvider(provider)
registry.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
expect(
provider.trackAddApiCreditButtonClicked
).toHaveBeenCalledExactlyOnceWith({ source: 'credits_panel' })
})
it('skips providers that do not implement trackSearchQuery', () => {
const empty: TelemetryProvider = {}
const registry = new TelemetryRegistry()

View File

@@ -1,6 +1,7 @@
import type { AuditLog } from '@/services/customerEventsService'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
@@ -99,8 +100,10 @@ export class TelemetryRegistry implements TelemetryDispatcher {
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
}
trackAddApiCreditButtonClicked(): void {
this.dispatch((provider) => provider.trackAddApiCreditButtonClicked?.())
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.dispatch((provider) =>
provider.trackAddApiCreditButtonClicked?.(metadata)
)
}
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {

View File

@@ -313,6 +313,42 @@ describe('PostHogTelemetryProvider', () => {
)
})
it('captures begin_checkout with intent metadata', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()
provider.trackBeginCheckout({
user_id: 'user-1',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
})
expect(hoisted.mockCapture).toHaveBeenCalledWith(
TelemetryEvents.BEGIN_CHECKOUT,
{
user_id: 'user-1',
tier: 'pro',
cycle: 'monthly',
checkout_type: 'new',
payment_intent_source: 'subscribe_to_run'
}
)
})
it('captures add-credit clicks with their source', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()
provider.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
expect(hoisted.mockCapture).toHaveBeenCalledWith(
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
{ source: 'credits_panel' }
)
})
it('captures share attribution events', async () => {
const provider = createProvider()
await vi.dynamicImportSettled()

View File

@@ -10,7 +10,9 @@ import { remoteConfig } from '@/platform/remoteConfig/remoteConfig'
import type { RemoteConfig } from '@/platform/remoteConfig/types'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
EnterLinearMetadata,
ShareFlowMetadata,
@@ -350,8 +352,12 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
this.trackEvent(eventName, metadata)
}
trackAddApiCreditButtonClicked(): void {
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
}
trackBeginCheckout(metadata: BeginCheckoutMetadata): void {
this.trackEvent(TelemetryEvents.BEGIN_CHECKOUT, metadata)
}
trackMonthlySubscriptionSucceeded(

View File

@@ -115,6 +115,17 @@ describe('HostTelemetrySink', () => {
)
})
it('forwards add-credit clicks with their source', () => {
new HostTelemetrySink().trackAddApiCreditButtonClicked({
source: 'avatar_menu'
})
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
{ source: 'avatar_menu' }
)
})
it('does nothing when the host bridge is absent', () => {
delete window.__comfyDesktop2

View File

@@ -10,6 +10,7 @@ import {
import type { AuditLog } from '@/services/customerEventsService'
import type {
AddCreditsClickMetadata,
AuthMetadata,
BeginCheckoutMetadata,
DefaultViewSetMetadata,
@@ -126,8 +127,8 @@ export class HostTelemetrySink implements TelemetryProvider {
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
}
trackAddApiCreditButtonClicked(): void {
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
}
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {

View File

@@ -12,12 +12,29 @@
* 3. Check dist/assets/*.js files contain no tracking code
*/
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import type { AuditLog } from '@/services/customerEventsService'
import type { AppMode } from '@/utils/appMode'
export type PaymentIntentSource =
| 'subscription_required'
| 'out_of_credits'
| 'top_up_blocked'
| 'deep_link'
| 'subscribe_to_run'
| 'subscribe_now_button'
| 'upgrade_to_add_credits'
| 'settings_billing_panel'
| 'avatar_menu_plans'
| 'team_members_panel'
| 'invite_member_upsell'
| 'upload_model_upgrade'
| 'team_upgrade_resume'
export type SubscriptionCheckoutType = 'new' | 'change'
export type SubscriptionCheckoutTier = TierKey | 'team'
/**
* Authentication metadata for sign-up tracking
*/
@@ -426,16 +443,23 @@ export interface CheckoutAttributionMetadata {
export interface SubscriptionMetadata {
current_tier?: string
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
}
export interface AddCreditsClickMetadata {
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
}
export interface BeginCheckoutMetadata
extends Record<string, unknown>, CheckoutAttributionMetadata {
user_id: string
tier: TierKey
tier: SubscriptionCheckoutTier
cycle: BillingCycle
checkout_type: 'new' | 'change'
checkout_type: SubscriptionCheckoutType
checkout_attempt_id?: string
billing_op_id?: string
previous_tier?: TierKey
payment_intent_source?: PaymentIntentSource
}
interface EcommerceItemMetadata {
@@ -457,8 +481,9 @@ export interface SubscriptionSuccessMetadata extends Record<string, unknown> {
checkout_attempt_id: string
tier: TierKey
cycle: BillingCycle
checkout_type: 'new' | 'change'
checkout_type: SubscriptionCheckoutType
previous_tier?: TierKey
payment_intent_source?: PaymentIntentSource
value: number
currency: string
ecommerce: EcommerceMetadata
@@ -489,7 +514,7 @@ export interface TelemetryProvider {
metadata?: SubscriptionSuccessMetadata
): void
trackMonthlySubscriptionCancelled?(): void
trackAddApiCreditButtonClicked?(): void
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
trackApiCreditTopupSucceeded?(): void
trackWorkspaceInviteSent?(metadata: WorkspaceInviteMetadata): void

View File

@@ -321,7 +321,7 @@ const handleOpenWorkspaceSettings = () => {
}
const handleOpenPlansAndPricing = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
emit('close')
}
@@ -336,13 +336,12 @@ const handleOpenPlanAndCreditsSettings = () => {
}
const handleUpgradeToAddCredits = () => {
subscriptionDialog.showPricingTable()
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
emit('close')
}
const handleTopUp = () => {
// Track purchase credits entry from avatar popover
useTelemetry()?.trackAddApiCreditButtonClicked()
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
dialogService.showTopUpCreditsDialog()
emit('close')
}

View File

@@ -391,12 +391,13 @@ const showZeroState = computed(
)
function handleSubscribeWorkspace() {
showSubscriptionDialog()
showSubscriptionDialog({ reason: 'settings_billing_panel' })
}
function handleUpgrade() {
if (isFreeTierPlan.value) showPricingTable()
else showSubscriptionDialog()
if (isFreeTierPlan.value)
showPricingTable({ reason: 'settings_billing_panel' })
else showSubscriptionDialog({ reason: 'settings_billing_panel' })
}
function handleViewMoreDetails() {

View File

@@ -113,7 +113,7 @@ import { cn } from '@comfyorg/tailwind-utils'
import { useEventListener } from '@vueuse/core'
import Button from '@/components/ui/button/Button.vue'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
import SubscriptionAddPaymentPreviewWorkspace from './SubscriptionAddPaymentPreviewWorkspace.vue'
@@ -123,7 +123,7 @@ import UnifiedPricingTable from './UnifiedPricingTable.vue'
const { onClose, reason, initialPlanMode } = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
initialPlanMode?: 'personal' | 'team'
}>()
@@ -152,7 +152,7 @@ const {
handleConfirmTransition,
handleTeamSubscribe,
handleResubscribe
} = useSubscriptionCheckout(emit)
} = useSubscriptionCheckout(emit, reason)
// Backspace mirrors the back arrow on the confirm step, but never while an
// editable element is focused (let it delete text there).

View File

@@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
import { ref } from 'vue'
import { createI18n } from 'vue-i18n'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import SubscriptionRequiredDialogContentWorkspace from './SubscriptionRequiredDialogContentWorkspace.vue'
@@ -17,25 +17,10 @@ const mockHandleResubscribe = vi.fn()
const mockHandleSuccessClose = vi.fn()
const mockCheckoutStep = ref<'pricing' | 'preview' | 'success'>('pricing')
const mockPreviewData = ref<{ transition_type: string } | null>(null)
const mockUseSubscriptionCheckout = vi.hoisted(() => vi.fn())
vi.mock('@/platform/workspace/composables/useSubscriptionCheckout', () => ({
useSubscriptionCheckout: () => ({
checkoutStep: mockCheckoutStep,
isLoadingPreview: ref(false),
loadingTier: ref(null),
isSubscribing: ref(false),
isResubscribing: ref(false),
previewData: mockPreviewData,
selectedTierKey: ref('standard'),
selectedBillingCycle: ref('yearly'),
isPolling: ref(false),
handleSubscribeClick: mockHandleSubscribeClick,
handleBackToPricing: mockHandleBackToPricing,
handleAddCreditCard: mockHandleAddCreditCard,
handleConfirmTransition: mockHandleConfirmTransition,
handleResubscribe: mockHandleResubscribe,
handleSuccessClose: mockHandleSuccessClose
})
useSubscriptionCheckout: mockUseSubscriptionCheckout
}))
const i18n = createI18n({
@@ -91,7 +76,7 @@ const SuccessStub = {
function renderComponent(
props: {
onClose?: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
isPersonal?: boolean
} = {}
) {
@@ -121,6 +106,23 @@ function renderComponent(
describe('SubscriptionRequiredDialogContentWorkspace', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseSubscriptionCheckout.mockReturnValue({
checkoutStep: mockCheckoutStep,
isLoadingPreview: ref(false),
loadingTier: ref(null),
isSubscribing: ref(false),
isResubscribing: ref(false),
previewData: mockPreviewData,
selectedTierKey: ref('standard'),
selectedBillingCycle: ref('yearly'),
isPolling: ref(false),
handleSubscribeClick: mockHandleSubscribeClick,
handleBackToPricing: mockHandleBackToPricing,
handleAddCreditCard: mockHandleAddCreditCard,
handleConfirmTransition: mockHandleConfirmTransition,
handleResubscribe: mockHandleResubscribe,
handleSuccessClose: mockHandleSuccessClose
})
mockCheckoutStep.value = 'pricing'
mockPreviewData.value = null
})
@@ -132,6 +134,15 @@ describe('SubscriptionRequiredDialogContentWorkspace', () => {
expect(screen.queryByTestId('transition-preview')).not.toBeInTheDocument()
})
it('passes the reason into subscription checkout', () => {
renderComponent({ reason: 'out_of_credits' })
expect(mockUseSubscriptionCheckout).toHaveBeenCalledWith(
expect.any(Function),
'out_of_credits'
)
})
it('shows the team workspace header by default', () => {
renderComponent()
expect(screen.getByText('Team Workspace')).toBeInTheDocument()

View File

@@ -116,7 +116,7 @@
import { cn } from '@comfyorg/tailwind-utils'
import Button from '@/components/ui/button/Button.vue'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
import PricingTableWorkspace from './PricingTableWorkspace.vue'
@@ -130,7 +130,7 @@ const {
isPersonal = false
} = defineProps<{
onClose: () => void
reason?: SubscriptionDialogReason
reason?: PaymentIntentSource
isPersonal?: boolean
}>()
@@ -154,7 +154,7 @@ const {
handleConfirmTransition,
handleResubscribe,
handleSuccessClose
} = useSubscriptionCheckout(emit)
} = useSubscriptionCheckout(emit, reason)
</script>
<style scoped>

View File

@@ -61,6 +61,9 @@ function onDismiss() {
function onUpgrade() {
dialogStore.closeDialog({ key: 'invite-member-upsell' })
subscriptionDialog.show({ planMode: 'team' })
subscriptionDialog.show({
planMode: 'team',
reason: 'invite_member_upsell'
})
}
</script>

View File

@@ -277,7 +277,7 @@ export function useMembersPanel() {
}
function showTeamPlans() {
subscriptionDialog.show({ planMode: 'team' })
subscriptionDialog.show({ planMode: 'team', reason: 'team_members_panel' })
}
return {

View File

@@ -1,8 +1,9 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { computed } from 'vue'
import { computed, reactive } from 'vue'
import type { PaymentIntentSource } from '@/platform/telemetry/types'
import type { Plan } from '@/platform/workspace/api/workspaceApi'
import { findPlanSlug } from './useSubscriptionCheckout'
@@ -75,7 +76,9 @@ const {
mockPlans,
mockResubscribe,
mockToastAdd,
mockStartOperation
mockStartOperation,
mockTrackBeginCheckout,
mockUserId
} = vi.hoisted(() => ({
mockSubscribe: vi.fn(),
mockPreviewSubscribe: vi.fn(),
@@ -84,7 +87,9 @@ const {
mockPlans: { value: [] as Plan[] },
mockResubscribe: vi.fn(),
mockToastAdd: vi.fn(),
mockStartOperation: vi.fn()
mockStartOperation: vi.fn(),
mockTrackBeginCheckout: vi.fn(),
mockUserId: { value: 'user-1' as string | null }
}))
vi.mock('@/composables/billing/useBillingContext', () => ({
@@ -119,7 +124,14 @@ vi.mock('primevue/usetoast', () => ({
}))
vi.mock('@/platform/telemetry', () => ({
useTelemetry: () => ({ trackMonthlySubscriptionSucceeded: vi.fn() })
useTelemetry: () => ({
trackMonthlySubscriptionSucceeded: vi.fn(),
trackBeginCheckout: mockTrackBeginCheckout
})
}))
vi.mock('@/stores/authStore', () => ({
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
}))
vi.mock('vue-i18n', async (importOriginal) => {
@@ -135,10 +147,10 @@ vi.mock('vue-i18n', async (importOriginal) => {
describe('useSubscriptionCheckout', () => {
let emit: ReturnType<typeof vi.fn>
async function setup() {
async function setup(paymentIntentSource?: PaymentIntentSource) {
const { useSubscriptionCheckout } =
await import('./useSubscriptionCheckout')
return useSubscriptionCheckout(emit as never)
return useSubscriptionCheckout(emit as never, paymentIntentSource)
}
beforeEach(() => {
@@ -146,6 +158,7 @@ describe('useSubscriptionCheckout', () => {
vi.clearAllMocks()
mockPlans.value = allPlans()
mockStartOperation.mockResolvedValue({ status: 'succeeded' })
mockUserId.value = 'user-1'
emit = vi.fn()
})
@@ -459,6 +472,13 @@ describe('useSubscriptionCheckout', () => {
cancelUrl: 'https://platform.comfy.org/payment/failed'
})
expect(checkout.checkoutStep.value).toBe('success')
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
tier: 'team',
checkout_type: 'new',
billing_op_id: 'op-team-1'
})
)
})
it('uses the annual plan slug for the yearly cycle', async () => {
@@ -553,6 +573,39 @@ describe('useSubscriptionCheckout', () => {
detail: 'Team payment failed'
})
)
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
it('keeps team checkout_type as change when the preview request fails', async () => {
const checkout = await setup()
mockPreviewSubscribe.mockRejectedValueOnce(new Error('not supported'))
await checkout.handleSubscribeTeamClick({
stop: {
id: 'team_1400',
usd: 1400,
credits: 295_400,
discountedUsd: 1295
},
billingCycle: 'monthly',
isChange: true
})
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-team-change'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleTeamSubscribe()
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
expect.objectContaining({
tier: 'team',
cycle: 'monthly',
checkout_type: 'change',
billing_op_id: 'op-team-change'
})
)
})
})
@@ -603,6 +656,47 @@ describe('useSubscriptionCheckout', () => {
expect(checkout.checkoutStep.value).toBe('success')
})
it('skips begin_checkout when no user id is available', async () => {
mockUserId.value = null
const checkout = await setup('subscribe_to_run')
checkout.selectedTierKey.value = 'standard'
checkout.selectedBillingCycle.value = 'yearly'
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-1'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleAddCreditCard()
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
mockUserId.value = 'user-1'
})
it('fires begin_checkout carrying the payment intent source', async () => {
const checkout = await setup('subscribe_to_run')
checkout.selectedTierKey.value = 'standard'
checkout.selectedBillingCycle.value = 'yearly'
mockSubscribe.mockResolvedValueOnce({
status: 'subscribed',
billing_op_id: 'op-1'
})
mockFetchStatus.mockResolvedValueOnce(undefined)
mockFetchBalance.mockResolvedValueOnce(undefined)
await checkout.handleAddCreditCard()
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
user_id: 'user-1',
tier: 'standard',
cycle: 'yearly',
checkout_type: 'new',
billing_op_id: 'op-1',
payment_intent_source: 'subscribe_to_run'
})
})
it('opens payment URL when needs_payment_method', async () => {
const checkout = await setup()
checkout.selectedTierKey.value = 'standard'
@@ -720,6 +814,7 @@ describe('useSubscriptionCheckout', () => {
detail: 'Payment failed'
})
)
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
})
})

View File

@@ -9,16 +9,26 @@ import type { TeamPlanSelection } from '@/platform/cloud/subscription/constants/
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { useTelemetry } from '@/platform/telemetry'
import type {
PaymentIntentSource,
SubscriptionCheckoutType
} from '@/platform/telemetry/types'
import type {
Plan,
PreviewSubscribeResponse,
SubscribeResponse
} from '@/platform/workspace/api/workspaceApi'
import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore'
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
type CheckoutStep = 'pricing' | 'preview' | 'success'
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
interface SelectedTeamCheckout {
stop: TeamPlanSelection
checkoutType: SubscriptionCheckoutType
}
/**
* Which screen the `preview` step shows. Only a change prorates: a team change
* carries `previewData` (handleSubscribeTeamClick sets it solely for an immediate
@@ -45,9 +55,12 @@ export function findPlanSlug(
return plan?.slug ?? null
}
export function useSubscriptionCheckout(emit: {
(e: 'close', subscribed: boolean): void
}) {
export function useSubscriptionCheckout(
emit: {
(e: 'close', subscribed: boolean): void
},
paymentIntentSource?: PaymentIntentSource
) {
const { t } = useI18n()
const toast = useToast()
const {
@@ -68,13 +81,16 @@ export function useSubscriptionCheckout(emit: {
const isResubscribing = ref(false)
const previewData = ref<PreviewSubscribeResponse | null>(null)
const selectedTierKey = ref<CheckoutTierKey | null>(null)
const selectedTeamStop = ref<TeamPlanSelection | null>(null)
const selectedTeamCheckout = ref<SelectedTeamCheckout | null>(null)
const selectedBillingCycle = ref<BillingCycle>('yearly')
const isPolling = computed(() => billingOperationStore.hasPendingOperations)
const isTeamCheckout = computed(() => selectedTeamStop.value !== null)
const selectedTeamStop = computed(
() => selectedTeamCheckout.value?.stop ?? null
)
const isTeamCheckout = computed(() => selectedTeamCheckout.value !== null)
const previewVariant = computed<PreviewVariant>(() => {
if (selectedTeamStop.value) {
if (selectedTeamCheckout.value) {
return previewData.value ? 'team-change' : 'team-new'
}
if (previewData.value) {
@@ -154,7 +170,10 @@ export function useSubscriptionCheckout(emit: {
billingCycle: BillingCycle
isChange?: boolean
}) {
selectedTeamStop.value = payload.stop
selectedTeamCheckout.value = {
stop: payload.stop,
checkoutType: payload.isChange ? 'change' : 'new'
}
selectedBillingCycle.value = payload.billingCycle
selectedTierKey.value = null
previewData.value = null
@@ -182,7 +201,7 @@ export function useSubscriptionCheckout(emit: {
function handleBackToPricing() {
checkoutStep.value = 'pricing'
previewData.value = null
selectedTeamStop.value = null
selectedTeamCheckout.value = null
}
function handleSuccessClose() {
@@ -190,20 +209,34 @@ export function useSubscriptionCheckout(emit: {
}
async function handleSubscription() {
if (!selectedTierKey.value) return
const tierKey = selectedTierKey.value
if (!tierKey) return
const billingCycle = selectedBillingCycle.value
const checkoutType =
previewData.value &&
previewData.value.transition_type !== 'new_subscription'
? 'change'
: 'new'
isSubscribing.value = true
try {
const planSlug = getApiPlanSlug(
selectedTierKey.value,
selectedBillingCycle.value
)
const planSlug = getApiPlanSlug(tierKey, billingCycle)
if (!planSlug) return
const response = await subscribe(planSlug, {
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
})
if (response) {
trackWorkspaceCheckoutStarted({
tier: tierKey,
cycle: billingCycle,
checkoutType,
billingOpId: response.billing_op_id,
paymentIntentSource
})
}
await handleSubscribeResponse(response)
} catch (error) {
showSubscribeError(error)
@@ -269,8 +302,8 @@ export function useSubscriptionCheckout(emit: {
}
async function handleTeamSubscription() {
const stop = selectedTeamStop.value
if (!stop?.id) {
const teamCheckout = selectedTeamCheckout.value
if (!teamCheckout?.stop.id) {
toast.add({
severity: 'error',
summary: t('subscription.teamPlan.name'),
@@ -279,16 +312,28 @@ export function useSubscriptionCheckout(emit: {
return
}
const { stop, checkoutType } = teamCheckout
const billingCycle = selectedBillingCycle.value
isSubscribing.value = true
try {
const planSlug = getTeamPlanSlug(selectedBillingCycle.value)
const planSlug = getTeamPlanSlug(billingCycle)
const response = await subscribe(planSlug, {
teamCreditStopId: stop.id,
billingCycle: selectedBillingCycle.value,
billingCycle,
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
})
if (response) {
trackWorkspaceCheckoutStarted({
tier: 'team',
cycle: billingCycle,
checkoutType,
billingOpId: response.billing_op_id,
paymentIntentSource
})
}
await handleSubscribeResponse(response)
} catch (error) {
showSubscribeError(error)

View File

@@ -2,6 +2,7 @@ import { computed, ref, shallowRef } from 'vue'
import { useBillingPlans } from '@/platform/cloud/subscription/composables/useBillingPlans'
import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type {
BillingBalanceResponse,
BillingStatusResponse,
@@ -275,12 +276,12 @@ export function useWorkspaceBilling(): BillingState & BillingActions {
async function requireActiveSubscription(): Promise<void> {
await fetchStatus()
if (!isActiveSubscription.value) {
subscriptionDialog.show()
subscriptionDialog.show({ reason: 'subscription_required' })
}
}
function showSubscriptionDialog(): void {
subscriptionDialog.show()
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
subscriptionDialog.show(options)
}
return {

View File

@@ -0,0 +1,38 @@
import { useTelemetry } from '@/platform/telemetry'
import type {
PaymentIntentSource,
SubscriptionCheckoutTier,
SubscriptionCheckoutType
} from '@/platform/telemetry/types'
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
import { useAuthStore } from '@/stores/authStore'
interface TrackWorkspaceCheckoutStartedOptions {
tier: SubscriptionCheckoutTier
cycle: BillingCycle
checkoutType: SubscriptionCheckoutType
billingOpId: string
paymentIntentSource?: PaymentIntentSource
}
export function trackWorkspaceCheckoutStarted({
tier,
cycle,
checkoutType,
billingOpId,
paymentIntentSource
}: TrackWorkspaceCheckoutStartedOptions) {
const { userId } = useAuthStore()
if (!userId) return
useTelemetry()?.trackBeginCheckout({
user_id: userId,
tier,
cycle,
checkout_type: checkoutType,
billing_op_id: billingOpId,
...(paymentIntentSource
? { payment_intent_source: paymentIntentSource }
: {})
})
}

View File

@@ -154,39 +154,6 @@ const zAssetDownloadWsMessage = z.object({
error: z.string().optional()
})
const DOWNLOAD_STATES = [
'queued',
'active',
'paused',
'verifying',
'completed',
'failed',
'cancelled'
] as const
const zDownloadSegment = z.object({
idx: z.number().int().nonnegative(),
bytes_done: z.number().int().nonnegative(),
length: z.number().int().nonnegative()
})
const zDownloadStatus = z.object({
download_id: z.string(),
model_id: z.string(),
url: z.string(),
status: z.enum(DOWNLOAD_STATES),
priority: z.number().int(),
total_bytes: z.number().int().nonnegative().nullable(),
bytes_done: z.number().int().nonnegative(),
progress: z.number().nullable(),
speed_bps: z.number().nonnegative().nullable(),
eta_seconds: z.number().nonnegative().nullable(),
segments: z.array(zDownloadSegment).nullable(),
error: z.string().nullable(),
created_at: z.number(),
updated_at: z.number()
})
const zAssetExportWsMessage = z.object({
task_id: z.string(),
export_name: z.string().optional(),
@@ -221,8 +188,6 @@ export type ProgressStateWsMessage = z.infer<typeof zProgressStateWsMessage>
export type FeatureFlagsWsMessage = z.infer<typeof zFeatureFlagsWsMessage>
export type AssetDownloadWsMessage = z.infer<typeof zAssetDownloadWsMessage>
export type AssetExportWsMessage = z.infer<typeof zAssetExportWsMessage>
export type DownloadState = (typeof DOWNLOAD_STATES)[number]
export type DownloadStatus = z.infer<typeof zDownloadStatus>
// End of ws messages
export type NotificationWsMessage = z.infer<typeof zNotificationWsMessage>

View File

@@ -34,7 +34,6 @@ import type {
AssetDownloadWsMessage,
AssetExportWsMessage,
CustomNodesI18n,
DownloadStatus,
EmbeddingsResponse,
ExecutedWsMessage,
ExecutingWsMessage,
@@ -187,7 +186,6 @@ interface BackendApiCalls {
feature_flags: FeatureFlagsWsMessage
asset_download: AssetDownloadWsMessage
asset_export: AssetExportWsMessage
download_progress: DownloadStatus
}
/** Dictionary of all api calls */

View File

@@ -0,0 +1,286 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
interface AxiosLikeError extends Error {
isAxiosError: true
response?: {
status: number
data?: {
message?: string
}
}
}
const mockClient = vi.hoisted(() => ({
get: vi.fn(),
post: vi.fn()
}))
const mockAxios = vi.hoisted(() => ({
create: vi.fn(() => mockClient),
isAxiosError: vi.fn(
(error: unknown): error is AxiosLikeError =>
typeof error === 'object' &&
error !== null &&
'isAxiosError' in error &&
error.isAxiosError === true
)
}))
vi.mock('axios', () => ({
default: mockAxios
}))
import { useComfyRegistryService } from './comfyRegistryService'
function response<T>(data: T) {
return { data }
}
function axiosError(
message: string,
responseData?: AxiosLikeError['response']
): AxiosLikeError {
const error = new Error(message) as AxiosLikeError
error.isAxiosError = true
if (responseData) error.response = responseData
return error
}
describe('useComfyRegistryService', () => {
beforeEach(() => {
mockClient.get.mockReset()
mockClient.post.mockReset()
mockAxios.isAxiosError.mockClear()
})
it('configures the registry axios client with repeated query params', () => {
expect(mockAxios.create).toHaveBeenCalledWith({
baseURL: 'https://api.comfy.org',
headers: {
'Content-Type': 'application/json'
},
paramsSerializer: {
indexes: null
}
})
})
it('returns response data and clears loading state for successful requests', async () => {
mockClient.get.mockResolvedValueOnce(response({ nodes: [] }))
const service = useComfyRegistryService()
const result = await service.search({ search: 'manager' })
expect(result).toEqual({ nodes: [] })
expect(mockClient.get).toHaveBeenCalledWith('/nodes/search', {
params: { search: 'manager' },
signal: undefined
})
expect(service.error.value).toBeNull()
expect(service.isLoading.value).toBe(false)
})
it('skips node definition requests when pack id or version is missing', async () => {
const service = useComfyRegistryService()
await expect(
service.getNodeDefs({ packId: '', version: '1.0.0' })
).resolves.toBeNull()
await expect(
service.getNodeDefs({ packId: 'pack', version: '' })
).resolves.toBeNull()
expect(mockClient.get).not.toHaveBeenCalled()
})
it('passes query params and abort signals through node definition requests', async () => {
const signal = new AbortController().signal
mockClient.get.mockResolvedValueOnce(response([{ name: 'KSampler' }]))
const service = useComfyRegistryService()
const result = await service.getNodeDefs(
{ packId: 'pack', version: '1.0.0', page: 2 },
signal
)
expect(result).toEqual([{ name: 'KSampler' }])
expect(mockClient.get).toHaveBeenCalledWith(
'/nodes/pack/versions/1.0.0/comfy-nodes',
{
params: { page: 2 },
signal
}
)
})
it('routes publisher, pack, and review methods to their registry endpoints', async () => {
mockClient.get
.mockResolvedValueOnce(response({ id: 'publisher' }))
.mockResolvedValueOnce(response([{ id: 'pack' }]))
.mockResolvedValueOnce(response([{ version: '1.0.0' }]))
.mockResolvedValueOnce(response({ id: 'version' }))
.mockResolvedValueOnce(response({ id: 'pack' }))
.mockResolvedValueOnce(response({ id: 'pack' }))
.mockResolvedValueOnce(response({ id: 'pack' }))
mockClient.post
.mockResolvedValueOnce(response({ id: 'reviewed' }))
.mockResolvedValueOnce(response({ node_versions: [] }))
const service = useComfyRegistryService()
const signal = new AbortController().signal
await expect(
service.getPublisherById('publisher', signal)
).resolves.toEqual({ id: 'publisher' })
await expect(
service.listPacksForPublisher('publisher', true, signal)
).resolves.toEqual([{ id: 'pack' }])
await expect(
service.getPackVersions(
'pack',
{ statuses: ['NodeVersionStatusActive'] },
signal
)
).resolves.toEqual([{ version: '1.0.0' }])
await expect(
service.getPackByVersion('pack', 'version', signal)
).resolves.toEqual({ id: 'version' })
await expect(service.getPackById('pack', signal)).resolves.toEqual({
id: 'pack'
})
await expect(
service.inferPackFromNodeName('KSampler', signal)
).resolves.toEqual({ id: 'pack' })
await expect(service.listAllPacks({ page: 1 }, signal)).resolves.toEqual({
id: 'pack'
})
await expect(service.postPackReview('pack', 5, signal)).resolves.toEqual({
id: 'reviewed'
})
await expect(
service.getBulkNodeVersions(
[{ node_id: 'pack', version: '1.0.0' }],
signal
)
).resolves.toEqual({ node_versions: [] })
expect(mockClient.get).toHaveBeenNthCalledWith(1, '/publishers/publisher', {
signal
})
expect(mockClient.get).toHaveBeenNthCalledWith(
2,
'/publishers/publisher/nodes',
{
params: { include_banned: true },
signal
}
)
expect(mockClient.get).toHaveBeenNthCalledWith(3, '/nodes/pack/versions', {
params: { statuses: ['NodeVersionStatusActive'] },
signal
})
expect(mockClient.get).toHaveBeenNthCalledWith(
4,
'/nodes/pack/versions/version',
{ signal }
)
expect(mockClient.get).toHaveBeenNthCalledWith(5, '/nodes/pack', {
signal
})
expect(mockClient.get).toHaveBeenNthCalledWith(
6,
'/comfy-nodes/KSampler/node',
{ signal }
)
expect(mockClient.get).toHaveBeenNthCalledWith(7, '/nodes', {
params: { page: 1 },
signal
})
expect(mockClient.post).toHaveBeenNthCalledWith(
1,
'/nodes/pack/reviews',
null,
{ params: { star: 5 }, signal }
)
expect(mockClient.post).toHaveBeenNthCalledWith(
2,
'/bulk/nodes/versions',
{ node_versions: [{ node_id: 'pack', version: '1.0.0' }] },
{ signal }
)
})
it('omits include_banned when listing publisher packs without banned packs', async () => {
mockClient.get.mockResolvedValueOnce(response([]))
const service = useComfyRegistryService()
await service.listPacksForPublisher('publisher', false)
expect(mockClient.get).toHaveBeenCalledWith('/publishers/publisher/nodes', {
params: undefined,
signal: undefined
})
})
it.for([
{ status: 400, expected: 'Bad request: Invalid input' },
{ status: 401, expected: 'Unauthorized: Authentication required' },
{ status: 403, expected: 'Forbidden: Access denied' },
{ status: 404, expected: 'Not found: Resource not found' },
{ status: 409, expected: 'Conflict: Resource conflict' },
{ status: 500, expected: 'Server error: Internal server error' },
{ status: 418, expected: 'Failed to perform search: teapot' }
])(
'normalizes axios response status $status',
async ({ status, expected }) => {
mockClient.get.mockRejectedValueOnce(
axiosError('Request failed', {
status,
data: status === 418 ? { message: 'teapot' } : {}
})
)
const service = useComfyRegistryService()
await expect(service.search()).resolves.toBeNull()
expect(service.error.value).toBe(expected)
expect(service.isLoading.value).toBe(false)
}
)
it('uses route-specific errors before generic status messages', async () => {
mockClient.get.mockRejectedValueOnce(
axiosError('Request failed', {
status: 404,
data: { message: 'ignored' }
})
)
const service = useComfyRegistryService()
await expect(service.getPackById('missing')).resolves.toBeNull()
expect(service.error.value).toBe(
'Pack not found: The pack with ID missing does not exist'
)
})
it('normalizes network, thrown Error, unknown, and abort failures', async () => {
const service = useComfyRegistryService()
mockClient.get.mockRejectedValueOnce(axiosError('Network down'))
await expect(service.search()).resolves.toBeNull()
expect(service.error.value).toBe('Failed to perform search: Network down')
mockClient.get.mockRejectedValueOnce(new Error('boom'))
await expect(service.search()).resolves.toBeNull()
expect(service.error.value).toBe('Failed to perform search: boom')
mockClient.get.mockRejectedValueOnce('bad')
await expect(service.search()).resolves.toBeNull()
expect(service.error.value).toBe(
'Failed to perform search: Unknown error occurred'
)
mockClient.get.mockRejectedValueOnce(new DOMException('', 'AbortError'))
await expect(service.search()).resolves.toBeNull()
expect(service.error.value).toBeNull()
})
})

View File

@@ -18,7 +18,7 @@ import type {
} from '@/stores/dialogStore'
import type { ComponentAttrs } from 'vue-component-type-helpers'
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
import type { WorkspaceRole } from '@/platform/workspace/api/workspaceApi'
// Lazy loaders for dialogs - components are loaded on first use
@@ -442,9 +442,9 @@ export const useDialogService = () => {
})
}
async function showSubscriptionRequiredDialog(options?: {
reason?: SubscriptionDialogReason
}) {
async function showSubscriptionRequiredDialog(
options?: SubscriptionDialogOptions
) {
if (!isCloud || !window.__CONFIG__?.subscription_required) {
return
}

View File

@@ -234,6 +234,54 @@ describe('useRegistrySearchGateway', () => {
const gateway = useRegistrySearchGateway()
expect(gateway).toBeDefined()
})
it('waits for the circuit breaker timeout before retrying a failed provider', async () => {
vi.useFakeTimers()
vi.setSystemTime(new Date('2024-01-01T00:00:00Z'))
vi.mocked(useAlgoliaSearchProvider).mockImplementation(() => {
throw new Error('Algolia init failed')
})
const registryResult = {
nodePacks: [{ id: 'registry-1', name: 'Registry Pack' }],
querySuggestions: []
}
const mockRegistryProvider = {
searchPacks: vi.fn().mockRejectedValue(new Error('Registry failed')),
clearSearchCache: vi.fn(),
getSortValue: vi.fn(),
getSortableFields: vi.fn().mockReturnValue([])
}
vi.mocked(useComfyRegistrySearchProvider).mockReturnValue(
mockRegistryProvider
)
const gateway = useRegistrySearchGateway()
for (let attempt = 0; attempt < 3; attempt++) {
await expect(
gateway.searchPacks('test', { pageSize: 10, pageNumber: 0 })
).rejects.toThrow('All search providers failed')
}
expect(mockRegistryProvider.searchPacks).toHaveBeenCalledTimes(3)
await expect(
gateway.searchPacks('test', { pageSize: 10, pageNumber: 0 })
).rejects.toThrow('All search providers failed')
expect(mockRegistryProvider.searchPacks).toHaveBeenCalledTimes(3)
vi.setSystemTime(new Date('2024-01-01T00:01:01Z'))
mockRegistryProvider.searchPacks.mockResolvedValueOnce(registryResult)
await expect(
gateway.searchPacks('test', { pageSize: 10, pageNumber: 0 })
).resolves.toBe(registryResult)
expect(mockRegistryProvider.searchPacks).toHaveBeenCalledTimes(4)
})
})
describe('Cache management', () => {

View File

@@ -126,6 +126,19 @@ describe('useAssetDownloadStore', () => {
})
})
it('keeps the first placeholder when the same task is tracked twice', () => {
const store = useAssetDownloadStore()
store.trackDownload('task-123', 'checkpoints', 'first.safetensors')
store.trackDownload('task-123', 'loras', 'second.safetensors')
expect(store.downloadList).toHaveLength(1)
expect(store.downloadList[0]).toMatchObject({
modelType: 'checkpoints',
assetName: 'first.safetensors'
})
})
it('handles out-of-order messages where completed arrives before progress', () => {
const store = useAssetDownloadStore()
@@ -179,6 +192,19 @@ describe('useAssetDownloadStore', () => {
expect(store.finishedDownloads[0].status).toBe('completed')
})
it('skips polling when active downloads have fresh progress', async () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(9_999)
dispatch(createDownloadMessage({ status: 'running', progress: 75 }))
await vi.advanceTimersByTimeAsync(1)
expect(taskService.getTask).not.toHaveBeenCalled()
expect(store.activeDownloads).toHaveLength(1)
expect(store.activeDownloads[0].progress).toBe(75)
})
it('polls and marks failed downloads', async () => {
const store = useAssetDownloadStore()
@@ -311,5 +337,22 @@ describe('useAssetDownloadStore', () => {
expect(store.sessionDownloadCount).toBe(0)
expect(store.isDownloadedThisSession('asset-456')).toBe(false)
})
it('does not acknowledge unrelated completed downloads', () => {
const store = useAssetDownloadStore()
dispatch(
createDownloadMessage({
status: 'completed',
progress: 100,
asset_id: 'asset-456'
})
)
store.acknowledgeAsset('other-asset')
expect(store.sessionDownloadCount).toBe(1)
expect(store.isDownloadedThisSession('asset-456')).toBe(true)
})
})
})

View File

@@ -1,4 +1,5 @@
import { createTestingPinia } from '@pinia/testing'
import { fromAny } from '@total-typescript/shoehorn'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { nextTick, watch } from 'vue'
@@ -11,6 +12,7 @@ import type {
} from '@/platform/assets/schemas/assetSchema'
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
import { assetService } from '@/platform/assets/services/assetService'
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
// Mock the api module
vi.mock('@/scripts/api', () => ({
@@ -96,6 +98,10 @@ const mockOutputOverrides = vi.hoisted(() => ({
value: null as MockOutput[] | null
}))
const mockAssetMapperOptions = vi.hoisted(() => ({
omitCreatedAtForIds: new Set<string>()
}))
// Mock TaskItemImpl
const PREVIEWABLE_MEDIA_TYPES = new Set(['images', 'video', 'audio'])
@@ -169,11 +175,14 @@ vi.mock('@/platform/assets/composables/media/assetMappers', () => ({
})),
mapTaskOutputToAssetItem: vi.fn((task, output) => {
const index = parseInt(task.jobId.split('_')[1]) || 0
const createdAt = new Date(Date.now() - index * 1000).toISOString()
return {
id: task.jobId,
name: output.filename,
size: 0,
created_at: new Date(Date.now() - index * 1000).toISOString(),
...(!mockAssetMapperOptions.omitCreatedAtForIds.has(task.jobId) && {
created_at: createdAt
}),
tags: ['output'],
preview_url: output.url,
user_metadata: {}
@@ -205,6 +214,7 @@ describe('assetsStore - Refactored (Option A)', () => {
setActivePinia(createTestingPinia({ stubActions: false }))
store = useAssetsStore()
vi.clearAllMocks()
mockAssetMapperOptions.omitCreatedAtForIds.clear()
})
describe('Initial Load', () => {
@@ -272,6 +282,17 @@ describe('assetsStore - Refactored (Option A)', () => {
'prompt_2'
])
})
it('should skip unfinished jobs and completed jobs without previews', async () => {
vi.mocked(api.getHistory).mockResolvedValue([
{ ...createMockJobItem(0), status: 'in_progress' },
{ ...createMockJobItem(1), preview_output: undefined }
])
await store.updateHistory()
expect(store.historyAssets).toEqual([])
})
})
describe('Pagination', () => {
@@ -328,6 +349,46 @@ describe('assetsStore - Refactored (Option A)', () => {
expect(uniqueAssetIds.size).toBe(store.historyAssets.length)
})
it('should insert newer paginated items in sorted order', async () => {
vi.mocked(api.getHistory).mockResolvedValueOnce(
Array.from({ length: 200 }, (_, i) => createMockJobItem(i))
)
await store.updateHistory()
vi.mocked(api.getHistory).mockResolvedValueOnce([createMockJobItem(-1)])
await store.loadMoreHistory()
expect(store.historyAssets[0].id).toBe('prompt_-1')
})
it('sorts paginated items when the incoming asset has no timestamp', async () => {
vi.mocked(api.getHistory).mockResolvedValueOnce(
Array.from({ length: 200 }, (_, i) => createMockJobItem(i))
)
await store.updateHistory()
mockAssetMapperOptions.omitCreatedAtForIds.add('prompt_200')
vi.mocked(api.getHistory).mockResolvedValueOnce([createMockJobItem(200)])
await store.loadMoreHistory()
expect(store.historyAssets.at(-1)?.id).toBe('prompt_200')
})
it('sorts paginated items when an existing asset has no timestamp', async () => {
for (let i = 0; i < 200; i++) {
mockAssetMapperOptions.omitCreatedAtForIds.add(`prompt_${i}`)
}
vi.mocked(api.getHistory).mockResolvedValueOnce(
Array.from({ length: 200 }, (_, i) => createMockJobItem(i))
)
await store.updateHistory()
vi.mocked(api.getHistory).mockResolvedValueOnce([createMockJobItem(-1)])
await store.loadMoreHistory()
expect(store.historyAssets[0].id).toBe('prompt_-1')
})
it('should stop loading when no more items', async () => {
// First batch - less than BATCH_SIZE
const firstBatch = Array.from({ length: 50 }, (_, i) =>
@@ -494,6 +555,29 @@ describe('assetsStore - Refactored (Option A)', () => {
expect(store.historyLoading).toBe(false)
expect(store.historyError).toBe(error)
})
it('should preserve existing history when refresh fails', async () => {
vi.mocked(api.getHistory).mockResolvedValueOnce([createMockJobItem(0)])
await store.updateHistory()
const error = new Error('API error')
vi.mocked(api.getHistory).mockRejectedValueOnce(error)
await store.updateHistory()
expect(store.historyAssets).toHaveLength(1)
expect(store.historyError).toBe(error)
})
it('should keep empty history when loadMore fails before any load', async () => {
const error = new Error('API error')
vi.mocked(api.getHistory).mockRejectedValueOnce(error)
await store.loadMoreHistory()
expect(store.historyAssets).toEqual([])
expect(store.historyError).toBe(error)
})
})
describe('Memory Management', () => {
@@ -924,6 +1008,43 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
vi.mocked(assetService.getAssetsForNodeType)
).toHaveBeenCalledTimes(2)
})
it('ignores a model response after the category is invalidated', async () => {
const store = useAssetsStore()
let resolveFetch!: (assets: AssetItem[]) => void
vi.mocked(assetService.getAssetsForNodeType).mockReturnValueOnce(
new Promise((resolve) => {
resolveFetch = resolve
})
)
const request = store.updateModelsForNodeType('CheckpointLoaderSimple')
store.invalidateCategory('checkpoints')
resolveFetch([createMockAsset('stale-response')])
await request
expect(store.getAssets('CheckpointLoaderSimple')).toEqual([])
})
it('ignores a model rejection after the category is invalidated', async () => {
const store = useAssetsStore()
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
let rejectFetch!: (error: Error) => void
vi.mocked(assetService.getAssetsForNodeType).mockReturnValueOnce(
new Promise((_resolve, reject) => {
rejectFetch = reject
})
)
const request = store.updateModelsForNodeType('CheckpointLoaderSimple')
store.invalidateCategory('checkpoints')
rejectFetch(new Error('stale rejection'))
await request
expect(store.getError('CheckpointLoaderSimple')).toBeUndefined()
expect(consoleSpy).not.toHaveBeenCalled()
consoleSpy.mockRestore()
})
})
describe('shallowReactive state reactivity', () => {
@@ -966,6 +1087,10 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
it('should return empty array for unknown node types', () => {
const store = useAssetsStore()
expect(store.getAssets('UnknownNodeType')).toEqual([])
expect(store.isModelLoading('UnknownNodeType')).toBe(false)
expect(store.getError('UnknownNodeType')).toBeUndefined()
expect(store.hasMore('UnknownNodeType')).toBe(false)
expect(store.hasAssetKey('UnknownNodeType')).toBe(false)
})
it('should not fetch for unknown node types', async () => {
@@ -975,6 +1100,63 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
vi.mocked(assetService.getAssetsForNodeType)
).not.toHaveBeenCalled()
})
it('should refresh an already loaded category', async () => {
const store = useAssetsStore()
const nodeType = 'CheckpointLoaderSimple'
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
createMockAsset('first')
])
await store.updateModelsForNodeType(nodeType)
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
createMockAsset('second')
])
await store.updateModelsForNodeType(nodeType)
expect(store.getAssets(nodeType).map((asset) => asset.id)).toEqual([
'second'
])
})
it('reports hasMore for a loaded category', async () => {
const store = useAssetsStore()
const nodeType = 'CheckpointLoaderSimple'
expect(store.hasMore(nodeType)).toBe(false)
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
createMockAsset('only-page')
])
await store.updateModelsForNodeType(nodeType)
expect(store.hasMore(nodeType)).toBe(false)
})
it('should record model loading errors', async () => {
const store = useAssetsStore()
const error = new Error('model fetch failed')
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
vi.mocked(assetService.getAssetsForNodeType).mockRejectedValueOnce(error)
await store.updateModelsForNodeType('CheckpointLoaderSimple')
expect(store.getError('CheckpointLoaderSimple')).toBe(error)
expect(store.isModelLoading('CheckpointLoaderSimple')).toBe(false)
consoleSpy.mockRestore()
})
it('should wrap non-error model loading failures', async () => {
const store = useAssetsStore()
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {})
vi.mocked(assetService.getAssetsForNodeType).mockRejectedValueOnce('boom')
await store.updateModelsForNodeType('CheckpointLoaderSimple')
expect(store.getError('CheckpointLoaderSimple')?.message).toBe('boom')
consoleSpy.mockRestore()
})
})
describe('invalidateCategory', () => {
@@ -1129,7 +1311,140 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
})
})
describe('completed download refresh', () => {
it('refreshes provider and tag caches for the completed model type', async () => {
const store = useAssetsStore()
const downloadStore = useAssetDownloadStore()
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([])
vi.mocked(assetService.getAssetsByTag).mockResolvedValue([])
downloadStore.lastCompletedDownload = {
taskId: 'task-1',
modelType: 'checkpoints',
timestamp: 1
}
await vi.waitFor(() =>
expect(assetService.getAssetsByTag).toHaveBeenCalledWith(
'models',
true,
expect.objectContaining({ limit: 500, offset: 0 })
)
)
expect(assetService.getAssetsForNodeType).toHaveBeenCalledWith(
'CheckpointLoaderSimple',
expect.objectContaining({ limit: 500, offset: 0 })
)
expect(assetService.getAssetsForNodeType).toHaveBeenCalledTimes(1)
expect(assetService.getAssetsByTag).toHaveBeenCalledWith(
'checkpoints',
true,
expect.objectContaining({ limit: 500, offset: 0 })
)
expect(store.hasCategory('tag:models')).toBe(true)
})
})
describe('updateAssetMetadata optimistic cache', () => {
it('still writes metadata when a cache key is unresolved', async () => {
const store = useAssetsStore()
const original = {
...createMockAsset('opt-unknown'),
user_metadata: { note: 'before' } as Record<string, unknown>
}
vi.mocked(assetService.updateAsset).mockResolvedValueOnce({
...original,
user_metadata: { note: 'after' }
})
await store.updateAssetMetadata(
original,
{ note: 'after' },
'UnknownNodeType'
)
expect(vi.mocked(assetService.updateAsset)).toHaveBeenCalledWith(
'opt-unknown',
{ user_metadata: { note: 'after' } }
)
})
it('still updates the server when the asset is not cached', async () => {
const store = useAssetsStore()
const original = {
...createMockAsset('opt-missing'),
user_metadata: { note: 'before' } as Record<string, unknown>
}
vi.mocked(assetService.updateAsset).mockResolvedValueOnce({
...original,
user_metadata: { note: 'server' }
})
await store.updateAssetMetadata(original, { note: 'after' })
expect(vi.mocked(assetService.updateAsset)).toHaveBeenCalledWith(
'opt-missing',
{ user_metadata: { note: 'after' } }
)
})
it('still updates the server when a resolved cache key has not loaded yet', async () => {
const store = useAssetsStore()
const original = {
...createMockAsset('opt-unloaded'),
user_metadata: { note: 'before' } as Record<string, unknown>
}
vi.mocked(assetService.updateAsset).mockResolvedValueOnce({
...original,
user_metadata: { note: 'server' }
})
await store.updateAssetMetadata(
original,
{ note: 'after' },
'CheckpointLoaderSimple'
)
expect(vi.mocked(assetService.updateAsset)).toHaveBeenCalledWith(
'opt-unloaded',
{ user_metadata: { note: 'after' } }
)
})
it('leaves unrelated cached assets alone during optimistic metadata update', async () => {
const store = useAssetsStore()
const cached = {
...createMockAsset('opt-cached'),
user_metadata: { note: 'cached' } as Record<string, unknown>
}
const missing = {
...createMockAsset('opt-missing-from-cache'),
user_metadata: { note: 'before' } as Record<string, unknown>
}
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
cached
])
await store.updateModelsForNodeType('CheckpointLoaderSimple')
vi.mocked(assetService.updateAsset).mockResolvedValueOnce({
...missing,
user_metadata: { note: 'server' }
})
await store.updateAssetMetadata(
missing,
{ note: 'after' },
'CheckpointLoaderSimple'
)
expect(
store.getAssets('CheckpointLoaderSimple')[0].user_metadata
).toEqual({
note: 'cached'
})
})
it('reflects the server response in the cache after a successful update', async () => {
const store = useAssetsStore()
const original = {
@@ -1237,6 +1552,31 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
'featured'
])
})
it('calls only the remove endpoint when there are no tags to add', async () => {
const store = useAssetsStore()
const asset = createMockAsset('tags-remove-only', ['models', 'archived'])
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
asset
])
await store.updateModelsForNodeType('CheckpointLoaderSimple')
vi.mocked(assetService.removeAssetTags).mockResolvedValueOnce({
total_tags: ['models']
})
await store.updateAssetTags(asset, ['models'], 'CheckpointLoaderSimple')
expect(vi.mocked(assetService.removeAssetTags)).toHaveBeenCalledWith(
'tags-remove-only',
['archived']
)
expect(vi.mocked(assetService.addAssetTags)).not.toHaveBeenCalled()
expect(store.getAssets('CheckpointLoaderSimple')[0].tags).toEqual([
'models'
])
})
})
describe('updateAssetTags partial-failure compensation', () => {
@@ -1351,6 +1691,36 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
expect(store.hasCategory('tag:models')).toBe(false)
})
it('keeps unrelated tag caches when compensation fails with a cache key', async () => {
const store = useAssetsStore()
const asset = createMockAsset('tags-target-fail', ['models', 'loras'])
const otherAsset = createMockAsset('tags-other', ['models'])
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
asset
])
await store.updateModelsForNodeType('LoraLoader')
vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([otherAsset])
await store.updateModelsForTag('models')
vi.mocked(assetService.removeAssetTags).mockResolvedValueOnce({
removed: ['loras'],
total_tags: ['models']
})
vi.mocked(assetService.addAssetTags)
.mockRejectedValueOnce(new Error('500 add failed'))
.mockRejectedValueOnce(new Error('503 compensation failed'))
await store.updateAssetTags(
asset,
['models', 'checkpoints'],
'LoraLoader'
)
expect(store.hasCategory('loras')).toBe(false)
expect(store.hasCategory('tag:models')).toBe(true)
})
it('does not attempt compensation when only the add was attempted', async () => {
const store = useAssetsStore()
const asset = createMockAsset('tags-add-only-fail', ['models'])
@@ -1483,9 +1853,78 @@ describe('assetsStore - Deletion State and Input Mapping', () => {
const store = useAssetsStore()
expect(store.getInputName('unknown.png')).toBe('unknown.png')
})
it('ignores input assets without hashes', async () => {
mockIsCloud.value = true
try {
setActivePinia(createTestingPinia({ stubActions: false }))
const store = useAssetsStore()
vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([
{
id: 'input-1',
name: 'plain.png',
tags: ['input']
}
])
await store.updateInputs()
expect(store.getInputName('plain.png')).toBe('plain.png')
} finally {
mockIsCloud.value = false
}
})
})
describe('updateInputs cloud routing', () => {
it('reads input files from the internal API when isCloud is false', async () => {
const fetchMock = vi.fn().mockResolvedValue(
fromAny<Response, unknown>({
ok: true,
json: async () => ['input-a.png', 'input-b.png']
})
)
vi.stubGlobal('fetch', fetchMock)
try {
const store = useAssetsStore()
await store.updateInputs()
expect(fetchMock).toHaveBeenCalledWith(
'http://localhost:3000/files/input',
{ headers: { 'Comfy-User': 'test-user' } }
)
expect(store.inputAssets.map((asset) => asset.name)).toEqual([
'input-a.png',
'input-b.png'
])
} finally {
vi.unstubAllGlobals()
}
})
it('records internal input API failures', async () => {
const fetchMock = vi.fn().mockResolvedValue(
fromAny<Response, unknown>({
ok: false
})
)
vi.stubGlobal('fetch', fetchMock)
try {
const consoleSpy = vi
.spyOn(console, 'error')
.mockImplementation(() => {})
const store = useAssetsStore()
await store.updateInputs()
expect(store.inputError).toBeInstanceOf(Error)
consoleSpy.mockRestore()
} finally {
vi.unstubAllGlobals()
}
})
it('reads from assetService.getAssetsByTag with limit 100 when isCloud is true', async () => {
mockIsCloud.value = true
try {
@@ -1586,6 +2025,18 @@ describe('assetsStore - Flat Output Assets (cloud-only)', () => {
expect(store.flatOutputHasMore).toBe(false)
})
it('does not load more flat outputs when there are no more pages', async () => {
vi.mocked(assetService.getAssetsPageByTag).mockResolvedValueOnce(
makePage([makeAsset('a1', 'one.png')])
)
const store = useAssetsStore()
await store.updateFlatOutputs()
await store.loadMoreFlatOutputs()
expect(assetService.getAssetsPageByTag).toHaveBeenCalledTimes(1)
})
it('threads the minted cursor into after on loadMore and omits offset', async () => {
vi.mocked(assetService.getAssetsPageByTag)
.mockResolvedValueOnce(
@@ -1800,4 +2251,26 @@ describe('assetsStore - Flat Output Assets (cloud-only)', () => {
expect(store.flatOutputAssets.map((x) => x.id)).toEqual(['shared-1'])
})
it('ignores concurrent load more calls while one is active', async () => {
vi.mocked(assetService.getAssetsPageByTag).mockResolvedValueOnce(
makePage([makeAsset('a1', 'f1.png')], { hasMore: true })
)
const store = useAssetsStore()
await store.updateFlatOutputs()
let resolvePage!: (page: AssetResponse) => void
vi.mocked(assetService.getAssetsPageByTag).mockReturnValueOnce(
new Promise<AssetResponse>((resolve) => {
resolvePage = resolve
})
)
const first = store.loadMoreFlatOutputs()
const second = store.loadMoreFlatOutputs()
resolvePage(makePage([makeAsset('a2', 'f2.png')]))
await Promise.all([first, second])
expect(assetService.getAssetsPageByTag).toHaveBeenCalledTimes(2)
})
})

View File

@@ -1,4 +1,5 @@
import { createTestingPinia } from '@pinia/testing'
import { fromAny } from '@total-typescript/shoehorn'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { ref } from 'vue'
@@ -177,9 +178,10 @@ describe('useComfyRegistryStore', () => {
it('should return null when fetching a pack with null ID', async () => {
const store = useComfyRegistryStore()
vi.spyOn(store.getPackById, 'call').mockResolvedValueOnce(null)
const result = await store.getPackById.call(null!)
const result = await store.getPackById.call(
fromAny<Parameters<typeof store.getPackById.call>[0], unknown>(null)
)
expect(result).toBeNull()
expect(mockRegistryService.getPackById).not.toHaveBeenCalled()
@@ -206,6 +208,56 @@ describe('useComfyRegistryStore', () => {
)
})
it('should reuse cached packs by ID', async () => {
const store = useComfyRegistryStore()
await store.getPacksByIds.call(['test-pack-id'])
const result = await store.getPacksByIds.call(['test-pack-id'])
expect(result).toEqual([mockNodePack])
expect(mockRegistryService.listAllPacks).toHaveBeenCalledTimes(1)
})
it('should ignore missing packs by ID', async () => {
mockRegistryService.listAllPacks.mockResolvedValueOnce({
nodes: [fromAny<components['schemas']['Node'], unknown>({ name: 'bad' })],
total: 1,
page: 1,
limit: 10
})
const store = useComfyRegistryStore()
const result = await store.getPacksByIds.call(['unknown-pack-id'])
expect(result).toEqual([])
})
it('should handle empty pack lookup responses', async () => {
mockRegistryService.listAllPacks.mockResolvedValueOnce(null)
const store = useComfyRegistryStore()
const result = await store.getPacksByIds.call(['unknown-pack-id'])
expect(result).toEqual([])
})
it('should filter undefined pack IDs before lookup', async () => {
const store = useComfyRegistryStore()
const result = await store.getPacksByIds.call(
fromAny<components['schemas']['Node']['id'][], unknown>([
'test-pack-id',
undefined
])
)
expect(result).toEqual([mockNodePack])
expect(mockRegistryService.listAllPacks).toHaveBeenCalledWith(
{ node_id: ['test-pack-id'] },
expect.any(Object)
)
})
describe('inferPackFromNodeName', () => {
it('should fetch a pack by comfy node name', async () => {
const store = useComfyRegistryStore()

View File

@@ -113,27 +113,6 @@ describe('commandStore', () => {
})
})
describe('unregisterCommand', () => {
it('removes a registered command', () => {
const store = useCommandStore()
store.registerCommand({ id: 'to.remove', function: vi.fn() })
store.unregisterCommand('to.remove')
expect(store.isRegistered('to.remove')).toBe(false)
})
it('is a no-op for an unregistered id', () => {
const store = useCommandStore()
store.registerCommand({ id: 'keep.me', function: vi.fn() })
store.unregisterCommand('nonexistent')
expect(store.isRegistered('keep.me')).toBe(true)
expect(store.commands).toHaveLength(1)
})
})
describe('isRegistered', () => {
it('returns false for unregistered command', () => {
const store = useCommandStore()

Some files were not shown because too many files have changed in this diff Show More