mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-07-02 21:28:08 +00:00
Compare commits
5 Commits
model_down
...
alexisroll
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
90b2b3ba1d | ||
|
|
25c4250b2f | ||
|
|
a84ec49f75 | ||
|
|
4f1a9700c4 | ||
|
|
2ec2a0e091 |
38
.github/workflows/cla.yml
vendored
38
.github/workflows/cla.yml
vendored
@@ -5,7 +5,6 @@ on:
|
||||
types: [created]
|
||||
pull_request_target:
|
||||
types: [opened, synchronize, closed]
|
||||
merge_group:
|
||||
|
||||
permissions:
|
||||
actions: write
|
||||
@@ -17,13 +16,41 @@ jobs:
|
||||
cla-assistant:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# The CLA action normally requires every commit author in a PR to sign.
|
||||
# We only want the PR author to sign, so we allowlist all other committers
|
||||
# by computing them from the PR's commits and excluding the PR author.
|
||||
- name: Build author-only allowlist
|
||||
id: allowlist
|
||||
if: >
|
||||
github.event_name == 'pull_request_target' ||
|
||||
(github.event_name == 'issue_comment' && github.event.issue.pull_request && (
|
||||
github.event.comment.body == 'recheck' ||
|
||||
github.event.comment.body == 'I have read and agree to the Contributor License Agreement'
|
||||
))
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number || github.event.issue.number }}
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
BASE_ALLOWLIST: action@github.com,actions-user,ampagent,claude,comfy-pr-bot,GitHub Action,github-actions,github-actions[bot],Glary Bot,Glary-Bot,*[bot]
|
||||
run: |
|
||||
others=$(gh api "repos/${{ github.repository }}/pulls/${PR_NUMBER}/commits" --paginate \
|
||||
--jq '.[] | (.author.login // empty), (.committer.login // empty)' \
|
||||
| sort -u | grep -vix "${PR_AUTHOR}" | paste -sd, -)
|
||||
if [ -n "$others" ]; then
|
||||
echo "allowlist=${BASE_ALLOWLIST},${others}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "allowlist=${BASE_ALLOWLIST}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: CLA Assistant
|
||||
# Run on PR events, on "recheck" comment, or when someone posts the exact signing phrase.
|
||||
# IMPORTANT: this phrase must match `custom-pr-sign-comment` below.
|
||||
if: >
|
||||
github.event_name == 'pull_request_target' ||
|
||||
github.event.comment.body == 'recheck' ||
|
||||
github.event.comment.body == 'I have read and agree to the Contributor License Agreement'
|
||||
(github.event_name == 'issue_comment' && github.event.issue.pull_request && (
|
||||
github.event.comment.body == 'recheck' ||
|
||||
github.event.comment.body == 'I have read and agree to the Contributor License Agreement'
|
||||
))
|
||||
uses: contributor-assistant/github-action@ca4a40a7d1004f18d9960b404b97e5f30a505a08 # v2.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -39,9 +66,10 @@ jobs:
|
||||
path-to-signatures: signatures/cla.json
|
||||
branch: main
|
||||
|
||||
# Allowlist bots so they don't need to sign (optional, comma-separated).
|
||||
# Only the PR author must sign: bots plus every non-author committer
|
||||
# are allowlisted via the "Build author-only allowlist" step above.
|
||||
# *[bot] is a catch-all for any GitHub App bot account.
|
||||
allowlist: action@github.com,actions-user,ampagent,claude,comfy-pr-bot,GitHub Action,github-actions,Glary Bot,Glary-Bot,*[bot]
|
||||
allowlist: ${{ steps.allowlist.outputs.allowlist }}
|
||||
|
||||
# Custom PR comment messages
|
||||
custom-notsigned-prcomment: |
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
import { expect, mergeTests } from '@playwright/test'
|
||||
import type { Page } from '@playwright/test'
|
||||
|
||||
import type { ComfyPage } from '@e2e/fixtures/ComfyPage'
|
||||
import { comfyPageFixture } from '@e2e/fixtures/ComfyPage'
|
||||
import { ConfirmDialog } from '@e2e/fixtures/components/ConfirmDialog'
|
||||
import { jsonRoute } from '@e2e/fixtures/utils/jsonRoute'
|
||||
import { webSocketFixture } from '@e2e/fixtures/ws'
|
||||
|
||||
import type {
|
||||
DownloadStatus,
|
||||
EnqueueResponse,
|
||||
HostCredentialUpsert,
|
||||
HostCredentialView
|
||||
} from '@/platform/modelManager/types'
|
||||
|
||||
const test = mergeTests(comfyPageFixture, webSocketFixture)
|
||||
|
||||
const DOWNLOADS_ROUTE = /\/api\/download$/
|
||||
const ENQUEUE_ROUTE = /\/api\/download\/enqueue$/
|
||||
const CREDENTIALS_ROUTE = /\/api\/download\/credentials$/
|
||||
const CREDENTIAL_ROUTE = /\/api\/download\/credentials\/([^/?]+)$/
|
||||
const CLEAR_ROUTE = /\/api\/download\/clear$/
|
||||
|
||||
const DOWNLOAD_ID = 'e2e-download-1'
|
||||
const MODEL_URL =
|
||||
'https://huggingface.co/e2e/test/resolve/main/model.safetensors'
|
||||
const MODEL_ID = 'checkpoints/e2e-test-model.safetensors'
|
||||
|
||||
function makeDownloadStatus(
|
||||
overrides: Partial<DownloadStatus> = {}
|
||||
): DownloadStatus {
|
||||
const now = Math.floor(Date.now() / 1000)
|
||||
return {
|
||||
download_id: DOWNLOAD_ID,
|
||||
model_id: MODEL_ID,
|
||||
url: MODEL_URL,
|
||||
status: 'queued',
|
||||
priority: 0,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
progress: null,
|
||||
speed_bps: null,
|
||||
eta_seconds: null,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
async function enableServerSideModelDownloads(comfyPage: ComfyPage) {
|
||||
await comfyPage.page.evaluate(() => {
|
||||
window.app!.api.serverFeatureFlags.value = {
|
||||
...window.app!.api.serverFeatureFlags.value,
|
||||
server_side_model_downloads: true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/** Keeps GET /download consistent so the 5s stale-poll can't wipe test state. */
|
||||
async function mockDownloadsList(
|
||||
page: Page,
|
||||
getDownloads: () => DownloadStatus[]
|
||||
) {
|
||||
await page.route(DOWNLOADS_ROUTE, async (route) => {
|
||||
if (route.request().method() !== 'GET') return route.fallback()
|
||||
await route.fulfill(jsonRoute({ downloads: getDownloads() }))
|
||||
})
|
||||
}
|
||||
|
||||
function modelDownloaderTabButton(comfyPage: ComfyPage) {
|
||||
return comfyPage.page.locator('.model-manager-tab-button')
|
||||
}
|
||||
|
||||
function modelDownloaderBadge(comfyPage: ComfyPage) {
|
||||
return modelDownloaderTabButton(comfyPage).locator('.sidebar-icon-badge')
|
||||
}
|
||||
|
||||
function modelDownloaderPanel(comfyPage: ComfyPage) {
|
||||
return comfyPage.page.locator('.sidebar-content-container')
|
||||
}
|
||||
|
||||
async function openModelDownloaderTab(comfyPage: ComfyPage) {
|
||||
await modelDownloaderTabButton(comfyPage).click()
|
||||
const panel = modelDownloaderPanel(comfyPage)
|
||||
await expect(panel.getByText('Downloads', { exact: true })).toBeVisible()
|
||||
// Toolbar buttons are only interactive while the tab panel is hovered
|
||||
// (`group-hover/sidebar-tab`), so keep the cursor over it for later clicks.
|
||||
await panel.hover()
|
||||
}
|
||||
|
||||
test.describe('Model Downloader sidebar', { tag: '@ui' }, () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.modelLibrary.mockFoldersWithFiles({ checkpoints: [] })
|
||||
// Isolate from any real downloads already tracked by the backend.
|
||||
await mockDownloadsList(comfyPage.page, () => [])
|
||||
await enableServerSideModelDownloads(comfyPage)
|
||||
})
|
||||
|
||||
test('adds a model by URL and reflects live progress over the websocket', async ({
|
||||
comfyPage,
|
||||
getWebSocket
|
||||
}) => {
|
||||
let downloads: DownloadStatus[] = []
|
||||
await mockDownloadsList(comfyPage.page, () => downloads)
|
||||
await comfyPage.page.route(ENQUEUE_ROUTE, async (route) => {
|
||||
downloads = [makeDownloadStatus()]
|
||||
const response: EnqueueResponse = {
|
||||
download_id: DOWNLOAD_ID,
|
||||
accepted: true
|
||||
}
|
||||
await route.fulfill({
|
||||
status: 202,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(response)
|
||||
})
|
||||
})
|
||||
|
||||
await openModelDownloaderTab(comfyPage)
|
||||
const panel = modelDownloaderPanel(comfyPage)
|
||||
await panel.getByTitle('Add model').click()
|
||||
|
||||
const addDialog = comfyPage.page.getByRole('dialog')
|
||||
await expect(addDialog.getByText('Add model')).toBeVisible()
|
||||
|
||||
await addDialog.getByLabel('URL').fill(MODEL_URL)
|
||||
await expect(addDialog.getByLabel('Filename')).toHaveValue(
|
||||
'model.safetensors'
|
||||
)
|
||||
await addDialog.getByLabel('Filename').fill('e2e-test-model.safetensors')
|
||||
await addDialog.getByRole('combobox', { name: 'Select a folder' }).click()
|
||||
await comfyPage.page
|
||||
.getByRole('option', { name: 'Checkpoints', exact: true })
|
||||
.click()
|
||||
|
||||
await addDialog
|
||||
.getByRole('button', { name: 'Download', exact: true })
|
||||
.click()
|
||||
await expect(addDialog).toBeHidden()
|
||||
|
||||
await expect(panel.getByText('e2e-test-model.safetensors')).toBeVisible()
|
||||
await expect(panel.getByText('Queued', { exact: true })).toBeVisible()
|
||||
await expect(modelDownloaderBadge(comfyPage)).toHaveText('1')
|
||||
|
||||
const ws = await getWebSocket()
|
||||
function sendProgress(overrides: Partial<DownloadStatus>) {
|
||||
const payload = makeDownloadStatus(overrides)
|
||||
downloads = [payload]
|
||||
ws.send(JSON.stringify({ type: 'download_progress', data: payload }))
|
||||
}
|
||||
|
||||
sendProgress({
|
||||
status: 'active',
|
||||
progress: 0.4,
|
||||
bytes_done: 400,
|
||||
total_bytes: 1000,
|
||||
speed_bps: 500_000
|
||||
})
|
||||
|
||||
await expect(panel.getByText('Downloading', { exact: true })).toBeVisible()
|
||||
await expect(panel.getByText(/40%/)).toBeVisible()
|
||||
await expect(modelDownloaderBadge(comfyPage)).toHaveText('1')
|
||||
|
||||
sendProgress({
|
||||
status: 'completed',
|
||||
progress: 1,
|
||||
bytes_done: 1000,
|
||||
total_bytes: 1000,
|
||||
speed_bps: null,
|
||||
eta_seconds: null
|
||||
})
|
||||
|
||||
await expect(panel.getByText('Completed', { exact: true })).toBeVisible()
|
||||
await expect(panel.getByText('History', { exact: true })).toBeVisible()
|
||||
await expect(modelDownloaderBadge(comfyPage)).toHaveCount(0)
|
||||
})
|
||||
|
||||
test('clears history and the cleared rows stay gone after reopening the tab', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
let downloads: DownloadStatus[] = [
|
||||
makeDownloadStatus({ status: 'completed', progress: 1, bytes_done: 1000 })
|
||||
]
|
||||
await mockDownloadsList(comfyPage.page, () => downloads)
|
||||
await comfyPage.page.route(CLEAR_ROUTE, async (route) => {
|
||||
const count = downloads.length
|
||||
downloads = []
|
||||
await route.fulfill(jsonRoute({ deleted: count }))
|
||||
})
|
||||
|
||||
await openModelDownloaderTab(comfyPage)
|
||||
const panel = modelDownloaderPanel(comfyPage)
|
||||
await expect(panel.getByText('History', { exact: true })).toBeVisible()
|
||||
await expect(panel.getByText('e2e-test-model.safetensors')).toBeVisible()
|
||||
|
||||
await panel.getByRole('button', { name: 'Clear history' }).click()
|
||||
await expect(panel.getByText('No downloads yet')).toBeVisible()
|
||||
|
||||
// Reopening the tab re-runs hydrate() -> GET /api/download. The bug was
|
||||
// that the cleared row reappeared here; the persisted delete must prevent
|
||||
// the backend list from returning it again.
|
||||
await modelDownloaderTabButton(comfyPage).click()
|
||||
await openModelDownloaderTab(comfyPage)
|
||||
await expect(
|
||||
modelDownloaderPanel(comfyPage).getByText('No downloads yet')
|
||||
).toBeVisible()
|
||||
await expect(
|
||||
modelDownloaderPanel(comfyPage).getByText('e2e-test-model.safetensors')
|
||||
).toBeHidden()
|
||||
})
|
||||
|
||||
test('manages download credentials: add, edit, and delete', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
let credentials: HostCredentialView[] = []
|
||||
let nextId = 1
|
||||
|
||||
await comfyPage.page.route(CREDENTIALS_ROUTE, async (route) => {
|
||||
const method = route.request().method()
|
||||
if (method === 'GET') {
|
||||
await route.fulfill(jsonRoute({ credentials }))
|
||||
return
|
||||
}
|
||||
if (method === 'POST') {
|
||||
const body = route.request().postDataJSON() as HostCredentialUpsert
|
||||
const existing = credentials.find((c) => c.host === body.host)
|
||||
const now = Math.floor(Date.now() / 1000)
|
||||
const view: HostCredentialView = {
|
||||
id: existing?.id ?? `cred-${nextId++}`,
|
||||
host: body.host,
|
||||
auth_scheme: body.auth_scheme ?? 'bearer',
|
||||
header_name: body.header_name ?? null,
|
||||
query_param: body.query_param ?? null,
|
||||
label: body.label ?? null,
|
||||
match_subdomains: body.match_subdomains ?? false,
|
||||
enabled: body.enabled ?? true,
|
||||
secret_last4: body.secret.slice(-4),
|
||||
created_at: existing?.created_at ?? now,
|
||||
updated_at: now
|
||||
}
|
||||
credentials = existing
|
||||
? credentials.map((c) => (c.id === view.id ? view : c))
|
||||
: [...credentials, view]
|
||||
await route.fulfill(jsonRoute(view))
|
||||
return
|
||||
}
|
||||
await route.fallback()
|
||||
})
|
||||
await comfyPage.page.route(CREDENTIAL_ROUTE, async (route) => {
|
||||
if (route.request().method() !== 'DELETE') return route.fallback()
|
||||
const id = route.request().url().match(CREDENTIAL_ROUTE)?.[1]
|
||||
credentials = credentials.filter((c) => c.id !== id)
|
||||
await route.fulfill(jsonRoute({ deleted: true }))
|
||||
})
|
||||
|
||||
const dialog = comfyPage.page
|
||||
.getByRole('dialog')
|
||||
.filter({ hasText: 'Credentials Manager' })
|
||||
|
||||
// The success toast shown after saving can momentarily steal focus and
|
||||
// close the dialog, so re-opening is retried until it sticks.
|
||||
async function ensureCredentialsDialogOpen() {
|
||||
await expect(async () => {
|
||||
const panel = modelDownloaderPanel(comfyPage)
|
||||
await panel.hover()
|
||||
await panel.getByTitle('Credentials Manager').click()
|
||||
await expect(dialog).toBeVisible({ timeout: 1000 })
|
||||
}).toPass({ timeout: 10_000 })
|
||||
}
|
||||
|
||||
await openModelDownloaderTab(comfyPage)
|
||||
await ensureCredentialsDialogOpen()
|
||||
|
||||
await dialog.getByLabel('Host').fill('huggingface.co')
|
||||
await dialog.getByLabel('API key').fill('secret-key-1234')
|
||||
await dialog.getByRole('button', { name: 'Save', exact: true }).click()
|
||||
|
||||
await ensureCredentialsDialogOpen()
|
||||
await expect(
|
||||
dialog.getByText('huggingface.co · Bearer token · ••••1234')
|
||||
).toBeVisible()
|
||||
|
||||
await dialog.getByTitle('Edit').click()
|
||||
await expect(dialog.getByText('Update credential')).toBeVisible()
|
||||
await expect(dialog.getByLabel('API key')).toHaveValue('')
|
||||
await dialog.getByLabel('Label').fill('My HF Key')
|
||||
await dialog.getByLabel('API key').fill('updated-secret-5678')
|
||||
await dialog.getByRole('button', { name: 'Save', exact: true }).click()
|
||||
|
||||
await ensureCredentialsDialogOpen()
|
||||
await expect(dialog.getByText('My HF Key')).toBeVisible()
|
||||
await expect(
|
||||
dialog.getByText('huggingface.co · Bearer token · ••••5678')
|
||||
).toBeVisible()
|
||||
|
||||
await dialog.getByTitle('Delete').click()
|
||||
const confirm = new ConfirmDialog(comfyPage.page)
|
||||
await confirm.click('delete')
|
||||
|
||||
await expect(dialog.getByText('My HF Key')).toBeHidden()
|
||||
})
|
||||
})
|
||||
@@ -243,7 +243,7 @@ onMounted(() => {
|
||||
--sidebar-padding: 4px;
|
||||
--sidebar-icon-size: 1rem;
|
||||
|
||||
--sidebar-default-floating-width: 50px;
|
||||
--sidebar-default-floating-width: 48px;
|
||||
--sidebar-default-connected-width: calc(
|
||||
var(--sidebar-default-floating-width) + var(--sidebar-padding) * 2
|
||||
);
|
||||
|
||||
@@ -224,7 +224,7 @@ const handleOpenUserSettings = () => {
|
||||
}
|
||||
|
||||
const handleOpenPlansAndPricing = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -239,8 +239,7 @@ const handleOpenPlanAndCreditsSettings = () => {
|
||||
}
|
||||
|
||||
const handleTopUp = () => {
|
||||
// Track purchase credits entry from avatar popover
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked()
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
|
||||
dialogService.showTopUpCreditsDialog()
|
||||
emit('close')
|
||||
}
|
||||
@@ -254,7 +253,7 @@ const handleOpenPartnerNodesInfo = () => {
|
||||
}
|
||||
|
||||
const handleUpgradeToAddCredits = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
|
||||
@@ -21,6 +21,6 @@ const { isFreeTier } = useBillingContext()
|
||||
const subscriptionDialog = useSubscriptionDialog()
|
||||
|
||||
function handleClick() {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
<script setup lang="ts">
|
||||
import type { DialogContentEmits, DialogContentProps } from 'reka-ui'
|
||||
import {
|
||||
DialogContent,
|
||||
injectDialogRootContext,
|
||||
useForwardPropsEmits
|
||||
} from 'reka-ui'
|
||||
import { DialogContent, useForwardPropsEmits } from 'reka-ui'
|
||||
import type { HTMLAttributes } from 'vue'
|
||||
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
import type { DialogContentSize } from './dialog.variants'
|
||||
import { dialogContentVariants } from './dialog.variants'
|
||||
import { useModalPointerLock } from './useModalPointerLock'
|
||||
|
||||
const {
|
||||
size,
|
||||
@@ -28,11 +23,6 @@ const {
|
||||
|
||||
const emits = defineEmits<DialogContentEmits>()
|
||||
const forwarded = useForwardPropsEmits(restProps, emits)
|
||||
|
||||
const dialogRootContext = injectDialogRootContext(null)
|
||||
if (dialogRootContext?.modal.value) {
|
||||
useModalPointerLock(() => dialogRootContext.open.value)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import { render } from '@testing-library/vue'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import {
|
||||
SelectContent,
|
||||
SelectPortal,
|
||||
SelectRoot,
|
||||
SelectTrigger,
|
||||
SelectViewport
|
||||
} from 'reka-ui'
|
||||
import { defineComponent, h, nextTick, ref } from 'vue'
|
||||
|
||||
import Dialog from './Dialog.vue'
|
||||
import DialogContent from './DialogContent.vue'
|
||||
import DialogPortal from './DialogPortal.vue'
|
||||
|
||||
async function flush() {
|
||||
await nextTick()
|
||||
await nextTick()
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
await nextTick()
|
||||
}
|
||||
|
||||
function mountDialogWithSelect(modal: boolean) {
|
||||
const dialogOpen = ref(true)
|
||||
const selectOpen = ref(false)
|
||||
|
||||
const Parent = defineComponent({
|
||||
setup() {
|
||||
return () =>
|
||||
h(
|
||||
Dialog,
|
||||
{
|
||||
modal,
|
||||
open: dialogOpen.value,
|
||||
'onUpdate:open': (value: boolean) => (dialogOpen.value = value)
|
||||
},
|
||||
() =>
|
||||
h(DialogPortal, null, () =>
|
||||
h(DialogContent, null, () =>
|
||||
h(
|
||||
SelectRoot,
|
||||
{
|
||||
open: selectOpen.value,
|
||||
'onUpdate:open': (value: boolean) =>
|
||||
(selectOpen.value = value)
|
||||
},
|
||||
() => [
|
||||
h(SelectTrigger, null, () => 'trigger'),
|
||||
h(SelectPortal, null, () =>
|
||||
h(SelectContent, { position: 'popper' }, () =>
|
||||
h(SelectViewport, null, () => 'items')
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
return { ...render(Parent), dialogOpen, selectOpen }
|
||||
}
|
||||
|
||||
describe('modal dialog pointer lock', () => {
|
||||
it('keeps body inert after a nested combobox popover opens and closes', async () => {
|
||||
const { selectOpen } = mountDialogWithSelect(true)
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
selectOpen.value = true
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
// Reka restores body pointer events when the popover layer unmounts; the
|
||||
// lock must re-assert it so the canvas behind the dialog stays inert.
|
||||
selectOpen.value = false
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
})
|
||||
|
||||
it('restores body pointer events once the dialog closes', async () => {
|
||||
const { dialogOpen } = mountDialogWithSelect(true)
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
dialogOpen.value = false
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('')
|
||||
})
|
||||
|
||||
it('does not lock body for a non-modal dialog', async () => {
|
||||
mountDialogWithSelect(false)
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('')
|
||||
})
|
||||
|
||||
it('holds the body lock until every open modal dialog closes', async () => {
|
||||
const first = mountDialogWithSelect(true)
|
||||
const second = mountDialogWithSelect(true)
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
// With one modal still open the shared lock must keep the body inert.
|
||||
first.dialogOpen.value = false
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
// Once the last modal closes the lock releases and stops forcing inert.
|
||||
second.dialogOpen.value = false
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).not.toBe('none')
|
||||
})
|
||||
})
|
||||
@@ -1,54 +0,0 @@
|
||||
import { onScopeDispose, toValue, watch } from 'vue'
|
||||
import type { MaybeRefOrGetter } from 'vue'
|
||||
|
||||
/**
|
||||
* Keeps the canvas behind a modal dialog inert by holding `document.body`'s
|
||||
* pointer-events lock for as long as at least one modal dialog is open.
|
||||
*
|
||||
* Reka-UI locks body pointer events per modal layer, but a nested dismissable
|
||||
* layer that is portalled to the body — e.g. a `Select` popover inside the
|
||||
* dialog — restores the body's pointer events when it closes, even while the
|
||||
* outer modal dialog is still open. That momentarily re-enables the canvas, so
|
||||
* combobox clicks leak through to it and can select a node or dismiss the
|
||||
* dialog. Reka still performs the initial lock and final restore; the
|
||||
* `MutationObserver` only re-asserts `none` if the lock is cleared while a
|
||||
* modal dialog is still open.
|
||||
*/
|
||||
let openModalCount = 0
|
||||
let observer: MutationObserver | null = null
|
||||
|
||||
function enforceLock() {
|
||||
if (openModalCount > 0 && document.body.style.pointerEvents !== 'none') {
|
||||
document.body.style.pointerEvents = 'none'
|
||||
}
|
||||
}
|
||||
|
||||
function acquire() {
|
||||
openModalCount += 1
|
||||
if (observer) return
|
||||
observer = new MutationObserver(enforceLock)
|
||||
observer.observe(document.body, {
|
||||
attributes: true,
|
||||
attributeFilter: ['style']
|
||||
})
|
||||
}
|
||||
|
||||
function release() {
|
||||
openModalCount = Math.max(0, openModalCount - 1)
|
||||
if (openModalCount > 0) return
|
||||
observer?.disconnect()
|
||||
observer = null
|
||||
document.body.style.pointerEvents = ''
|
||||
}
|
||||
|
||||
export function useModalPointerLock(isOpen: MaybeRefOrGetter<boolean>) {
|
||||
let holding = false
|
||||
const sync = (open: boolean) => {
|
||||
if (open === holding) return
|
||||
holding = open
|
||||
if (open) acquire()
|
||||
else release()
|
||||
}
|
||||
watch(() => toValue(isOpen), sync, { immediate: true })
|
||||
onScopeDispose(() => sync(false))
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ComputedRef, Ref } from 'vue'
|
||||
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type {
|
||||
BillingStatus,
|
||||
@@ -75,9 +76,10 @@ export interface BillingActions {
|
||||
*/
|
||||
requireActiveSubscription: () => Promise<void>
|
||||
/**
|
||||
* Shows the subscription dialog.
|
||||
* Shows the subscription dialog. Pass a reason so the paywall open and any
|
||||
* downstream checkout stay attributed to the triggering product moment.
|
||||
*/
|
||||
showSubscriptionDialog: () => void
|
||||
showSubscriptionDialog: (options?: SubscriptionDialogOptions) => void
|
||||
}
|
||||
|
||||
export interface BillingState {
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
getTierFeatures
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
PreviewSubscribeOptions,
|
||||
SubscribeOptions
|
||||
@@ -281,8 +282,8 @@ function useBillingContextInternal(): BillingContext {
|
||||
return activeContext.value.requireActiveSubscription()
|
||||
}
|
||||
|
||||
function showSubscriptionDialog() {
|
||||
return activeContext.value.showSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions) {
|
||||
return activeContext.value.showSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -2,6 +2,7 @@ import { computed, ref } from 'vue'
|
||||
|
||||
import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import { useSubscription } from '@/platform/cloud/subscription/composables/useSubscription'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
BillingStatus,
|
||||
BillingSubscriptionStatus,
|
||||
@@ -189,12 +190,12 @@ export function useLegacyBilling(): BillingState & BillingActions {
|
||||
async function requireActiveSubscription(): Promise<void> {
|
||||
await fetchStatus()
|
||||
if (!isActiveSubscription.value) {
|
||||
legacyShowSubscriptionDialog()
|
||||
legacyShowSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
function showSubscriptionDialog(): void {
|
||||
legacyShowSubscriptionDialog()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
|
||||
legacyShowSubscriptionDialog(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -503,7 +503,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -526,7 +526,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
@@ -548,7 +548,7 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
}) => {
|
||||
trackRunButton(metadata)
|
||||
if (!isActiveSubscription.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -30,8 +30,7 @@ export enum ServerFeatureFlag {
|
||||
COMFYHUB_PROFILE_GATE_ENABLED = 'comfyhub_profile_gate_enabled',
|
||||
SHOW_SIGNIN_BUTTON = 'show_signin_button',
|
||||
UNIFIED_CLOUD_AUTH = 'unified_cloud_auth',
|
||||
SIGNUP_TURNSTILE = 'signup_turnstile',
|
||||
SERVER_SIDE_MODEL_DOWNLOADS = 'server_side_model_downloads'
|
||||
SIGNUP_TURNSTILE = 'signup_turnstile'
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -182,12 +181,6 @@ export function useFeatureFlags() {
|
||||
remoteConfig.value.signup_turnstile,
|
||||
'off'
|
||||
)
|
||||
},
|
||||
get serverSideModelDownloads() {
|
||||
return api.getServerFeature(
|
||||
ServerFeatureFlag.SERVER_SIDE_MODEL_DOWNLOADS,
|
||||
false
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -1069,74 +1069,6 @@
|
||||
"update": "Update",
|
||||
"description": "Check out the latest improvements and features in this update."
|
||||
},
|
||||
"modelManager": {
|
||||
"title": "Downloads",
|
||||
"active": "Active",
|
||||
"history": "History",
|
||||
"empty": "No downloads yet",
|
||||
"clearHistory": "Clear history",
|
||||
"addModel": "Add model",
|
||||
"addModelDescription": "Download a model directly to your ComfyUI models folder.",
|
||||
"url": "URL",
|
||||
"urlPlaceholder": "https://huggingface.co/.../model.safetensors",
|
||||
"hostNotAllowedHint": "This host may not be on the download allowlist.",
|
||||
"folder": "Model folder",
|
||||
"selectFolder": "Select a folder",
|
||||
"filename": "Filename",
|
||||
"filenamePlaceholder": "model.safetensors",
|
||||
"allowAnyExtension": "Allow any file extension (advanced)",
|
||||
"download": "Download",
|
||||
"downloadQueued": "Download queued",
|
||||
"alreadyInstalled": "Model already installed",
|
||||
"alreadyDownloading": "Model is already downloading",
|
||||
"resume": "Resume",
|
||||
"raisePriority": "Raise priority",
|
||||
"removeFromList": "Remove from list",
|
||||
"addCredentials": "Add credentials",
|
||||
"authErrorHint": "{host} needs an API key. Add one in the Credentials Manager, then resume.",
|
||||
"authErrorHintNoHost": "This host/model needs an API key. Add one in the Credentials Manager, then resume.",
|
||||
"gatedModelHint": "This model is gated. Accept its license on the model's page, then add an API key and resume.",
|
||||
"openModelPage": "Accept license",
|
||||
"actionFailed": "Action failed",
|
||||
"cancelConfirmTitle": "Cancel download?",
|
||||
"cancelConfirmMessage": "This will stop the download and delete the partial file for \"{name}\".",
|
||||
"cancelConfirm": "Cancel download",
|
||||
"status": {
|
||||
"queued": "Queued",
|
||||
"active": "Downloading",
|
||||
"paused": "Paused",
|
||||
"verifying": "Verifying",
|
||||
"completed": "Completed",
|
||||
"failed": "Failed",
|
||||
"cancelled": "Cancelled"
|
||||
},
|
||||
"credentials": {
|
||||
"title": "Credentials Manager",
|
||||
"description": "Store API keys per host (e.g. HuggingFace, Civitai). Secrets are write-only and never shown again.",
|
||||
"add": "Add credential",
|
||||
"update": "Update credential",
|
||||
"edit": "Edit",
|
||||
"save": "Save",
|
||||
"saved": "Credential saved",
|
||||
"host": "Host",
|
||||
"secret": "API key",
|
||||
"secretPlaceholder": "Paste API key",
|
||||
"authScheme": "Auth scheme",
|
||||
"headerName": "Header name",
|
||||
"queryParam": "Query parameter",
|
||||
"label": "Label",
|
||||
"disabled": "Disabled",
|
||||
"matchSubdomains": "Match subdomains",
|
||||
"matchSubdomainsWarning": "Not recommended: hubs redirect to sibling CDN hosts that must not receive your key.",
|
||||
"deleteTitle": "Delete credential?",
|
||||
"deleteMessage": "Delete the stored credential for \"{host}\"?",
|
||||
"scheme": {
|
||||
"bearer": "Bearer token",
|
||||
"header": "Custom header",
|
||||
"query": "Query parameter"
|
||||
}
|
||||
}
|
||||
},
|
||||
"menu": {
|
||||
"hideMenu": "Hide Menu",
|
||||
"showMenu": "Show Menu",
|
||||
|
||||
@@ -25,6 +25,6 @@ function handleClose() {
|
||||
}
|
||||
|
||||
function handleSubscribe() {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'upload_model_upgrade' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -140,7 +140,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
|
||||
// Shows loading affordances
|
||||
@@ -169,7 +172,10 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(mockPerformSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'creator',
|
||||
'monthly',
|
||||
false
|
||||
{
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
@@ -180,7 +186,8 @@ describe('CloudSubscriptionRedirectView', () => {
|
||||
expect(screen.getByText('Subscribe to Team Plan')).toBeInTheDocument()
|
||||
expect(mockPerformTeamSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
'team_700',
|
||||
'yearly'
|
||||
'yearly',
|
||||
{ paymentIntentSource: 'deep_link' }
|
||||
)
|
||||
// Team never goes through the personal checkout path
|
||||
expect(mockPerformSubscriptionCheckout).not.toHaveBeenCalled()
|
||||
|
||||
@@ -94,7 +94,9 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
return
|
||||
}
|
||||
isTeamCheckout.value = true
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle)
|
||||
await performTeamSubscriptionCheckout(stopId, billingCycle, {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -112,7 +114,10 @@ const runRedirect = wrapWithErrorHandlingAsync(async () => {
|
||||
if (isActiveSubscription.value) {
|
||||
await accessBillingPortal(undefined, false)
|
||||
} else {
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, false)
|
||||
await performSubscriptionCheckout(tierKeyParam, billingCycle, {
|
||||
openInNewTab: false,
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
}
|
||||
}, reportError)
|
||||
|
||||
|
||||
@@ -351,12 +351,12 @@ const handleRefresh = wrapWithErrorHandlingAsync(async () => {
|
||||
})
|
||||
|
||||
function handleAddCredits() {
|
||||
telemetry?.trackAddApiCreditButtonClicked()
|
||||
telemetry?.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
function handleUpgradeToAddCredits() {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
}
|
||||
|
||||
async function handleWindowFocus() {
|
||||
|
||||
@@ -5,6 +5,8 @@ import { render, screen } from '@testing-library/vue'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
import FreeTierDialogContent from './FreeTierDialogContent.vue'
|
||||
|
||||
const mockRenewalDate = vi.hoisted(() => ({ value: null as string | null }))
|
||||
@@ -15,7 +17,7 @@ vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
}))
|
||||
}))
|
||||
|
||||
function renderComponent() {
|
||||
function renderComponent(props?: { reason?: PaymentIntentSource }) {
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
@@ -23,6 +25,7 @@ function renderComponent() {
|
||||
})
|
||||
|
||||
return render(FreeTierDialogContent, {
|
||||
props,
|
||||
global: {
|
||||
plugins: [i18n]
|
||||
}
|
||||
@@ -43,4 +46,18 @@ describe('FreeTierDialogContent', () => {
|
||||
renderComponent()
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('keeps the generic copy for intent reasons outside the credits variants', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'subscribe_to_run' })
|
||||
expect(
|
||||
screen.getByText('Your credits refresh on Jul 15, 2026.')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('swaps to the out-of-credits copy without the refresh line', () => {
|
||||
mockRenewalDate.value = '2026-07-15T10:00:00Z'
|
||||
renderComponent({ reason: 'out_of_credits' })
|
||||
expect(screen.queryByText(/credits refresh on/)).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
@@ -52,7 +52,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="!reason || reason === 'subscription_required'"
|
||||
v-if="!isCreditsBlockedVariant"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -65,10 +65,7 @@
|
||||
</p>
|
||||
|
||||
<p
|
||||
v-if="
|
||||
(!reason || reason === 'subscription_required') &&
|
||||
formattedRenewalDate
|
||||
"
|
||||
v-if="!isCreditsBlockedVariant && formattedRenewalDate"
|
||||
class="m-0 text-sm text-text-secondary"
|
||||
>
|
||||
{{
|
||||
@@ -88,7 +85,7 @@
|
||||
@click="$emit('upgrade')"
|
||||
>
|
||||
{{
|
||||
reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
isCreditsBlockedVariant
|
||||
? $t('subscription.freeTier.upgradeCta')
|
||||
: $t('subscription.freeTier.subscribeCta')
|
||||
}}
|
||||
@@ -103,12 +100,12 @@ import { computed } from 'vue'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import SubscriptionBenefits from '@/platform/cloud/subscription/components/SubscriptionBenefits.vue'
|
||||
import { getTierCredits } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
|
||||
defineProps<{
|
||||
reason?: SubscriptionDialogReason
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
defineEmits<{
|
||||
@@ -129,4 +126,10 @@ const formattedRenewalDate = computed(() => {
|
||||
})
|
||||
|
||||
const freeTierCredits = computed(() => getTierCredits('free'))
|
||||
|
||||
// Only these two variants replace the generic free-tier copy; any other
|
||||
// intent reason (subscribe_to_run, deep_link, ...) keeps the default pitch.
|
||||
const isCreditsBlockedVariant = computed(
|
||||
() => reason === 'out_of_credits' || reason === 'top_up_blocked'
|
||||
)
|
||||
</script>
|
||||
|
||||
@@ -261,6 +261,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
expect(mockAccessBillingPortal).toHaveBeenCalledWith('creator-yearly')
|
||||
@@ -341,6 +342,7 @@ describe('PricingTable', () => {
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should use the latest userId value when it changes after mount', async () => {
|
||||
@@ -366,6 +368,7 @@ describe('PricingTable', () => {
|
||||
tier: 'creator',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'change',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
previous_tier: 'standard'
|
||||
})
|
||||
})
|
||||
|
||||
@@ -277,13 +277,19 @@ import type {
|
||||
TierKey,
|
||||
TierPricing
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import {
|
||||
recordPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { performSubscriptionCheckout } from '@/platform/cloud/subscription/utils/subscriptionCheckoutUtil'
|
||||
import { isPlanDowngrade } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
|
||||
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
|
||||
@@ -321,6 +327,10 @@ interface PricingTierConfig {
|
||||
isPopular?: boolean
|
||||
}
|
||||
|
||||
const { reason } = defineProps<{
|
||||
reason?: PaymentIntentSource
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
chooseTeamWorkspace: []
|
||||
}>()
|
||||
@@ -463,16 +473,17 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
} as const
|
||||
const previousPlan = currentPlanDescriptor.value
|
||||
const checkoutAttribution = await getCheckoutAttributionForCloud()
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
})
|
||||
}
|
||||
const beginCheckoutMetadata = userId.value
|
||||
? {
|
||||
user_id: userId.value,
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change' as const,
|
||||
...(reason ? { payment_intent_source: reason } : {}),
|
||||
...checkoutAttribution,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {})
|
||||
}
|
||||
: null
|
||||
// Pass the target tier to create a deep link to subscription update confirmation
|
||||
const checkoutTier = getCheckoutTier(
|
||||
targetPlan.tierKey,
|
||||
@@ -487,29 +498,39 @@ const handleSubscribe = wrapWithErrorHandlingAsync(
|
||||
|
||||
if (downgrade) {
|
||||
// TODO(COMFY-StripeProration): Remove once backend checkout creation mirrors portal proration ("change at billing end")
|
||||
await accessBillingPortal()
|
||||
const didOpenPortal = await accessBillingPortal()
|
||||
if (didOpenPortal && beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(beginCheckoutMetadata)
|
||||
}
|
||||
} else {
|
||||
const didOpenPortal = await accessBillingPortal(checkoutTier)
|
||||
if (!didOpenPortal) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
const pendingAttempt = recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: targetPlan.tierKey,
|
||||
cycle: targetPlan.billingCycle,
|
||||
checkout_type: 'change',
|
||||
payment_intent_source: reason,
|
||||
...(previousPlan ? { previous_tier: previousPlan.tierKey } : {}),
|
||||
...(previousPlan
|
||||
? { previous_cycle: previousPlan.billingCycle }
|
||||
: {})
|
||||
})
|
||||
if (beginCheckoutMetadata) {
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
beginCheckoutMetadata,
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
await performSubscriptionCheckout(
|
||||
tierKey,
|
||||
currentBillingCycle.value,
|
||||
true
|
||||
)
|
||||
await performSubscriptionCheckout(tierKey, currentBillingCycle.value, {
|
||||
paymentIntentSource: reason
|
||||
})
|
||||
}
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
|
||||
@@ -56,7 +56,7 @@ const handleSubscribe = () => {
|
||||
current_tier: tier.value?.toLowerCase()
|
||||
})
|
||||
isAwaitingStripeSubscription.value = true
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_now_button' })
|
||||
}
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
|
||||
@@ -54,6 +54,6 @@ function handleSubscribeToRun() {
|
||||
trackRunButton({ subscribe_to_run: true })
|
||||
}
|
||||
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscribe_to_run' })
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -48,7 +48,9 @@
|
||||
v-if="isActiveSubscription"
|
||||
variant="primary"
|
||||
class="rounded-lg px-4 py-2 text-sm font-normal text-text-primary"
|
||||
@click="showSubscriptionDialog"
|
||||
@click="
|
||||
showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
"
|
||||
>
|
||||
{{ $t('subscription.upgradePlan') }}
|
||||
</Button>
|
||||
|
||||
@@ -33,7 +33,11 @@
|
||||
</i18n-t>
|
||||
</div>
|
||||
|
||||
<PricingTable class="flex-1" @choose-team-workspace="handleChooseTeam" />
|
||||
<PricingTable
|
||||
:reason
|
||||
class="flex-1"
|
||||
@choose-team-workspace="handleChooseTeam"
|
||||
/>
|
||||
|
||||
<!-- Contact and Enterprise Links -->
|
||||
<div class="flex flex-col items-center gap-2">
|
||||
@@ -157,11 +161,11 @@ import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useCommandStore } from '@/stores/commandStore'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
const { onClose, reason, onChooseTeam } = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
onChooseTeam?: () => void
|
||||
}>()
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@ export function useAccountPreconditionDialog() {
|
||||
)
|
||||
return
|
||||
case 'subscription':
|
||||
void dialogService.showSubscriptionRequiredDialog()
|
||||
void dialogService.showSubscriptionRequiredDialog({
|
||||
reason: 'subscription_required'
|
||||
})
|
||||
return
|
||||
case 'credits':
|
||||
void dialogService.showTopUpCreditsDialog({
|
||||
|
||||
@@ -55,12 +55,6 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
describe('usePricingTableUrlLoader', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -96,9 +90,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
reason: 'deep_link',
|
||||
planMode: undefined
|
||||
})
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
reason: 'deep_link'
|
||||
})
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
})
|
||||
|
||||
@@ -150,7 +141,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('denies, strips, and clears together when the user is not eligible', async () => {
|
||||
@@ -161,7 +151,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
await loadPricingTableFromUrl()
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({
|
||||
query: { other: 'param' }
|
||||
})
|
||||
@@ -230,7 +219,6 @@ describe('usePricingTableUrlLoader', () => {
|
||||
)
|
||||
|
||||
expect(mockShowPricingTable).not.toHaveBeenCalled()
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
expect(mockRouterReplace).toHaveBeenCalledWith({ query: {} })
|
||||
expect(preservedQueryMocks.clearPreservedQuery).toHaveBeenCalledWith(
|
||||
'pricing'
|
||||
|
||||
@@ -7,7 +7,6 @@ import {
|
||||
mergePreservedQueryIntoQuery
|
||||
} from '@/platform/navigation/preservedQueryManager'
|
||||
import { PRESERVED_QUERY_NAMESPACES } from '@/platform/navigation/preservedQueryNamespaces'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -62,7 +61,6 @@ export function usePricingTableUrlLoader() {
|
||||
const planMode =
|
||||
param === 'team' || param === 'personal' ? param : undefined
|
||||
|
||||
useTelemetry()?.trackSubscription('modal_opened', { reason: 'deep_link' })
|
||||
subscriptionDialog.showPricingTable({ reason: 'deep_link', planMode })
|
||||
}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import { useDialogService } from '@/services/dialogService'
|
||||
@@ -237,14 +237,7 @@ function useSubscriptionInternal() {
|
||||
})
|
||||
}, reportError)
|
||||
|
||||
const showSubscriptionDialog = (options?: {
|
||||
reason?: SubscriptionDialogReason
|
||||
}) => {
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: subscriptionTier.value?.toLowerCase(),
|
||||
reason: options?.reason
|
||||
})
|
||||
|
||||
const showSubscriptionDialog = (options?: SubscriptionDialogOptions) => {
|
||||
void showSubscriptionRequiredDialog(options)
|
||||
}
|
||||
|
||||
@@ -277,7 +270,7 @@ function useSubscriptionInternal() {
|
||||
await fetchSubscriptionStatus()
|
||||
|
||||
if (!isSubscribedOrIsNotCloud.value) {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -39,15 +39,23 @@ vi.mock('@/stores/commandStore', () => ({
|
||||
}))
|
||||
|
||||
// useTelemetry() returns null in OSS, a dispatcher in cloud — toggle via mockIsCloud.
|
||||
const { mockIsCloud, mockTrackHelpResourceClicked } = vi.hoisted(() => ({
|
||||
const {
|
||||
mockIsCloud,
|
||||
mockTrackHelpResourceClicked,
|
||||
mockTrackAddApiCreditButtonClicked
|
||||
} = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockTrackHelpResourceClicked: vi.fn()
|
||||
mockTrackHelpResourceClicked: vi.fn(),
|
||||
mockTrackAddApiCreditButtonClicked: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () =>
|
||||
mockIsCloud.value
|
||||
? { trackHelpResourceClicked: mockTrackHelpResourceClicked }
|
||||
? {
|
||||
trackHelpResourceClicked: mockTrackHelpResourceClicked,
|
||||
trackAddApiCreditButtonClicked: mockTrackAddApiCreditButtonClicked
|
||||
}
|
||||
: null
|
||||
}))
|
||||
|
||||
@@ -69,6 +77,9 @@ describe('useSubscriptionActions', () => {
|
||||
const { handleAddApiCredits } = useSubscriptionActions()
|
||||
handleAddApiCredits()
|
||||
expect(mockShowTopUpCreditsDialog).toHaveBeenCalledOnce()
|
||||
expect(mockTrackAddApiCreditButtonClicked).toHaveBeenCalledWith({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -21,6 +21,9 @@ export function useSubscriptionActions() {
|
||||
})
|
||||
|
||||
const handleAddApiCredits = () => {
|
||||
telemetry?.trackAddApiCreditButtonClicked({
|
||||
source: 'settings_billing_panel'
|
||||
})
|
||||
void dialogService.showTopUpCreditsDialog()
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import { useSubscriptionDialog } from './useSubscriptionDialog'
|
||||
const mockCloseDialog = vi.fn()
|
||||
const mockShowLayoutDialog = vi.fn()
|
||||
const mockShowTeamWorkspacesDialog = vi.fn()
|
||||
const mockTrackSubscription = vi.hoisted(() => vi.fn())
|
||||
const mockIsInPersonalWorkspace = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsFreeTier = vi.hoisted(() => ({ value: false }))
|
||||
const mockTier = vi.hoisted(() => ({ value: 'FREE' as string | null }))
|
||||
const mockTeamWorkspacesEnabled = vi.hoisted(() => ({ value: false }))
|
||||
const mockIsCloud = vi.hoisted(() => ({ value: true }))
|
||||
const mockIsLegacyTeamPlan = vi.hoisted(() => ({ value: false }))
|
||||
@@ -60,10 +62,15 @@ vi.mock('@/platform/workspace/stores/teamWorkspaceStore', () => ({
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: () => ({
|
||||
isFreeTier: mockIsFreeTier,
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan
|
||||
isLegacyTeamPlan: mockIsLegacyTeamPlan,
|
||||
tier: mockTier
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackSubscription: mockTrackSubscription })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useWorkspaceUI', () => ({
|
||||
useWorkspaceUI: () => ({
|
||||
permissions: {
|
||||
@@ -80,6 +87,7 @@ describe('useSubscriptionDialog', () => {
|
||||
mockIsCloud.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
mockIsFreeTier.value = false
|
||||
mockTier.value = 'FREE'
|
||||
mockTeamWorkspacesEnabled.value = false
|
||||
mockIsLegacyTeamPlan.value = false
|
||||
mockCanManageSubscription.value = true
|
||||
@@ -198,6 +206,51 @@ describe('useSubscriptionDialog', () => {
|
||||
const props = mockShowLayoutDialog.mock.calls[0][0].props
|
||||
expect(props.initialPlanMode).toBe('team')
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the caller reason and current tier', () => {
|
||||
mockTier.value = 'STANDARD'
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith('modal_opened', {
|
||||
current_tier: 'standard',
|
||||
reason: 'upgrade_to_add_credits'
|
||||
})
|
||||
})
|
||||
|
||||
it('tracks modal_opened on the workspace (unified) path too', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'subscribe_to_run' })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not track modal_opened for the inactive member dialog', () => {
|
||||
mockTeamWorkspacesEnabled.value = true
|
||||
mockIsInPersonalWorkspace.value = false
|
||||
mockCanManageSubscription.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockShowLayoutDialog).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not track on non-cloud', () => {
|
||||
mockIsCloud.value = false
|
||||
const { showPricingTable } = useSubscriptionDialog()
|
||||
|
||||
showPricingTable({ reason: 'subscribe_to_run' })
|
||||
|
||||
expect(mockTrackSubscription).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('show', () => {
|
||||
@@ -235,6 +288,20 @@ describe('useSubscriptionDialog', () => {
|
||||
expect.objectContaining({ key: 'subscription-required' })
|
||||
)
|
||||
})
|
||||
|
||||
it('tracks modal_opened with the reason for the free-tier dialog', () => {
|
||||
mockIsFreeTier.value = true
|
||||
mockIsInPersonalWorkspace.value = true
|
||||
const { show } = useSubscriptionDialog()
|
||||
|
||||
show({ reason: 'out_of_credits' })
|
||||
|
||||
expect(mockTrackSubscription).toHaveBeenCalledTimes(1)
|
||||
expect(mockTrackSubscription).toHaveBeenCalledWith(
|
||||
'modal_opened',
|
||||
expect.objectContaining({ reason: 'out_of_credits' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('startTeamWorkspaceUpgradeFlow', () => {
|
||||
|
||||
@@ -4,6 +4,8 @@ import { useDialogStore } from '@/stores/dialogStore'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useWorkspaceUI } from '@/platform/workspace/composables/useWorkspaceUI'
|
||||
import { useTeamWorkspaceStore } from '@/platform/workspace/stores/teamWorkspaceStore'
|
||||
|
||||
@@ -11,14 +13,8 @@ const DIALOG_KEY = 'subscription-required'
|
||||
const FREE_TIER_DIALOG_KEY = 'free-tier-info'
|
||||
const RESUME_PRICING_KEY = 'comfy:resume-team-pricing'
|
||||
|
||||
export type SubscriptionDialogReason =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
|
||||
interface SubscriptionDialogOptions {
|
||||
reason?: SubscriptionDialogReason
|
||||
export interface SubscriptionDialogOptions {
|
||||
reason?: PaymentIntentSource
|
||||
/**
|
||||
* Forces the unified pricing dialog to open on a specific plan tab,
|
||||
* overriding the workspace-derived default (e.g. an "Upgrade to Team" CTA
|
||||
@@ -38,6 +34,17 @@ export const useSubscriptionDialog = () => {
|
||||
dialogStore.closeDialog({ key: FREE_TIER_DIALOG_KEY })
|
||||
}
|
||||
|
||||
// Fired here — the choke point every paywall/pricing dialog variant passes
|
||||
// through — so both the legacy and workspace billing paths emit it.
|
||||
function trackModalOpened(reason?: PaymentIntentSource) {
|
||||
// Resolved lazily to avoid the useBillingContext import cycle (see below).
|
||||
const { tier } = useBillingContext()
|
||||
useTelemetry()?.trackSubscription('modal_opened', {
|
||||
current_tier: tier.value?.toLowerCase(),
|
||||
reason
|
||||
})
|
||||
}
|
||||
|
||||
function showPricingTable(options?: SubscriptionDialogOptions) {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -71,6 +78,8 @@ export const useSubscriptionDialog = () => {
|
||||
return
|
||||
}
|
||||
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
// Shared dialog shell styling for both variants.
|
||||
const dialogComponentProps = {
|
||||
style: 'width: min(1328px, 95vw); max-height: 958px;',
|
||||
@@ -167,6 +176,8 @@ export const useSubscriptionDialog = () => {
|
||||
// (not at composable setup) to avoid the useBillingContext import cycle.
|
||||
const { isFreeTier } = useBillingContext()
|
||||
if (isFreeTier.value && workspaceStore.isInPersonalWorkspace) {
|
||||
trackModalOpened(options?.reason)
|
||||
|
||||
const component = defineAsyncComponent(
|
||||
() =>
|
||||
import('@/platform/cloud/subscription/components/FreeTierDialogContent.vue')
|
||||
@@ -236,7 +247,7 @@ export const useSubscriptionDialog = () => {
|
||||
sessionStorage.removeItem(RESUME_PRICING_KEY)
|
||||
|
||||
if (!workspaceStore.isInPersonalWorkspace) {
|
||||
showPricingTable()
|
||||
showPricingTable({ reason: 'team_upgrade_resume' })
|
||||
}
|
||||
} catch {
|
||||
// sessionStorage may be unavailable
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
clearPendingSubscriptionCheckoutAttempt,
|
||||
consumePendingSubscriptionCheckoutSuccess,
|
||||
recordPendingSubscriptionCheckoutAttempt
|
||||
} from './subscriptionCheckoutTracker'
|
||||
|
||||
const activeProStatus = {
|
||||
is_active: true,
|
||||
subscription_tier: 'PRO',
|
||||
subscription_duration: 'MONTHLY'
|
||||
} as const
|
||||
|
||||
describe('subscriptionCheckoutTracker', () => {
|
||||
beforeEach(() => {
|
||||
clearPendingSubscriptionCheckoutAttempt()
|
||||
})
|
||||
|
||||
it('round-trips payment_intent_source from attempt to success metadata', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).toMatchObject({
|
||||
tier: 'pro',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('omits payment_intent_source when the attempt had none', () => {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
})
|
||||
|
||||
const metadata = consumePendingSubscriptionCheckoutSuccess(activeProStatus)
|
||||
|
||||
expect(metadata).not.toBeNull()
|
||||
expect(metadata).not.toHaveProperty('payment_intent_source')
|
||||
})
|
||||
})
|
||||
@@ -7,7 +7,12 @@ import type {
|
||||
TierKey
|
||||
} from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { SubscriptionSuccessMetadata } from '@/platform/telemetry/types'
|
||||
import type {
|
||||
BeginCheckoutMetadata,
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutType,
|
||||
SubscriptionSuccessMetadata
|
||||
} from '@/platform/telemetry/types'
|
||||
|
||||
const PENDING_SUBSCRIPTION_CHECKOUT_MAX_AGE_MS = 6 * 60 * 60 * 1000
|
||||
const VALID_TIER_KEYS = new Set<TierKey>([
|
||||
@@ -23,7 +28,6 @@ export const PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY =
|
||||
export const PENDING_SUBSCRIPTION_CHECKOUT_EVENT =
|
||||
'comfy:subscription-checkout-attempt-changed'
|
||||
|
||||
type CheckoutType = 'new' | 'change'
|
||||
type SubscriptionDuration = 'MONTHLY' | 'ANNUAL'
|
||||
|
||||
interface SubscriptionStatusSnapshot {
|
||||
@@ -32,22 +36,24 @@ interface SubscriptionStatusSnapshot {
|
||||
subscription_duration?: SubscriptionDuration | null
|
||||
}
|
||||
|
||||
interface PendingSubscriptionCheckoutAttempt {
|
||||
export interface PendingSubscriptionCheckoutAttempt {
|
||||
attempt_id: string
|
||||
started_at_ms: number
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface RecordPendingSubscriptionCheckoutAttemptInput {
|
||||
interface PendingSubscriptionCheckoutAttemptInput {
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: CheckoutType
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
previous_cycle?: BillingCycle
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
const dispatchPendingCheckoutChangeEvent = () => {
|
||||
@@ -168,6 +174,9 @@ const normalizeAttempt = (
|
||||
...(candidate.previous_cycle === 'monthly' ||
|
||||
candidate.previous_cycle === 'yearly'
|
||||
? { previous_cycle: candidate.previous_cycle }
|
||||
: {}),
|
||||
...(typeof candidate.payment_intent_source === 'string'
|
||||
? { payment_intent_source: candidate.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
@@ -224,20 +233,27 @@ const getPendingSubscriptionCheckoutAttempt =
|
||||
export const hasPendingSubscriptionCheckoutAttempt = (): boolean =>
|
||||
getPendingSubscriptionCheckoutAttempt() !== null
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: RecordPendingSubscriptionCheckoutAttemptInput
|
||||
export const createPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
const attempt: PendingSubscriptionCheckoutAttempt = {
|
||||
return {
|
||||
attempt_id: createAttemptId(),
|
||||
started_at_ms: Date.now(),
|
||||
tier: input.tier,
|
||||
cycle: input.cycle,
|
||||
checkout_type: input.checkout_type,
|
||||
...(input.previous_tier ? { previous_tier: input.previous_tier } : {}),
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {})
|
||||
...(input.previous_cycle ? { previous_cycle: input.previous_cycle } : {}),
|
||||
...(input.payment_intent_source
|
||||
? { payment_intent_source: input.payment_intent_source }
|
||||
: {})
|
||||
}
|
||||
}
|
||||
|
||||
export const persistPendingSubscriptionCheckoutAttempt = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): PendingSubscriptionCheckoutAttempt => {
|
||||
const storage = getStorage()
|
||||
if (!storage) {
|
||||
return attempt
|
||||
}
|
||||
@@ -255,6 +271,21 @@ export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
return attempt
|
||||
}
|
||||
|
||||
export const recordPendingSubscriptionCheckoutAttempt = (
|
||||
input: PendingSubscriptionCheckoutAttemptInput
|
||||
): PendingSubscriptionCheckoutAttempt =>
|
||||
persistPendingSubscriptionCheckoutAttempt(
|
||||
createPendingSubscriptionCheckoutAttempt(input)
|
||||
)
|
||||
|
||||
export const withPendingCheckoutAttemptId = (
|
||||
metadata: BeginCheckoutMetadata,
|
||||
attempt: PendingSubscriptionCheckoutAttempt
|
||||
): BeginCheckoutMetadata => ({
|
||||
...metadata,
|
||||
checkout_attempt_id: attempt.attempt_id
|
||||
})
|
||||
|
||||
const didAttemptSucceed = (
|
||||
attempt: PendingSubscriptionCheckoutAttempt,
|
||||
status: SubscriptionStatusSnapshot
|
||||
@@ -287,6 +318,9 @@ export const consumePendingSubscriptionCheckoutSuccess = (
|
||||
cycle: attempt.cycle,
|
||||
checkout_type: attempt.checkout_type,
|
||||
...(attempt.previous_tier ? { previous_tier: attempt.previous_tier } : {}),
|
||||
...(attempt.payment_intent_source
|
||||
? { payment_intent_source: attempt.payment_intent_source }
|
||||
: {}),
|
||||
value,
|
||||
currency: 'USD',
|
||||
ecommerce: {
|
||||
|
||||
@@ -132,13 +132,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'yearly', true)
|
||||
await performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String),
|
||||
ga_client_id: 'ga-client-id',
|
||||
ga_session_id: 'ga-session-id',
|
||||
ga_session_number: 'ga-session-number',
|
||||
@@ -150,6 +151,12 @@ describe('performSubscriptionCheckout', () => {
|
||||
gbraid: 'gbraid-456',
|
||||
wbraid: 'wbraid-789'
|
||||
})
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
JSON.parse(storedAttempt).attempt_id
|
||||
)
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'/customers/cloud-subscription-checkout/pro-yearly'
|
||||
@@ -186,7 +193,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(warnSpy).toHaveBeenCalledWith(
|
||||
'[SubscriptionCheckout] Failed to collect checkout attribution',
|
||||
@@ -203,11 +210,43 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-123',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('carries the payment intent source into begin_checkout and the pending attempt', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
.spyOn(window, 'open')
|
||||
.mockImplementation(() => window as unknown as Window)
|
||||
|
||||
vi.mocked(global.fetch).mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', {
|
||||
paymentIntentSource: 'out_of_credits'
|
||||
})
|
||||
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ payment_intent_source: 'out_of_credits' })
|
||||
)
|
||||
const beginCheckoutMetadata =
|
||||
mockTelemetry.trackBeginCheckout.mock.calls[0][0]
|
||||
const [, storedAttempt] = mockLocalStorage.setItem.mock.calls[0]
|
||||
const pendingAttempt = JSON.parse(storedAttempt)
|
||||
expect(pendingAttempt).toMatchObject({
|
||||
payment_intent_source: 'out_of_credits'
|
||||
})
|
||||
expect(beginCheckoutMetadata.checkout_attempt_id).toBe(
|
||||
pendingAttempt.attempt_id
|
||||
)
|
||||
openSpy.mockRestore()
|
||||
})
|
||||
|
||||
it('uses the latest userId when it changes after checkout starts', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi
|
||||
@@ -222,7 +261,7 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly', true)
|
||||
const checkoutPromise = performSubscriptionCheckout('pro', 'yearly')
|
||||
|
||||
mockUserId.value = 'user-late'
|
||||
authHeader.resolve({ Authorization: 'Bearer test-token' })
|
||||
@@ -235,13 +274,14 @@ describe('performSubscriptionCheckout', () => {
|
||||
user_id: 'user-late',
|
||||
tier: 'pro',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new'
|
||||
checkout_type: 'new',
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
})
|
||||
|
||||
it('does not persist a pending attempt when the checkout popup is blocked', async () => {
|
||||
it('does not persist the pending attempt when the checkout popup is blocked', async () => {
|
||||
const checkoutUrl = 'https://checkout.stripe.com/test'
|
||||
const openSpy = vi.spyOn(window, 'open').mockImplementation(() => null)
|
||||
|
||||
@@ -250,11 +290,18 @@ describe('performSubscriptionCheckout', () => {
|
||||
json: async () => ({ checkout_url: checkoutUrl })
|
||||
} as Response)
|
||||
|
||||
await performSubscriptionCheckout('pro', 'monthly', true)
|
||||
await performSubscriptionCheckout('pro', 'monthly')
|
||||
|
||||
expect(openSpy).toHaveBeenCalledWith(checkoutUrl, '_blank')
|
||||
expect(
|
||||
window.localStorage.getItem(PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY)
|
||||
).toBeNull()
|
||||
const storedAttempt = window.localStorage.getItem(
|
||||
PENDING_SUBSCRIPTION_CHECKOUT_STORAGE_KEY
|
||||
)
|
||||
expect(storedAttempt).toBeNull()
|
||||
expect(mockLocalStorage.setItem).not.toHaveBeenCalled()
|
||||
expect(mockTelemetry.trackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
checkout_attempt_id: expect.any(String)
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -4,12 +4,19 @@ import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { getComfyApiBaseUrl } from '@/config/comfyApi'
|
||||
import { t } from '@/i18n'
|
||||
import { fetchWithUnifiedRemint } from '@/platform/auth/unified/remintRetry'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import {
|
||||
createPendingSubscriptionCheckoutAttempt,
|
||||
persistPendingSubscriptionCheckoutAttempt,
|
||||
withPendingCheckoutAttemptId
|
||||
} from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
CheckoutAttributionMetadata,
|
||||
PaymentIntentSource
|
||||
} from '@/platform/telemetry/types'
|
||||
import { AuthStoreError, useAuthStore } from '@/stores/authStore'
|
||||
import type { CheckoutAttributionMetadata } from '@/platform/telemetry/types'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import { recordPendingSubscriptionCheckoutAttempt } from '@/platform/cloud/subscription/utils/subscriptionCheckoutTracker'
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
type CheckoutTier = TierKey | `${TierKey}-yearly`
|
||||
@@ -31,6 +38,11 @@ const getCheckoutAttributionForCloud =
|
||||
return getCheckoutAttribution()
|
||||
}
|
||||
|
||||
interface PerformSubscriptionCheckoutOptions {
|
||||
openInNewTab?: boolean
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Core subscription checkout logic shared between PricingTable and
|
||||
* SubscriptionRedirectView. Handles:
|
||||
@@ -47,10 +59,12 @@ const getCheckoutAttributionForCloud =
|
||||
export async function performSubscriptionCheckout(
|
||||
tierKey: TierKey,
|
||||
currentBillingCycle: BillingCycle,
|
||||
openInNewTab: boolean = true
|
||||
options: PerformSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
const { openInNewTab = true, paymentIntentSource } = options
|
||||
|
||||
const authStore = useAuthStore()
|
||||
const { userId } = storeToRefs(authStore)
|
||||
const telemetry = useTelemetry()
|
||||
@@ -108,14 +122,29 @@ export async function performSubscriptionCheckout(
|
||||
const data = await response.json()
|
||||
|
||||
if (data.checkout_url) {
|
||||
const pendingAttempt = createPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: paymentIntentSource
|
||||
})
|
||||
|
||||
if (userId.value) {
|
||||
telemetry?.trackBeginCheckout({
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...checkoutAttribution
|
||||
})
|
||||
telemetry?.trackBeginCheckout(
|
||||
withPendingCheckoutAttemptId(
|
||||
{
|
||||
user_id: userId.value,
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new',
|
||||
...(paymentIntentSource
|
||||
? { payment_intent_source: paymentIntentSource }
|
||||
: {}),
|
||||
...checkoutAttribution
|
||||
},
|
||||
pendingAttempt
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
if (openInNewTab) {
|
||||
@@ -123,18 +152,9 @@ export async function performSubscriptionCheckout(
|
||||
if (!checkoutWindow) {
|
||||
return
|
||||
}
|
||||
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
} else {
|
||||
recordPendingSubscriptionCheckoutAttempt({
|
||||
tier: tierKey,
|
||||
cycle: currentBillingCycle,
|
||||
checkout_type: 'new'
|
||||
})
|
||||
persistPendingSubscriptionCheckoutAttempt(pendingAttempt)
|
||||
globalThis.location.href = data.checkout_url
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed, reactive } from 'vue'
|
||||
|
||||
const { mockIsCloud, mockSubscribe } = vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn()
|
||||
}))
|
||||
const { mockIsCloud, mockSubscribe, mockTrackBeginCheckout, mockUserId } =
|
||||
vi.hoisted(() => ({
|
||||
mockIsCloud: { value: true },
|
||||
mockSubscribe: vi.fn(),
|
||||
mockTrackBeginCheckout: vi.fn(),
|
||||
mockUserId: { value: 'user-1' as string | null }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
@@ -16,6 +20,12 @@ vi.mock('@/config/comfyApi', () => ({
|
||||
vi.mock('@/platform/workspace/api/workspaceApi', () => ({
|
||||
workspaceApi: { subscribe: mockSubscribe }
|
||||
}))
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackBeginCheckout: mockTrackBeginCheckout })
|
||||
}))
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
|
||||
}))
|
||||
|
||||
import { performTeamSubscriptionCheckout } from './teamSubscriptionCheckoutUtil'
|
||||
|
||||
@@ -43,7 +53,9 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
billing_op_id: 'op_1'
|
||||
})
|
||||
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
await performTeamSubscriptionCheckout('team_700', 'yearly', {
|
||||
paymentIntentSource: 'deep_link'
|
||||
})
|
||||
|
||||
expect(mockSubscribe).toHaveBeenCalledWith('team_per_credit_annual', {
|
||||
returnUrl: 'https://app.test/payment/success',
|
||||
@@ -51,6 +63,14 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
teamCreditStopId: 'team_700'
|
||||
})
|
||||
expect(assignedHref).toBe('https://stripe.test/pay')
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-1',
|
||||
tier: 'team',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op_1',
|
||||
payment_intent_source: 'deep_link'
|
||||
})
|
||||
})
|
||||
|
||||
it('uses the monthly slug and lands in the app when no Stripe step is needed', async () => {
|
||||
@@ -82,6 +102,16 @@ describe('performTeamSubscriptionCheckout', () => {
|
||||
expect(assignedHref).toBeUndefined()
|
||||
})
|
||||
|
||||
it('does not track begin_checkout when subscribe fails', async () => {
|
||||
mockSubscribe.mockRejectedValueOnce(new Error('subscribe failed'))
|
||||
|
||||
await expect(
|
||||
performTeamSubscriptionCheckout('team_700', 'yearly')
|
||||
).rejects.toThrow('subscribe failed')
|
||||
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does nothing off cloud', async () => {
|
||||
mockIsCloud.value = false
|
||||
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import { getComfyPlatformBaseUrl } from '@/config/comfyApi'
|
||||
import { getTeamPlanSlug } from '@/platform/cloud/subscription/constants/teamPlanCreditStops'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { workspaceApi } from '@/platform/workspace/api/workspaceApi'
|
||||
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
|
||||
|
||||
import type { BillingCycle } from './subscriptionTierRank'
|
||||
|
||||
interface PerformTeamSubscriptionCheckoutOptions {
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct team-plan checkout for the marketing `/cloud/subscribe?tier=team` deep
|
||||
* link: subscribes to the per-credit Team plan at the chosen slider stop and
|
||||
@@ -22,7 +28,8 @@ import type { BillingCycle } from './subscriptionTierRank'
|
||||
*/
|
||||
export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId: string,
|
||||
billingCycle: BillingCycle
|
||||
billingCycle: BillingCycle,
|
||||
options: PerformTeamSubscriptionCheckoutOptions = {}
|
||||
): Promise<void> {
|
||||
if (!isCloud) return
|
||||
|
||||
@@ -33,6 +40,14 @@ export async function performTeamSubscriptionCheckout(
|
||||
teamCreditStopId
|
||||
})
|
||||
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: 'team',
|
||||
cycle: billingCycle,
|
||||
checkoutType: 'new',
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource: options.paymentIntentSource
|
||||
})
|
||||
|
||||
if (response.status === 'needs_payment_method') {
|
||||
// A needs_payment_method response without a URL is unusable: surface it to
|
||||
// the caller's error handling rather than silently dropping the user home
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { DownloadApiError } from '@/platform/modelManager/types'
|
||||
|
||||
import {
|
||||
downloadModel,
|
||||
fetchModelMetadata,
|
||||
@@ -9,23 +7,13 @@ import {
|
||||
toBrowsableUrl
|
||||
} from './missingModelDownload'
|
||||
|
||||
const {
|
||||
fetchMock,
|
||||
mockIsDesktop,
|
||||
mockSidebarTabStore,
|
||||
mockStartDownload,
|
||||
mockEnqueue,
|
||||
mockToastAdd,
|
||||
mockFlags
|
||||
} = vi.hoisted(() => ({
|
||||
fetchMock: vi.fn(),
|
||||
mockIsDesktop: { value: false },
|
||||
mockSidebarTabStore: { activeSidebarTabId: null as string | null },
|
||||
mockStartDownload: vi.fn(),
|
||||
mockEnqueue: vi.fn(),
|
||||
mockToastAdd: vi.fn(),
|
||||
mockFlags: { serverSideModelDownloads: false }
|
||||
}))
|
||||
const { fetchMock, mockIsDesktop, mockSidebarTabStore, mockStartDownload } =
|
||||
vi.hoisted(() => ({
|
||||
fetchMock: vi.fn(),
|
||||
mockIsDesktop: { value: false },
|
||||
mockSidebarTabStore: { activeSidebarTabId: null as string | null },
|
||||
mockStartDownload: vi.fn()
|
||||
}))
|
||||
|
||||
vi.stubGlobal('fetch', fetchMock)
|
||||
|
||||
@@ -45,38 +33,12 @@ vi.mock('@/stores/workspace/sidebarTabStore', () => ({
|
||||
useSidebarTabStore: () => mockSidebarTabStore
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useFeatureFlags', () => ({
|
||||
useFeatureFlags: () => ({ flags: mockFlags })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/modelManager/stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => ({ enqueue: mockEnqueue })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: mockToastAdd })
|
||||
}))
|
||||
|
||||
const mockRefreshModelFolder = vi.fn()
|
||||
const mockRefreshMissingModels = vi.fn()
|
||||
|
||||
vi.mock('@/stores/modelStore', () => ({
|
||||
useModelStore: () => ({ refreshModelFolder: mockRefreshModelFolder })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/missingModel/missingModelStore', () => ({
|
||||
useMissingModelStore: () => ({
|
||||
refreshMissingModels: mockRefreshMissingModels
|
||||
})
|
||||
}))
|
||||
|
||||
let testId = 0
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.resetAllMocks()
|
||||
delete window.__comfyDesktop2
|
||||
mockFlags.serverSideModelDownloads = false
|
||||
})
|
||||
|
||||
describe('fetchModelMetadata', () => {
|
||||
@@ -155,26 +117,6 @@ describe('fetchModelMetadata', () => {
|
||||
expect(fetchMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns null metadata when the Civitai request throws', async () => {
|
||||
fetchMock.mockRejectedValueOnce(new Error('network down'))
|
||||
|
||||
const metadata = await fetchModelMetadata(
|
||||
`https://civitai.com/api/download/models/${testId}`
|
||||
)
|
||||
|
||||
expect(metadata).toEqual({ fileSize: null, gatedRepoUrl: null })
|
||||
})
|
||||
|
||||
it('returns null metadata when the HEAD request throws', async () => {
|
||||
fetchMock.mockRejectedValueOnce(new Error('network down'))
|
||||
|
||||
const metadata = await fetchModelMetadata(
|
||||
`https://huggingface.co/org/model/resolve/main/throws-${testId}.safetensors`
|
||||
)
|
||||
|
||||
expect(metadata).toEqual({ fileSize: null, gatedRepoUrl: null })
|
||||
})
|
||||
|
||||
it('returns cached metadata on second call', async () => {
|
||||
const url = `https://huggingface.co/org/model/resolve/main/cached-${testId}.safetensors`
|
||||
|
||||
@@ -298,16 +240,6 @@ describe('isModelDownloadable', () => {
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('allows explicitly whitelisted URLs from an otherwise disallowed host', () => {
|
||||
expect(
|
||||
isModelDownloadable({
|
||||
name: 'RealESRGAN_x4plus.pth',
|
||||
url: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
||||
directory: 'upscale_models'
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('downloadModel', () => {
|
||||
@@ -447,152 +379,6 @@ describe('downloadModel', () => {
|
||||
expect(anchorClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('uses the browser fallback on web when server-side downloads are disabled', () => {
|
||||
const anchorClick = vi
|
||||
.spyOn(HTMLAnchorElement.prototype, 'click')
|
||||
.mockImplementation(() => {})
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
expect(anchorClick).toHaveBeenCalledTimes(1)
|
||||
expect(mockEnqueue).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('enqueues a server-side download and reveals the manager when enabled', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockResolvedValue({ download_id: 'd1', accepted: true })
|
||||
const anchorClick = vi
|
||||
.spyOn(HTMLAnchorElement.prototype, 'click')
|
||||
.mockImplementation(() => {})
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockSidebarTabStore.activeSidebarTabId).toBe('model-manager')
|
||||
})
|
||||
expect(mockEnqueue).toHaveBeenCalledWith({
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
model_id: 'checkpoints/model.safetensors'
|
||||
})
|
||||
expect(anchorClick).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows a toast when a server-side enqueue fails', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockRejectedValue(new Error('boom'))
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error', detail: 'boom' })
|
||||
)
|
||||
})
|
||||
expect(mockSidebarTabStore.activeSidebarTabId).toBeNull()
|
||||
})
|
||||
|
||||
it('reveals the download manager and shows an info toast for an in-progress download', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('exists', 'ALREADY_DOWNLOADING', 409)
|
||||
)
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockSidebarTabStore.activeSidebarTabId).toBe('model-manager')
|
||||
})
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'info',
|
||||
detail: 'model.safetensors'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('refreshes the model folder and re-scans missing models when already available', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
|
||||
)
|
||||
mockRefreshModelFolder.mockResolvedValue(undefined)
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'info',
|
||||
detail: 'model.safetensors'
|
||||
})
|
||||
)
|
||||
})
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshModelFolder).toHaveBeenCalledWith('checkpoints')
|
||||
})
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
expect(mockSidebarTabStore.activeSidebarTabId).toBeNull()
|
||||
})
|
||||
|
||||
it('still re-scans missing models when the post-available folder refresh fails', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
|
||||
)
|
||||
const consoleWarn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
mockRefreshModelFolder.mockRejectedValue(new Error('boom'))
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
})
|
||||
expect(consoleWarn).toHaveBeenCalled()
|
||||
consoleWarn.mockRestore()
|
||||
})
|
||||
|
||||
it('opens the model library sidebar before starting a desktop download', () => {
|
||||
mockIsDesktop.value = true
|
||||
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
import { downloadUrlToHfRepoUrl, isCivitaiModelUrl } from '@/utils/formatUtil'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { t } from '@/i18n'
|
||||
import { isDesktop } from '@/platform/distribution/types'
|
||||
import { useModelDownloadStore } from '@/platform/modelManager/stores/modelDownloadStore'
|
||||
import { DownloadApiError } from '@/platform/modelManager/types'
|
||||
import { buildModelId } from '@/platform/modelManager/utils/modelId'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { useElectronDownloadStore } from '@/stores/electronDownloadStore'
|
||||
import { useSidebarTabStore } from '@/stores/workspace/sidebarTabStore'
|
||||
import type { ComfyDesktop2Bridge } from '@/types'
|
||||
@@ -35,7 +29,6 @@ const WHITE_LISTED_URLS: ReadonlySet<string> = new Set([
|
||||
])
|
||||
|
||||
const MODEL_LIBRARY_TAB_ID = 'model-library'
|
||||
const MODEL_MANAGER_TAB_ID = 'model-manager'
|
||||
|
||||
export interface ModelWithUrl {
|
||||
name: string
|
||||
@@ -54,96 +47,6 @@ async function startDesktop2ModelDownload(
|
||||
}
|
||||
}
|
||||
|
||||
function revealDownloadManager(): void {
|
||||
useSidebarTabStore().activeSidebarTabId = MODEL_MANAGER_TAB_ID
|
||||
}
|
||||
|
||||
/**
|
||||
* Already on disk: surface a confirmation, refresh the model folder, and
|
||||
* re-scan missing models so any node error for this model clears. Loaded
|
||||
* lazily to keep this module's import graph (and its unit tests) light.
|
||||
*/
|
||||
async function refreshAfterModelAvailable(model: ModelWithUrl): Promise<void> {
|
||||
try {
|
||||
const [{ useModelStore }, { useMissingModelStore }] = await Promise.all([
|
||||
import('@/stores/modelStore'),
|
||||
import('@/platform/missingModel/missingModelStore')
|
||||
])
|
||||
if (model.directory) {
|
||||
try {
|
||||
await useModelStore().refreshModelFolder(model.directory)
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'[MissingModel] Failed to refresh model folder after model available',
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
void useMissingModelStore().refreshMissingModels()
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'[MissingModel] Failed to refresh after model available',
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enqueues a server-side download and reveals the Model Manager panel so the
|
||||
* user can watch live progress, status, and completion. The two benign `409`
|
||||
* cases get an info toast: `ALREADY_DOWNLOADING` links to the existing job,
|
||||
* `ALREADY_AVAILABLE` confirms it's installed and clears the node error. Any
|
||||
* other failure is reported via an error toast.
|
||||
*/
|
||||
async function startServerSideModelDownload(
|
||||
model: ModelWithUrl
|
||||
): Promise<void> {
|
||||
const toast = useToastStore()
|
||||
try {
|
||||
await useModelDownloadStore().enqueue({
|
||||
url: model.url,
|
||||
model_id: buildModelId(model.directory, model.name)
|
||||
})
|
||||
revealDownloadManager()
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof DownloadApiError && error.is('ALREADY_DOWNLOADING')) {
|
||||
revealDownloadManager()
|
||||
toast.add({
|
||||
severity: 'info',
|
||||
summary: t('modelManager.alreadyDownloading'),
|
||||
detail: model.name,
|
||||
life: 4000
|
||||
})
|
||||
return
|
||||
}
|
||||
if (error instanceof DownloadApiError && error.is('ALREADY_AVAILABLE')) {
|
||||
toast.add({
|
||||
severity: 'info',
|
||||
summary: t('modelManager.alreadyInstalled'),
|
||||
detail: model.name,
|
||||
life: 4000
|
||||
})
|
||||
void refreshAfterModelAvailable(model)
|
||||
return
|
||||
}
|
||||
toast.add({
|
||||
severity: 'error',
|
||||
summary: t('modelManager.actionFailed'),
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
life: 5000
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function startBrowserModelDownload(model: ModelWithUrl): void {
|
||||
const link = document.createElement('a')
|
||||
link.href = model.url
|
||||
link.download = model.name
|
||||
link.target = '_blank'
|
||||
link.rel = 'noopener noreferrer'
|
||||
link.click()
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a model download URL to a browsable page URL.
|
||||
* - HuggingFace: `/resolve/` → `/blob/` (file page with model info)
|
||||
@@ -179,11 +82,12 @@ export function downloadModel(
|
||||
}
|
||||
|
||||
if (!isDesktop) {
|
||||
if (useFeatureFlags().flags.serverSideModelDownloads) {
|
||||
void startServerSideModelDownload(model)
|
||||
} else {
|
||||
startBrowserModelDownload(model)
|
||||
}
|
||||
const link = document.createElement('a')
|
||||
link.href = model.url
|
||||
link.download = model.name
|
||||
link.target = '_blank'
|
||||
link.rel = 'noopener noreferrer'
|
||||
link.click()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { api } from '@/scripts/api'
|
||||
|
||||
import { DownloadApiError } from '../types'
|
||||
import {
|
||||
cancelDownload,
|
||||
checkAvailability,
|
||||
clearDownloads,
|
||||
deleteCredential,
|
||||
deleteDownload,
|
||||
enqueueDownload,
|
||||
listCredentials,
|
||||
listDownloads,
|
||||
pauseDownload,
|
||||
resumeDownload,
|
||||
setDownloadPriority,
|
||||
upsertCredential
|
||||
} from './modelDownloadApi'
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
fetchApi: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
const fetchApi = vi.mocked(api.fetchApi)
|
||||
|
||||
function jsonResponse(status: number, body: unknown): Response {
|
||||
return {
|
||||
ok: status >= 200 && status < 300,
|
||||
status,
|
||||
statusText: `status ${status}`,
|
||||
json: () => Promise.resolve(body)
|
||||
} as unknown as Response
|
||||
}
|
||||
|
||||
function nonJsonErrorResponse(status: number, statusText: string): Response {
|
||||
return {
|
||||
ok: false,
|
||||
status,
|
||||
statusText,
|
||||
json: () => Promise.reject(new Error('not json'))
|
||||
} as unknown as Response
|
||||
}
|
||||
|
||||
describe('modelDownloadApi', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('enqueueDownload', () => {
|
||||
it('returns the enqueue response on 202', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(202, { download_id: 'd1', accepted: true })
|
||||
)
|
||||
|
||||
const result = await enqueueDownload({
|
||||
url: 'https://huggingface.co/x.safetensors',
|
||||
model_id: 'loras/x.safetensors'
|
||||
})
|
||||
|
||||
expect(result).toEqual({ download_id: 'd1', accepted: true })
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/enqueue',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
|
||||
it('throws a DownloadApiError carrying the error code on 409', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(409, {
|
||||
error: { code: 'ALREADY_DOWNLOADING', message: 'exists' }
|
||||
})
|
||||
)
|
||||
|
||||
await expect(
|
||||
enqueueDownload({ url: 'u', model_id: 'loras/x.safetensors' })
|
||||
).rejects.toMatchObject({
|
||||
code: 'ALREADY_DOWNLOADING',
|
||||
status: 409,
|
||||
message: 'exists'
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to statusText and an UNKNOWN code for a non-JSON error body', async () => {
|
||||
fetchApi.mockResolvedValue(nonJsonErrorResponse(502, 'Bad Gateway'))
|
||||
|
||||
await expect(
|
||||
enqueueDownload({ url: 'u', model_id: 'loras/x.safetensors' })
|
||||
).rejects.toMatchObject({
|
||||
message: 'Bad Gateway',
|
||||
code: 'UNKNOWN',
|
||||
status: 502
|
||||
})
|
||||
})
|
||||
|
||||
it('exposes the code through DownloadApiError.is()', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(400, {
|
||||
error: { code: 'URL_NOT_ALLOWED', message: 'nope' }
|
||||
})
|
||||
)
|
||||
|
||||
const error = await enqueueDownload({
|
||||
url: 'u',
|
||||
model_id: 'loras/x.safetensors'
|
||||
}).catch((e) => e)
|
||||
|
||||
expect(error).toBeInstanceOf(DownloadApiError)
|
||||
expect(error.is('URL_NOT_ALLOWED')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('listDownloads', () => {
|
||||
it('unwraps the downloads array', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(200, { downloads: [{ download_id: 'd1' }] })
|
||||
)
|
||||
|
||||
const result = await listDownloads()
|
||||
|
||||
expect(result).toEqual([{ download_id: 'd1' }])
|
||||
})
|
||||
})
|
||||
|
||||
describe('actions', () => {
|
||||
it('posts to the pause route', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
|
||||
|
||||
await pauseDownload('d1')
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1/pause',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
|
||||
it('posts to the resume route', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
|
||||
|
||||
await resumeDownload('d1')
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1/resume',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
|
||||
it('posts to the cancel route', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
|
||||
|
||||
await cancelDownload('d1')
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1/cancel',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
|
||||
it('sends the priority in the body', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
|
||||
|
||||
await setDownloadPriority('d1', 5)
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1/priority',
|
||||
expect.objectContaining({ body: JSON.stringify({ priority: 5 }) })
|
||||
)
|
||||
})
|
||||
|
||||
it('sends a DELETE to the download route', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { deleted: true }))
|
||||
|
||||
await deleteDownload('d1')
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1',
|
||||
expect.objectContaining({ method: 'DELETE' })
|
||||
)
|
||||
})
|
||||
|
||||
it('posts to the clear route and returns the deleted count', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { deleted: 3 }))
|
||||
|
||||
const result = await clearDownloads()
|
||||
|
||||
expect(result).toBe(3)
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/clear',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('checkAvailability', () => {
|
||||
it('wraps the models map in the request body', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { models: {} }))
|
||||
|
||||
await checkAvailability({ 'loras/x.safetensors': 'https://h.co/x' })
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/availability',
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
models: { 'loras/x.safetensors': 'https://h.co/x' }
|
||||
})
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('credentials', () => {
|
||||
it('unwraps the credentials list', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(200, { credentials: [{ id: 'c1' }] })
|
||||
)
|
||||
|
||||
expect(await listCredentials()).toEqual([{ id: 'c1' }])
|
||||
})
|
||||
|
||||
it('returns the created credential view', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(201, { id: 'c1', host: 'h' }))
|
||||
|
||||
const result = await upsertCredential({ host: 'h', secret: 's' })
|
||||
|
||||
expect(result).toEqual({ id: 'c1', host: 'h' })
|
||||
})
|
||||
|
||||
it('throws on a failed delete', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(404, { error: { code: 'NOT_FOUND', message: 'gone' } })
|
||||
)
|
||||
|
||||
await expect(deleteCredential('c1')).rejects.toMatchObject({
|
||||
code: 'NOT_FOUND'
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,121 +0,0 @@
|
||||
import { api } from '@/scripts/api'
|
||||
|
||||
import type {
|
||||
AvailabilityResponse,
|
||||
DownloadStatus,
|
||||
EnqueueRequest,
|
||||
EnqueueResponse,
|
||||
HostCredentialUpsert,
|
||||
HostCredentialView
|
||||
} from '../types'
|
||||
import { DownloadApiError } from '../types'
|
||||
|
||||
const BASE = '/download'
|
||||
|
||||
interface ErrorEnvelope {
|
||||
error?: {
|
||||
code?: string
|
||||
message?: string
|
||||
details?: Record<string, unknown>
|
||||
}
|
||||
}
|
||||
|
||||
async function throwFromResponse(response: Response): Promise<never> {
|
||||
let body: ErrorEnvelope = {}
|
||||
try {
|
||||
body = await response.json()
|
||||
} catch {
|
||||
// Non-JSON error body
|
||||
}
|
||||
const error = body.error
|
||||
throw new DownloadApiError(
|
||||
error?.message ?? response.statusText,
|
||||
error?.code ?? 'UNKNOWN',
|
||||
response.status,
|
||||
error?.details
|
||||
)
|
||||
}
|
||||
|
||||
async function parseJson<T>(response: Response): Promise<T> {
|
||||
if (!response.ok) return throwFromResponse(response)
|
||||
return response.json() as Promise<T>
|
||||
}
|
||||
|
||||
function postJson(route: string, body: unknown): Promise<Response> {
|
||||
return api.fetchApi(route, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body)
|
||||
})
|
||||
}
|
||||
|
||||
export async function enqueueDownload(
|
||||
body: EnqueueRequest
|
||||
): Promise<EnqueueResponse> {
|
||||
const response = await postJson(`${BASE}/enqueue`, body)
|
||||
if (response.status === 202) {
|
||||
return response.json() as Promise<EnqueueResponse>
|
||||
}
|
||||
return throwFromResponse(response)
|
||||
}
|
||||
|
||||
export async function listDownloads(): Promise<DownloadStatus[]> {
|
||||
const response = await api.fetchApi(BASE)
|
||||
const data = await parseJson<{ downloads: DownloadStatus[] }>(response)
|
||||
return data.downloads
|
||||
}
|
||||
|
||||
async function postAction(id: string, action: string): Promise<void> {
|
||||
const response = await postJson(`${BASE}/${id}/${action}`, undefined)
|
||||
await parseJson<{ ok: boolean }>(response)
|
||||
}
|
||||
|
||||
export const pauseDownload = (id: string) => postAction(id, 'pause')
|
||||
export const resumeDownload = (id: string) => postAction(id, 'resume')
|
||||
export const cancelDownload = (id: string) => postAction(id, 'cancel')
|
||||
|
||||
export async function deleteDownload(id: string): Promise<void> {
|
||||
const response = await api.fetchApi(`${BASE}/${id}`, { method: 'DELETE' })
|
||||
await parseJson<{ deleted: boolean }>(response)
|
||||
}
|
||||
|
||||
export async function clearDownloads(): Promise<number> {
|
||||
const response = await postJson(`${BASE}/clear`, undefined)
|
||||
const data = await parseJson<{ deleted: number }>(response)
|
||||
return data.deleted
|
||||
}
|
||||
|
||||
export async function setDownloadPriority(
|
||||
id: string,
|
||||
priority: number
|
||||
): Promise<void> {
|
||||
const response = await postJson(`${BASE}/${id}/priority`, { priority })
|
||||
await parseJson<{ ok: boolean }>(response)
|
||||
}
|
||||
|
||||
export async function checkAvailability(
|
||||
models: Record<string, string>
|
||||
): Promise<AvailabilityResponse> {
|
||||
const response = await postJson(`${BASE}/availability`, { models })
|
||||
return parseJson<AvailabilityResponse>(response)
|
||||
}
|
||||
|
||||
export async function listCredentials(): Promise<HostCredentialView[]> {
|
||||
const response = await api.fetchApi(`${BASE}/credentials`)
|
||||
const data = await parseJson<{ credentials: HostCredentialView[] }>(response)
|
||||
return data.credentials
|
||||
}
|
||||
|
||||
export async function upsertCredential(
|
||||
body: HostCredentialUpsert
|
||||
): Promise<HostCredentialView> {
|
||||
const response = await postJson(`${BASE}/credentials`, body)
|
||||
return parseJson<HostCredentialView>(response)
|
||||
}
|
||||
|
||||
export async function deleteCredential(id: string): Promise<void> {
|
||||
const response = await api.fetchApi(`${BASE}/credentials/${id}`, {
|
||||
method: 'DELETE'
|
||||
})
|
||||
await parseJson<{ deleted: boolean }>(response)
|
||||
}
|
||||
@@ -1,253 +0,0 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { defineComponent, ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import { DownloadApiError } from '../types'
|
||||
import AddModelByUrlDialog from './AddModelByUrlDialog.vue'
|
||||
|
||||
const mockEnqueue = vi.fn()
|
||||
const mockToastAdd = vi.fn()
|
||||
const mockFetchModelTypes = vi.fn()
|
||||
const mockModelTypes = ref([
|
||||
{ name: 'Checkpoints', value: 'checkpoints' },
|
||||
{ name: 'LoRA', value: 'loras' }
|
||||
])
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => ({ enqueue: mockEnqueue })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/composables/useModelTypes', () => ({
|
||||
useModelTypes: () => ({
|
||||
modelTypes: mockModelTypes,
|
||||
isLoading: ref(false),
|
||||
fetchModelTypes: mockFetchModelTypes
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: mockToastAdd })
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
const stubs = {
|
||||
Dialog: { template: '<div><slot /></div>' },
|
||||
DialogPortal: { template: '<div><slot /></div>' },
|
||||
DialogOverlay: { template: '<div />' },
|
||||
DialogContent: defineComponent({
|
||||
emits: ['open-auto-focus'],
|
||||
mounted() {
|
||||
this.$emit('open-auto-focus')
|
||||
},
|
||||
template: '<div><slot /></div>'
|
||||
}),
|
||||
DialogHeader: { template: '<div><slot /></div>' },
|
||||
DialogTitle: { template: '<div><slot /></div>' },
|
||||
DialogDescription: { template: '<div><slot /></div>' },
|
||||
DialogFooter: { template: '<div><slot /></div>' },
|
||||
SingleSelect: {
|
||||
props: ['modelValue', 'options', 'label', 'loading'],
|
||||
emits: ['update:modelValue'],
|
||||
template: `
|
||||
<select
|
||||
data-testid="folder-select"
|
||||
:value="modelValue ?? ''"
|
||||
@change="$emit('update:modelValue', $event.target.value)"
|
||||
>
|
||||
<option value="" disabled>{{ label }}</option>
|
||||
<option v-for="opt in options" :key="opt.value" :value="opt.value">
|
||||
{{ opt.name }}
|
||||
</option>
|
||||
</select>
|
||||
`
|
||||
}
|
||||
}
|
||||
|
||||
function mountDialog(open = true) {
|
||||
return render(AddModelByUrlDialog, {
|
||||
props: { open },
|
||||
global: { plugins: [i18n], stubs }
|
||||
})
|
||||
}
|
||||
|
||||
async function fillValidForm() {
|
||||
await userEvent.type(
|
||||
screen.getByLabelText('URL'),
|
||||
'https://huggingface.co/org/model/resolve/main/model.safetensors'
|
||||
)
|
||||
await userEvent.selectOptions(
|
||||
screen.getByTestId('folder-select'),
|
||||
'checkpoints'
|
||||
)
|
||||
}
|
||||
|
||||
describe('AddModelByUrlDialog', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('auto-fills the filename from the url until edited manually', async () => {
|
||||
mountDialog()
|
||||
const urlInput = screen.getByLabelText('URL')
|
||||
const filenameInput = screen.getByLabelText('Filename')
|
||||
|
||||
await userEvent.type(
|
||||
urlInput,
|
||||
'https://huggingface.co/org/model/resolve/main/model.safetensors'
|
||||
)
|
||||
expect(filenameInput).toHaveValue('model.safetensors')
|
||||
|
||||
await userEvent.clear(filenameInput)
|
||||
await userEvent.type(filenameInput, 'custom.safetensors')
|
||||
await userEvent.type(urlInput, '?download=true')
|
||||
|
||||
expect(filenameInput).toHaveValue('custom.safetensors')
|
||||
})
|
||||
|
||||
it('shows a hint for a url on a non-allowlisted host', async () => {
|
||||
mountDialog()
|
||||
|
||||
await userEvent.type(
|
||||
screen.getByLabelText('URL'),
|
||||
'https://example.com/model.safetensors'
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText('This host may not be on the download allowlist.')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('disables submit until url, folder, and a valid filename are set', async () => {
|
||||
mountDialog()
|
||||
const submit = screen.getByRole('button', { name: 'Download' })
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.type(
|
||||
screen.getByLabelText('URL'),
|
||||
'https://huggingface.co/org/model/resolve/main/model.safetensors'
|
||||
)
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.selectOptions(
|
||||
screen.getByTestId('folder-select'),
|
||||
'checkpoints'
|
||||
)
|
||||
expect(submit).toBeEnabled()
|
||||
})
|
||||
|
||||
it('requires a known model extension', async () => {
|
||||
mountDialog()
|
||||
await userEvent.type(
|
||||
screen.getByLabelText('URL'),
|
||||
'https://huggingface.co/org/model/resolve/main/readme.txt'
|
||||
)
|
||||
await userEvent.selectOptions(
|
||||
screen.getByTestId('folder-select'),
|
||||
'checkpoints'
|
||||
)
|
||||
const submit = screen.getByRole('button', { name: 'Download' })
|
||||
expect(submit).toBeDisabled()
|
||||
})
|
||||
|
||||
it('closes without submitting when cancel is clicked', async () => {
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Cancel' }))
|
||||
|
||||
expect(mockEnqueue).not.toHaveBeenCalled()
|
||||
expect(emitted('update:open')?.at(-1)).toEqual([false])
|
||||
})
|
||||
|
||||
it('enqueues the download and closes on success', async () => {
|
||||
mockEnqueue.mockResolvedValue({ download_id: 'd1', accepted: true })
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
expect(mockEnqueue).toHaveBeenCalledWith({
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
model_id: 'checkpoints/model.safetensors',
|
||||
allow_any_extension: false
|
||||
})
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
})
|
||||
expect(emitted('update:open')?.at(-1)).toEqual([false])
|
||||
})
|
||||
|
||||
it('shows an info toast and closes when the model is already downloading', async () => {
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('exists', 'ALREADY_DOWNLOADING', 409)
|
||||
)
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'info' })
|
||||
)
|
||||
})
|
||||
expect(emitted('update:open')?.at(-1)).toEqual([false])
|
||||
})
|
||||
|
||||
it('shows an info toast and closes when the model is already installed', async () => {
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
|
||||
)
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'info' })
|
||||
)
|
||||
})
|
||||
expect(emitted('update:open')?.at(-1)).toEqual([false])
|
||||
})
|
||||
|
||||
it('shows an inline error for other API failures and stays open', async () => {
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('url not allowed', 'URL_NOT_ALLOWED', 400)
|
||||
)
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(screen.getByText('url not allowed')).toBeInTheDocument()
|
||||
})
|
||||
expect(emitted('update:open')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('shows a generic error message for a non-api error', async () => {
|
||||
mockEnqueue.mockRejectedValue(new Error('network down'))
|
||||
mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(screen.getByText('network down')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,237 +0,0 @@
|
||||
<template>
|
||||
<Dialog v-model:open="isOpen">
|
||||
<DialogPortal>
|
||||
<DialogOverlay class="bg-black/70" />
|
||||
<DialogContent
|
||||
size="md"
|
||||
class="flex flex-col gap-4 p-6"
|
||||
@open-auto-focus="onOpen"
|
||||
>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{{ $t('modelManager.addModel') }}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{{ $t('modelManager.addModelDescription') }}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="download-url">
|
||||
{{ $t('modelManager.url') }}
|
||||
</label>
|
||||
<Input
|
||||
id="download-url"
|
||||
v-model="url"
|
||||
:placeholder="$t('modelManager.urlPlaceholder')"
|
||||
@update:model-value="onUrlChanged"
|
||||
/>
|
||||
<p v-if="hostHint" class="text-xs text-amber-400">{{ hostHint }}</p>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground">
|
||||
{{ $t('modelManager.folder') }}
|
||||
</label>
|
||||
<SingleSelect
|
||||
v-model="directory"
|
||||
:options="folderOptions"
|
||||
:label="$t('modelManager.selectFolder')"
|
||||
:loading="isLoadingFolders"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="download-filename">
|
||||
{{ $t('modelManager.filename') }}
|
||||
</label>
|
||||
<Input
|
||||
id="download-filename"
|
||||
v-model="filename"
|
||||
:placeholder="$t('modelManager.filenamePlaceholder')"
|
||||
@update:model-value="onFilenameEdited"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- TODO: re-enable once we think we'd want to allow any extension
|
||||
<label class="flex w-fit items-center gap-2 text-xs text-muted-foreground">
|
||||
<input v-model="allowAnyExtension" type="checkbox" class="size-4" />
|
||||
{{ $t('modelManager.allowAnyExtension') }}
|
||||
</label>
|
||||
-->
|
||||
|
||||
<p
|
||||
v-if="modelId"
|
||||
class="truncate rounded-md bg-secondary-background px-2 py-1 text-xs text-muted-foreground"
|
||||
>
|
||||
{{ modelId }}
|
||||
</p>
|
||||
|
||||
<p v-if="errorMessage" class="text-xs text-red-400">
|
||||
{{ errorMessage }}
|
||||
</p>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="secondary" @click="isOpen = false">
|
||||
{{ $t('g.cancel') }}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
:disabled="!canSubmit"
|
||||
:loading="isSubmitting"
|
||||
@click="submit"
|
||||
>
|
||||
{{ $t('modelManager.download') }}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</DialogPortal>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import Dialog from '@/components/ui/dialog/Dialog.vue'
|
||||
import DialogContent from '@/components/ui/dialog/DialogContent.vue'
|
||||
import DialogDescription from '@/components/ui/dialog/DialogDescription.vue'
|
||||
import DialogFooter from '@/components/ui/dialog/DialogFooter.vue'
|
||||
import DialogHeader from '@/components/ui/dialog/DialogHeader.vue'
|
||||
import DialogOverlay from '@/components/ui/dialog/DialogOverlay.vue'
|
||||
import DialogPortal from '@/components/ui/dialog/DialogPortal.vue'
|
||||
import DialogTitle from '@/components/ui/dialog/DialogTitle.vue'
|
||||
import Input from '@/components/ui/input/Input.vue'
|
||||
import SingleSelect from '@/components/ui/single-select/SingleSelect.vue'
|
||||
import { useModelTypes } from '@/platform/assets/composables/useModelTypes'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
import { DownloadApiError } from '../types'
|
||||
import {
|
||||
buildModelId,
|
||||
filenameFromUrl,
|
||||
hasModelExtension,
|
||||
isLikelyAllowedHost,
|
||||
isValidPathSegment
|
||||
} from '../utils/modelId'
|
||||
|
||||
const isOpen = defineModel<boolean>('open', { required: true })
|
||||
|
||||
const { t } = useI18n()
|
||||
const store = useModelDownloadStore()
|
||||
const {
|
||||
modelTypes,
|
||||
isLoading: isLoadingFolders,
|
||||
fetchModelTypes
|
||||
} = useModelTypes()
|
||||
|
||||
const url = ref('')
|
||||
const directory = ref<string | undefined>(undefined)
|
||||
const filename = ref('')
|
||||
const isFilenameUserEdited = ref(false)
|
||||
const allowAnyExtension = ref(false)
|
||||
const isSubmitting = ref(false)
|
||||
const errorMessage = ref('')
|
||||
|
||||
const folderOptions = computed(() => modelTypes.value)
|
||||
|
||||
const modelId = computed(() =>
|
||||
directory.value && filename.value
|
||||
? buildModelId(directory.value, filename.value)
|
||||
: ''
|
||||
)
|
||||
|
||||
const hostHint = computed(() =>
|
||||
url.value && !isLikelyAllowedHost(url.value)
|
||||
? t('modelManager.hostNotAllowedHint')
|
||||
: ''
|
||||
)
|
||||
|
||||
const canSubmit = computed(
|
||||
() =>
|
||||
!!url.value &&
|
||||
!!directory.value &&
|
||||
isValidPathSegment(filename.value) &&
|
||||
(allowAnyExtension.value || hasModelExtension(filename.value))
|
||||
)
|
||||
|
||||
function onOpen() {
|
||||
void fetchModelTypes()
|
||||
errorMessage.value = ''
|
||||
}
|
||||
|
||||
function onUrlChanged() {
|
||||
if (!isFilenameUserEdited.value) {
|
||||
filename.value = filenameFromUrl(url.value)
|
||||
}
|
||||
}
|
||||
|
||||
function onFilenameEdited() {
|
||||
isFilenameUserEdited.value = true
|
||||
}
|
||||
|
||||
function reset() {
|
||||
url.value = ''
|
||||
directory.value = undefined
|
||||
filename.value = ''
|
||||
isFilenameUserEdited.value = false
|
||||
allowAnyExtension.value = false
|
||||
errorMessage.value = ''
|
||||
}
|
||||
|
||||
async function submit() {
|
||||
if (!canSubmit.value) return
|
||||
isSubmitting.value = true
|
||||
errorMessage.value = ''
|
||||
try {
|
||||
await store.enqueue({
|
||||
url: url.value,
|
||||
model_id: modelId.value,
|
||||
allow_any_extension: allowAnyExtension.value
|
||||
})
|
||||
useToastStore().add({
|
||||
severity: 'success',
|
||||
summary: t('modelManager.downloadQueued'),
|
||||
detail: modelId.value,
|
||||
life: 4000
|
||||
})
|
||||
reset()
|
||||
isOpen.value = false
|
||||
} catch (error) {
|
||||
handleEnqueueError(error)
|
||||
} finally {
|
||||
isSubmitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function handleEnqueueError(error: unknown) {
|
||||
if (error instanceof DownloadApiError) {
|
||||
if (error.is('ALREADY_AVAILABLE')) {
|
||||
useToastStore().add({
|
||||
severity: 'info',
|
||||
summary: t('modelManager.alreadyInstalled'),
|
||||
detail: modelId.value,
|
||||
life: 4000
|
||||
})
|
||||
reset()
|
||||
isOpen.value = false
|
||||
return
|
||||
}
|
||||
if (error.is('ALREADY_DOWNLOADING')) {
|
||||
useToastStore().add({
|
||||
severity: 'info',
|
||||
summary: t('modelManager.alreadyDownloading'),
|
||||
detail: modelId.value,
|
||||
life: 4000
|
||||
})
|
||||
reset()
|
||||
isOpen.value = false
|
||||
return
|
||||
}
|
||||
errorMessage.value = error.message
|
||||
return
|
||||
}
|
||||
errorMessage.value =
|
||||
error instanceof Error ? error.message : t('modelManager.actionFailed')
|
||||
}
|
||||
</script>
|
||||
@@ -1,313 +0,0 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { defineComponent, nextTick, reactive, ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { HostCredentialView } from '../types'
|
||||
import HostCredentialsDialog from './HostCredentialsDialog.vue'
|
||||
|
||||
const mockCloseDialog = vi.fn()
|
||||
const mockToastAdd = vi.fn()
|
||||
|
||||
const mockCredentialsStore = reactive({
|
||||
credentials: ref<HostCredentialView[]>([]),
|
||||
isLoading: ref(false),
|
||||
fetchCredentials: vi.fn(),
|
||||
upsert: vi.fn(),
|
||||
remove: vi.fn()
|
||||
})
|
||||
|
||||
vi.mock('../stores/hostCredentialsStore', () => ({
|
||||
useHostCredentialsStore: () => mockCredentialsStore
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/dialogStore', () => ({
|
||||
useDialogStore: () => ({ closeDialog: mockCloseDialog })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: mockToastAdd })
|
||||
}))
|
||||
|
||||
vi.mock('@/components/dialog/confirm/confirmDialog')
|
||||
|
||||
const mockShowConfirmDialog = vi.mocked(showConfirmDialog)
|
||||
|
||||
interface CapturedConfirmOptions {
|
||||
footerProps?: {
|
||||
onConfirm?: () => void | Promise<void>
|
||||
onCancel?: () => void
|
||||
}
|
||||
}
|
||||
|
||||
function capturedOptions(): CapturedConfirmOptions {
|
||||
return mockShowConfirmDialog.mock
|
||||
.calls[0][0] as unknown as CapturedConfirmOptions
|
||||
}
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
const stubs = {
|
||||
Dialog: { template: '<div><slot /></div>' },
|
||||
DialogPortal: { template: '<div><slot /></div>' },
|
||||
DialogOverlay: { template: '<div />' },
|
||||
DialogContent: defineComponent({
|
||||
emits: ['open-auto-focus'],
|
||||
mounted() {
|
||||
this.$emit('open-auto-focus')
|
||||
},
|
||||
template: '<div><slot /></div>'
|
||||
}),
|
||||
DialogHeader: { template: '<div><slot /></div>' },
|
||||
DialogTitle: { template: '<div><slot /></div>' },
|
||||
DialogDescription: { template: '<div><slot /></div>' },
|
||||
SingleSelect: {
|
||||
props: ['modelValue', 'options', 'label'],
|
||||
emits: ['update:modelValue'],
|
||||
template: `
|
||||
<select
|
||||
data-testid="scheme-select"
|
||||
:value="modelValue ?? ''"
|
||||
@change="$emit('update:modelValue', $event.target.value)"
|
||||
>
|
||||
<option v-for="opt in options" :key="opt.value" :value="opt.value">
|
||||
{{ opt.name }}
|
||||
</option>
|
||||
</select>
|
||||
`
|
||||
}
|
||||
}
|
||||
|
||||
function createCredential(
|
||||
overrides: Partial<HostCredentialView> = {}
|
||||
): HostCredentialView {
|
||||
return {
|
||||
id: 'c1',
|
||||
host: 'huggingface.co',
|
||||
auth_scheme: 'bearer',
|
||||
header_name: null,
|
||||
query_param: null,
|
||||
label: null,
|
||||
match_subdomains: false,
|
||||
enabled: true,
|
||||
secret_last4: '1234',
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function mountDialog(props: { open?: boolean; prefillHost?: string } = {}) {
|
||||
return render(HostCredentialsDialog, {
|
||||
props: { open: true, ...props },
|
||||
global: { plugins: [i18n], stubs }
|
||||
})
|
||||
}
|
||||
|
||||
describe('HostCredentialsDialog', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCredentialsStore.credentials = []
|
||||
mockCredentialsStore.fetchCredentials.mockResolvedValue(undefined)
|
||||
mockShowConfirmDialog.mockReturnValue(
|
||||
{} as ReturnType<typeof showConfirmDialog>
|
||||
)
|
||||
})
|
||||
|
||||
it('renders existing credentials with host, scheme, and last 4 of the secret', () => {
|
||||
mockCredentialsStore.credentials = [
|
||||
createCredential({
|
||||
host: 'huggingface.co',
|
||||
auth_scheme: 'bearer',
|
||||
secret_last4: '9f21'
|
||||
})
|
||||
]
|
||||
mountDialog()
|
||||
|
||||
expect(
|
||||
screen.getByText('huggingface.co · Bearer token · ••••9f21')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows a disabled marker for a disabled credential', () => {
|
||||
mockCredentialsStore.credentials = [createCredential({ enabled: false })]
|
||||
mountDialog()
|
||||
|
||||
expect(screen.getByText(/Disabled/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('prefills the host from the prefillHost prop on open', async () => {
|
||||
mountDialog({ prefillHost: 'civitai.com' })
|
||||
await nextTick()
|
||||
|
||||
expect(screen.getByLabelText('Host')).toHaveValue('civitai.com')
|
||||
})
|
||||
|
||||
it('populates the form when editing an existing credential', async () => {
|
||||
mockCredentialsStore.credentials = [
|
||||
createCredential({
|
||||
id: 'c1',
|
||||
host: 'civitai.com',
|
||||
auth_scheme: 'header',
|
||||
header_name: 'X-Api-Key',
|
||||
label: 'My Civitai key'
|
||||
})
|
||||
]
|
||||
mountDialog()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Edit'))
|
||||
|
||||
expect(screen.getByLabelText('Host')).toHaveValue('civitai.com')
|
||||
expect(screen.getByLabelText('Label')).toHaveValue('My Civitai key')
|
||||
expect(screen.getByText('Update credential')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('disables submit until host and secret are filled', async () => {
|
||||
mountDialog()
|
||||
const submit = screen.getByRole('button', { name: 'Save' })
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
|
||||
expect(submit).toBeEnabled()
|
||||
})
|
||||
|
||||
it('requires a query parameter name when the scheme is query', async () => {
|
||||
mountDialog()
|
||||
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
|
||||
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
|
||||
await userEvent.selectOptions(screen.getByTestId('scheme-select'), 'query')
|
||||
|
||||
const submit = screen.getByRole('button', { name: 'Save' })
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('Query parameter'), 'token')
|
||||
expect(submit).toBeEnabled()
|
||||
})
|
||||
|
||||
it('shows the header name field and a subdomain warning for the header scheme', async () => {
|
||||
mountDialog()
|
||||
|
||||
await userEvent.selectOptions(screen.getByTestId('scheme-select'), 'header')
|
||||
expect(screen.getByLabelText('Header name')).toBeInTheDocument()
|
||||
|
||||
await userEvent.click(screen.getByLabelText('Match subdomains'))
|
||||
expect(
|
||||
screen.getByText(
|
||||
'Not recommended: hubs redirect to sibling CDN hosts that must not receive your key.'
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('cancels an in-progress edit and resets the form', async () => {
|
||||
mockCredentialsStore.credentials = [
|
||||
createCredential({ id: 'c1', host: 'civitai.com' })
|
||||
]
|
||||
mountDialog()
|
||||
await userEvent.click(screen.getByTitle('Edit'))
|
||||
expect(screen.getByText('Update credential')).toBeInTheDocument()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Cancel' }))
|
||||
|
||||
expect(screen.getByText('Add credential')).toBeInTheDocument()
|
||||
expect(screen.getByLabelText('Host')).toHaveValue('')
|
||||
})
|
||||
|
||||
it('submits the form payload and resets on success', async () => {
|
||||
mockCredentialsStore.upsert.mockResolvedValue(createCredential())
|
||||
mountDialog()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
|
||||
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Save' }))
|
||||
|
||||
expect(mockCredentialsStore.upsert).toHaveBeenCalledWith({
|
||||
host: 'huggingface.co',
|
||||
secret: 's3cret',
|
||||
auth_scheme: 'bearer',
|
||||
header_name: null,
|
||||
query_param: null,
|
||||
label: null,
|
||||
match_subdomains: false
|
||||
})
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
})
|
||||
expect(screen.getByLabelText('Host')).toHaveValue('')
|
||||
})
|
||||
|
||||
it('shows an inline error message when saving fails', async () => {
|
||||
mockCredentialsStore.upsert.mockRejectedValue(new Error('save failed'))
|
||||
mountDialog()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
|
||||
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Save' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(screen.getByText('save failed')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('shows an inline error when the initial credentials fetch fails', async () => {
|
||||
mockCredentialsStore.fetchCredentials.mockRejectedValue(
|
||||
new Error('offline')
|
||||
)
|
||||
mountDialog()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(screen.getByText('offline')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('deletes a credential through a confirm dialog', async () => {
|
||||
mockCredentialsStore.remove.mockResolvedValue(undefined)
|
||||
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
|
||||
mountDialog()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Delete'))
|
||||
await capturedOptions().footerProps?.onConfirm?.()
|
||||
|
||||
expect(mockCredentialsStore.remove).toHaveBeenCalledWith('c1')
|
||||
expect(mockCloseDialog).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows an error toast when deletion fails', async () => {
|
||||
mockCredentialsStore.remove.mockRejectedValue(new Error('boom'))
|
||||
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
|
||||
mountDialog()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Delete'))
|
||||
await capturedOptions().footerProps?.onConfirm?.()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('does not delete when the confirm dialog is dismissed', async () => {
|
||||
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
|
||||
mountDialog()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Delete'))
|
||||
capturedOptions().footerProps?.onCancel?.()
|
||||
|
||||
expect(mockCredentialsStore.remove).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -1,333 +0,0 @@
|
||||
<template>
|
||||
<Dialog v-model:open="isOpen">
|
||||
<DialogPortal>
|
||||
<DialogOverlay class="bg-black/70" />
|
||||
<DialogContent
|
||||
size="md"
|
||||
class="flex flex-col gap-4 p-6"
|
||||
@open-auto-focus="onOpen"
|
||||
>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{{ $t('modelManager.credentials.title') }}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{{ $t('modelManager.credentials.description') }}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div v-if="credentials.length" class="flex flex-col gap-2">
|
||||
<div
|
||||
v-for="credential in credentials"
|
||||
:key="credential.id"
|
||||
class="flex items-center gap-2 rounded-lg border border-border-default bg-secondary-background px-3 py-2"
|
||||
>
|
||||
<div class="flex min-w-0 flex-1 flex-col">
|
||||
<span class="truncate text-sm font-medium text-base-foreground">
|
||||
{{ credential.label || credential.host }}
|
||||
</span>
|
||||
<span class="truncate text-xs text-muted-foreground">
|
||||
{{ credential.host }} ·
|
||||
{{
|
||||
$t(
|
||||
`modelManager.credentials.scheme.${credential.auth_scheme}`
|
||||
)
|
||||
}}
|
||||
<template v-if="credential.secret_last4">
|
||||
· ••••{{ credential.secret_last4 }}
|
||||
</template>
|
||||
<template v-if="!credential.enabled">
|
||||
· {{ $t('modelManager.credentials.disabled') }}
|
||||
</template>
|
||||
</span>
|
||||
</div>
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.credentials.edit')"
|
||||
@click="editCredential(credential)"
|
||||
>
|
||||
<i class="icon-[lucide--pencil] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('g.delete')"
|
||||
@click="confirmDelete(credential)"
|
||||
>
|
||||
<i class="icon-[lucide--trash-2] size-4 text-red-400" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-3 border-t border-border-default pt-3">
|
||||
<span class="text-sm font-medium text-base-foreground">
|
||||
{{
|
||||
form.id
|
||||
? $t('modelManager.credentials.update')
|
||||
: $t('modelManager.credentials.add')
|
||||
}}
|
||||
</span>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-host">
|
||||
{{ $t('modelManager.credentials.host') }}
|
||||
</label>
|
||||
<Input
|
||||
id="cred-host"
|
||||
v-model="form.host"
|
||||
:placeholder="HOST_EXAMPLE"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-secret">
|
||||
{{ $t('modelManager.credentials.secret') }}
|
||||
</label>
|
||||
<Input
|
||||
id="cred-secret"
|
||||
v-model="form.secret"
|
||||
type="password"
|
||||
:placeholder="$t('modelManager.credentials.secretPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground">
|
||||
{{ $t('modelManager.credentials.authScheme') }}
|
||||
</label>
|
||||
<SingleSelect v-model="form.auth_scheme" :options="schemeOptions" />
|
||||
</div>
|
||||
|
||||
<div v-if="form.auth_scheme === 'header'" class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-header">
|
||||
{{ $t('modelManager.credentials.headerName') }}
|
||||
</label>
|
||||
<Input
|
||||
id="cred-header"
|
||||
v-model="form.header_name"
|
||||
:placeholder="HEADER_NAME_EXAMPLE"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="form.auth_scheme === 'query'" class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-query">
|
||||
{{ $t('modelManager.credentials.queryParam') }}
|
||||
</label>
|
||||
<Input
|
||||
id="cred-query"
|
||||
v-model="form.query_param"
|
||||
:placeholder="QUERY_PARAM_EXAMPLE"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-label">
|
||||
{{ $t('modelManager.credentials.label') }}
|
||||
</label>
|
||||
<Input id="cred-label" v-model="form.label" />
|
||||
</div>
|
||||
|
||||
<label class="flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<input
|
||||
v-model="form.match_subdomains"
|
||||
type="checkbox"
|
||||
class="size-4"
|
||||
/>
|
||||
{{ $t('modelManager.credentials.matchSubdomains') }}
|
||||
</label>
|
||||
<p v-if="form.match_subdomains" class="text-xs text-amber-400">
|
||||
{{ $t('modelManager.credentials.matchSubdomainsWarning') }}
|
||||
</p>
|
||||
|
||||
<p v-if="errorMessage" class="text-xs text-red-400">
|
||||
{{ errorMessage }}
|
||||
</p>
|
||||
|
||||
<div class="flex justify-end gap-2">
|
||||
<Button v-if="form.id" variant="secondary" @click="resetForm">
|
||||
{{ $t('g.cancel') }}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
:disabled="!canSubmit"
|
||||
:loading="isSubmitting"
|
||||
@click="submit"
|
||||
>
|
||||
{{ $t('modelManager.credentials.save') }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</DialogPortal>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, reactive, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import Dialog from '@/components/ui/dialog/Dialog.vue'
|
||||
import DialogContent from '@/components/ui/dialog/DialogContent.vue'
|
||||
import DialogDescription from '@/components/ui/dialog/DialogDescription.vue'
|
||||
import DialogHeader from '@/components/ui/dialog/DialogHeader.vue'
|
||||
import DialogOverlay from '@/components/ui/dialog/DialogOverlay.vue'
|
||||
import DialogPortal from '@/components/ui/dialog/DialogPortal.vue'
|
||||
import DialogTitle from '@/components/ui/dialog/DialogTitle.vue'
|
||||
import Input from '@/components/ui/input/Input.vue'
|
||||
import SingleSelect from '@/components/ui/single-select/SingleSelect.vue'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { storeToRefs } from 'pinia'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
|
||||
import { useHostCredentialsStore } from '../stores/hostCredentialsStore'
|
||||
import { AUTH_SCHEMES } from '../types'
|
||||
import type { AuthScheme, HostCredentialView } from '../types'
|
||||
|
||||
const HOST_EXAMPLE = 'huggingface.co'
|
||||
const HEADER_NAME_EXAMPLE = 'Authorization'
|
||||
const QUERY_PARAM_EXAMPLE = 'token'
|
||||
|
||||
const { prefillHost = '' } = defineProps<{ prefillHost?: string }>()
|
||||
|
||||
const isOpen = defineModel<boolean>('open', { required: true })
|
||||
|
||||
const { t } = useI18n()
|
||||
const store = useHostCredentialsStore()
|
||||
const { credentials } = storeToRefs(store)
|
||||
const dialogStore = useDialogStore()
|
||||
|
||||
const isSubmitting = ref(false)
|
||||
const errorMessage = ref('')
|
||||
|
||||
interface CredentialForm {
|
||||
id: string | null
|
||||
host: string
|
||||
secret: string
|
||||
auth_scheme: AuthScheme
|
||||
header_name: string
|
||||
query_param: string
|
||||
label: string
|
||||
match_subdomains: boolean
|
||||
}
|
||||
|
||||
function emptyForm(): CredentialForm {
|
||||
return {
|
||||
id: null,
|
||||
host: '',
|
||||
secret: '',
|
||||
auth_scheme: 'bearer',
|
||||
header_name: '',
|
||||
query_param: '',
|
||||
label: '',
|
||||
match_subdomains: false
|
||||
}
|
||||
}
|
||||
|
||||
const form = reactive<CredentialForm>(emptyForm())
|
||||
|
||||
const schemeOptions = computed(() =>
|
||||
AUTH_SCHEMES.map((scheme) => ({
|
||||
value: scheme,
|
||||
name: t(`modelManager.credentials.scheme.${scheme}`)
|
||||
}))
|
||||
)
|
||||
|
||||
const canSubmit = computed(
|
||||
() =>
|
||||
!!form.host.trim() &&
|
||||
!!form.secret &&
|
||||
(form.auth_scheme !== 'query' || !!form.query_param.trim()) &&
|
||||
(form.auth_scheme !== 'header' || !!form.header_name.trim())
|
||||
)
|
||||
|
||||
async function onOpen() {
|
||||
resetForm()
|
||||
if (prefillHost) {
|
||||
form.host = prefillHost
|
||||
}
|
||||
try {
|
||||
await store.fetchCredentials()
|
||||
} catch (error) {
|
||||
errorMessage.value =
|
||||
error instanceof Error ? error.message : t('modelManager.actionFailed')
|
||||
}
|
||||
}
|
||||
|
||||
function resetForm() {
|
||||
Object.assign(form, emptyForm())
|
||||
errorMessage.value = ''
|
||||
}
|
||||
|
||||
function editCredential(credential: HostCredentialView) {
|
||||
Object.assign(form, {
|
||||
id: credential.id,
|
||||
host: credential.host,
|
||||
secret: '',
|
||||
auth_scheme: credential.auth_scheme,
|
||||
header_name: credential.header_name ?? '',
|
||||
query_param: credential.query_param ?? '',
|
||||
label: credential.label ?? '',
|
||||
match_subdomains: credential.match_subdomains
|
||||
})
|
||||
errorMessage.value = ''
|
||||
}
|
||||
|
||||
async function submit() {
|
||||
if (!canSubmit.value) return
|
||||
isSubmitting.value = true
|
||||
errorMessage.value = ''
|
||||
try {
|
||||
await store.upsert({
|
||||
host: form.host,
|
||||
secret: form.secret,
|
||||
auth_scheme: form.auth_scheme,
|
||||
header_name: form.auth_scheme === 'header' ? form.header_name : null,
|
||||
query_param: form.auth_scheme === 'query' ? form.query_param : null,
|
||||
label: form.label || null,
|
||||
match_subdomains: form.match_subdomains
|
||||
})
|
||||
useToastStore().add({
|
||||
severity: 'success',
|
||||
summary: t('modelManager.credentials.saved'),
|
||||
detail: form.host,
|
||||
life: 4000
|
||||
})
|
||||
resetForm()
|
||||
} catch (error) {
|
||||
errorMessage.value =
|
||||
error instanceof Error ? error.message : t('modelManager.actionFailed')
|
||||
} finally {
|
||||
isSubmitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function confirmDelete(credential: HostCredentialView) {
|
||||
const dialog = showConfirmDialog({
|
||||
headerProps: { title: t('modelManager.credentials.deleteTitle') },
|
||||
props: {
|
||||
promptText: t('modelManager.credentials.deleteMessage', {
|
||||
host: credential.host
|
||||
})
|
||||
},
|
||||
footerProps: {
|
||||
confirmText: t('g.delete'),
|
||||
confirmVariant: 'destructive' as const,
|
||||
onCancel: () => dialogStore.closeDialog(dialog),
|
||||
onConfirm: async () => {
|
||||
dialogStore.closeDialog(dialog)
|
||||
try {
|
||||
await store.remove(credential.id)
|
||||
} catch (error) {
|
||||
useToastStore().add({
|
||||
severity: 'error',
|
||||
summary: t('modelManager.actionFailed'),
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
life: 5000
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
</script>
|
||||
@@ -1,336 +0,0 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { DownloadStatus } from '../types'
|
||||
import ModelDownloadRow from './ModelDownloadRow.vue'
|
||||
|
||||
const mockPause = vi.fn()
|
||||
const mockResume = vi.fn()
|
||||
const mockCancel = vi.fn()
|
||||
const mockRaisePriority = vi.fn()
|
||||
const mockRemove = vi.fn()
|
||||
|
||||
vi.mock('../composables/useModelDownloadActions', () => ({
|
||||
useModelDownloadActions: () => ({
|
||||
pause: mockPause,
|
||||
resume: mockResume,
|
||||
cancel: mockCancel,
|
||||
raisePriority: mockRaisePriority,
|
||||
remove: mockRemove,
|
||||
toastError: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
function createDownload(
|
||||
overrides: Partial<DownloadStatus> = {}
|
||||
): DownloadStatus {
|
||||
return {
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
url: 'https://huggingface.co/org/x.safetensors',
|
||||
status: 'active',
|
||||
priority: 0,
|
||||
total_bytes: 2048,
|
||||
bytes_done: 1024,
|
||||
progress: 0.5,
|
||||
speed_bps: 512,
|
||||
eta_seconds: 125,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function mountRow(
|
||||
download: DownloadStatus,
|
||||
onOpenCredentials?: (host: string) => void
|
||||
) {
|
||||
return render(ModelDownloadRow, {
|
||||
props: {
|
||||
download,
|
||||
...(onOpenCredentials ? { onOpenCredentials } : {})
|
||||
},
|
||||
global: { plugins: [i18n] }
|
||||
})
|
||||
}
|
||||
|
||||
describe('ModelDownloadRow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('splits the model id into directory and filename', () => {
|
||||
mountRow(createDownload({ model_id: 'loras/x.safetensors' }))
|
||||
|
||||
expect(screen.getByText('x.safetensors')).toBeInTheDocument()
|
||||
expect(screen.getByText('loras')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders an empty directory when the model id has no folder', () => {
|
||||
mountRow(createDownload({ model_id: 'x.safetensors' }))
|
||||
|
||||
expect(screen.getByText('x.safetensors')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('formats the meta line with percent, size, speed, and eta while active', () => {
|
||||
mountRow(createDownload({ status: 'active' }))
|
||||
|
||||
expect(
|
||||
screen.getByText('50% · 1 KB / 2 KB · 512 B/s · 2:05')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders an empty meta line when no progress metrics are known yet', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'queued',
|
||||
progress: null,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
speed_bps: null,
|
||||
eta_seconds: null
|
||||
})
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('meta-line')).toBeEmptyDOMElement()
|
||||
})
|
||||
|
||||
it('omits the eta once a download is no longer active', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'paused',
|
||||
progress: 0.5,
|
||||
eta_seconds: 125
|
||||
})
|
||||
)
|
||||
|
||||
expect(screen.getByText('50% · 1 KB / 2 KB · 512 B/s')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
describe('progress bar visibility', () => {
|
||||
it('shows a progress bar for an active download with known progress', () => {
|
||||
mountRow(createDownload({ status: 'active', progress: 0.5 }))
|
||||
|
||||
expect(screen.getByTestId('progress-bar')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('hides the progress bar for a cancelled download', () => {
|
||||
mountRow(createDownload({ status: 'cancelled', progress: 0.5 }))
|
||||
|
||||
expect(screen.queryByTestId('progress-bar')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('Cancelled')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('action buttons by status', () => {
|
||||
it('shows pause and cancel for a queued download, plus raise priority', () => {
|
||||
mountRow(createDownload({ status: 'queued' }))
|
||||
|
||||
expect(screen.getByTitle('Raise priority')).toBeInTheDocument()
|
||||
expect(screen.getByTitle('Pause')).toBeInTheDocument()
|
||||
expect(screen.getByTitle('Cancel')).toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Resume')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Remove from list')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows only resume for a failed download without an auth error', () => {
|
||||
mountRow(createDownload({ status: 'failed', error: 'disk full' }))
|
||||
|
||||
expect(screen.getByTitle('Resume')).toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Pause')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Cancel')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Remove from list')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Add credentials')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows the remove action for terminal downloads', () => {
|
||||
mountRow(createDownload({ status: 'completed' }))
|
||||
|
||||
expect(screen.getByTitle('Remove from list')).toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Cancel')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Resume')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('auth errors', () => {
|
||||
it('shows the credentials button and a host-specific hint', async () => {
|
||||
const onOpenCredentials = vi.fn()
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://huggingface.co/org/x.safetensors',
|
||||
error: '401 Unauthorized'
|
||||
}),
|
||||
onOpenCredentials
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText(
|
||||
'huggingface.co needs an API key. Add one in the Credentials Manager, then resume.'
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Add credentials'))
|
||||
expect(onOpenCredentials).toHaveBeenCalledWith('huggingface.co')
|
||||
})
|
||||
|
||||
it('falls back to a hostless hint when the url cannot be parsed', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'not-a-url',
|
||||
error: '403 forbidden'
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText(
|
||||
'This host/model needs an API key. Add one in the Credentials Manager, then resume.'
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows the raw error message for a non-auth failure', () => {
|
||||
mountRow(createDownload({ status: 'failed', error: 'disk full' }))
|
||||
|
||||
expect(screen.getByText('disk full')).toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Add credentials')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('hides a leftover error while the download is not yet terminal', () => {
|
||||
mountRow(createDownload({ status: 'active', error: 'disk full' }))
|
||||
|
||||
expect(screen.queryByText('disk full')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('gated models', () => {
|
||||
const gatedError =
|
||||
'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/vae/diffusion_pytorch_model.safetensors is a gated model — Access to model black-forest-labs/FLUX.2-dev is restricted. You must have access to it and be authenticated to access it.'
|
||||
|
||||
it('shows the gated hint, credentials button, and a link to accept the license on the model page', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/vae/diffusion_pytorch_model.safetensors',
|
||||
error: gatedError
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText(
|
||||
"This model is gated. Accept its license on the model's page, then add an API key and resume."
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
expect(screen.getByTitle('Add credentials')).toBeInTheDocument()
|
||||
|
||||
const link = screen.getByRole('link', { name: 'Accept license' })
|
||||
expect(link).toHaveAttribute(
|
||||
'href',
|
||||
'https://huggingface.co/black-forest-labs/FLUX.2-dev'
|
||||
)
|
||||
})
|
||||
|
||||
it('derives the model page from the error when the download url is a cdn link', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://cas-bridge.xethub.hf.co/xet-bridge-us/abc/def',
|
||||
error: gatedError
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByRole('link', { name: 'Accept license' })
|
||||
).toHaveAttribute(
|
||||
'href',
|
||||
'https://huggingface.co/black-forest-labs/FLUX.2-dev'
|
||||
)
|
||||
})
|
||||
|
||||
it('hides the raw backend error text for gated failures', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/model.safetensors',
|
||||
error: gatedError
|
||||
})
|
||||
)
|
||||
|
||||
expect(screen.queryByText(gatedError)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('omits the accept-license link when no huggingface url is present', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://example.com/model.safetensors',
|
||||
error: 'Access to this model is restricted, request access.'
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.queryByRole('link', { name: 'Accept license' })
|
||||
).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('user actions', () => {
|
||||
it('pauses on click', async () => {
|
||||
const download = createDownload({ status: 'active' })
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Pause'))
|
||||
expect(mockPause).toHaveBeenCalledWith(download)
|
||||
})
|
||||
|
||||
it('resumes on click', async () => {
|
||||
const download = createDownload({ status: 'paused' })
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Resume'))
|
||||
expect(mockResume).toHaveBeenCalledWith(download)
|
||||
})
|
||||
|
||||
it('cancels on click', async () => {
|
||||
const download = createDownload({ status: 'active' })
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Cancel'))
|
||||
expect(mockCancel).toHaveBeenCalledWith(download)
|
||||
})
|
||||
|
||||
it('raises priority by 1 on click', async () => {
|
||||
const download = createDownload({ status: 'queued', priority: 2 })
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Raise priority'))
|
||||
expect(mockRaisePriority).toHaveBeenCalledWith(download, 1)
|
||||
})
|
||||
|
||||
it('removes the download on click', async () => {
|
||||
const download = createDownload({
|
||||
download_id: 'd1',
|
||||
status: 'completed'
|
||||
})
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Remove from list'))
|
||||
expect(mockRemove).toHaveBeenCalledWith(download)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,245 +0,0 @@
|
||||
<template>
|
||||
<div
|
||||
:class="
|
||||
cn(
|
||||
'relative flex flex-col gap-1 overflow-hidden rounded-lg border border-border-default bg-secondary-background px-3 py-2',
|
||||
isCancelled && 'opacity-60'
|
||||
)
|
||||
"
|
||||
>
|
||||
<div
|
||||
v-if="showProgressBar"
|
||||
:class="progressBarContainerClass"
|
||||
data-testid="progress-bar"
|
||||
>
|
||||
<div :class="progressBarPrimaryClass" :style="barStyle" />
|
||||
</div>
|
||||
|
||||
<div class="relative flex items-center gap-2">
|
||||
<div class="flex min-w-0 flex-1 flex-col">
|
||||
<span class="truncate text-sm font-medium text-base-foreground">
|
||||
{{ filename }}
|
||||
</span>
|
||||
<span class="truncate text-xs text-muted-foreground">
|
||||
{{ directory }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="flex shrink-0 items-center gap-0.5">
|
||||
<template v-if="canRaisePriority">
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.raisePriority')"
|
||||
@click="actions.raisePriority(download, 1)"
|
||||
>
|
||||
<i class="icon-[lucide--chevron-up] size-4" />
|
||||
</Button>
|
||||
</template>
|
||||
<Button
|
||||
v-if="canPause"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('g.pause')"
|
||||
@click="actions.pause(download)"
|
||||
>
|
||||
<i class="icon-[lucide--pause] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
v-if="isAuthError"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.addCredentials')"
|
||||
@click="emit('openCredentials', host)"
|
||||
>
|
||||
<i class="icon-[lucide--key-round] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
v-if="canResume"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.resume')"
|
||||
@click="actions.resume(download)"
|
||||
>
|
||||
<i class="icon-[lucide--play] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
v-if="canCancel"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('g.cancel')"
|
||||
@click="actions.cancel(download)"
|
||||
>
|
||||
<i class="icon-[lucide--x] size-4 text-red-400" />
|
||||
</Button>
|
||||
<Button
|
||||
v-if="isTerminal"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.removeFromList')"
|
||||
@click="actions.remove(download)"
|
||||
>
|
||||
<i class="icon-[lucide--x] size-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="relative flex items-center justify-between gap-2 text-xs text-muted-foreground"
|
||||
>
|
||||
<span class="flex items-center gap-1">
|
||||
<i v-if="isCancelled" class="icon-[lucide--ban] size-3.5" />
|
||||
{{ statusLabel }}
|
||||
</span>
|
||||
<span class="truncate" data-testid="meta-line">{{ metaLine }}</span>
|
||||
</div>
|
||||
|
||||
<p
|
||||
v-if="isGatedModel"
|
||||
class="relative text-xs wrap-break-word text-amber-400"
|
||||
>
|
||||
{{ $t('modelManager.gatedModelHint') }}
|
||||
<a
|
||||
v-if="modelPageUrl"
|
||||
:href="modelPageUrl"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="underline"
|
||||
>
|
||||
{{ $t('modelManager.openModelPage') }}
|
||||
</a>
|
||||
</p>
|
||||
<p
|
||||
v-else-if="isAuthError"
|
||||
class="relative text-xs wrap-break-word text-amber-400"
|
||||
>
|
||||
{{
|
||||
host
|
||||
? $t('modelManager.authErrorHint', { host })
|
||||
: $t('modelManager.authErrorHintNoHost')
|
||||
}}
|
||||
</p>
|
||||
<p
|
||||
v-else-if="isFailed && download.error"
|
||||
class="relative text-xs wrap-break-word text-red-400"
|
||||
>
|
||||
{{ download.error }}
|
||||
</p>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useProgressBarBackground } from '@/composables/useProgressBarBackground'
|
||||
import { formatSize } from '@/utils/formatUtil'
|
||||
|
||||
import { useModelDownloadActions } from '../composables/useModelDownloadActions'
|
||||
import { downloadProgressFraction } from '../stores/modelDownloadStore'
|
||||
import type { DownloadStatus } from '../types'
|
||||
import { directoryOf, filenameOf } from '../utils/modelId'
|
||||
|
||||
const { download } = defineProps<{ download: DownloadStatus }>()
|
||||
|
||||
const emit = defineEmits<{ openCredentials: [host: string] }>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const actions = useModelDownloadActions()
|
||||
const {
|
||||
progressBarContainerClass,
|
||||
progressBarPrimaryClass,
|
||||
progressPercentStyle
|
||||
} = useProgressBarBackground()
|
||||
|
||||
const directory = computed(() => directoryOf(download.model_id))
|
||||
const filename = computed(() => filenameOf(download.model_id))
|
||||
|
||||
const host = computed(() => {
|
||||
try {
|
||||
return new URL(download.url).hostname
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
})
|
||||
|
||||
const percent = computed(() => {
|
||||
const fraction = downloadProgressFraction(download)
|
||||
return fraction == null ? undefined : Math.round(fraction * 100)
|
||||
})
|
||||
|
||||
const barStyle = computed(() => progressPercentStyle(percent.value))
|
||||
|
||||
const isTerminal = computed(() =>
|
||||
['completed', 'cancelled'].includes(download.status)
|
||||
)
|
||||
const isCancelled = computed(() => download.status === 'cancelled')
|
||||
const showProgressBar = computed(
|
||||
() =>
|
||||
download.status !== 'failed' &&
|
||||
!isCancelled.value &&
|
||||
percent.value !== undefined
|
||||
)
|
||||
const canPause = computed(() => ['queued', 'active'].includes(download.status))
|
||||
const canResume = computed(() => ['paused', 'failed'].includes(download.status))
|
||||
const canCancel = computed(
|
||||
() => !isTerminal.value && download.status !== 'failed'
|
||||
)
|
||||
const canRaisePriority = computed(() => download.status === 'queued')
|
||||
|
||||
const GATED_MODEL_PATTERN = /gated|restricted|request access|must have access/i
|
||||
const AUTH_ERROR_PATTERN =
|
||||
/api key|credential|token|unauthor|forbidden|\b401\b|\b403\b/i
|
||||
|
||||
const isFailed = computed(() => download.status === 'failed')
|
||||
const isGatedModel = computed(
|
||||
() => isFailed.value && !!download.error?.match(GATED_MODEL_PATTERN)
|
||||
)
|
||||
const isAuthError = computed(
|
||||
() =>
|
||||
isFailed.value &&
|
||||
(isGatedModel.value || !!download.error?.match(AUTH_ERROR_PATTERN))
|
||||
)
|
||||
|
||||
const HF_MODEL_URL_PATTERN =
|
||||
/https?:\/\/huggingface\.co\/([^/\s?#]+)\/([^/\s?#]+)/i
|
||||
const modelPageUrl = computed(() => {
|
||||
const match = `${download.url} ${download.error ?? ''}`.match(
|
||||
HF_MODEL_URL_PATTERN
|
||||
)
|
||||
if (!match) return ''
|
||||
const [, owner, repo] = match
|
||||
return `https://huggingface.co/${owner}/${repo}`
|
||||
})
|
||||
|
||||
const statusLabel = computed(() => t(`modelManager.status.${download.status}`))
|
||||
|
||||
const metaLine = computed(() => {
|
||||
const parts: string[] = []
|
||||
if (percent.value !== undefined) parts.push(`${percent.value}%`)
|
||||
if (download.total_bytes != null) {
|
||||
parts.push(
|
||||
`${formatSize(download.bytes_done)} / ${formatSize(download.total_bytes)}`
|
||||
)
|
||||
}
|
||||
if (download.speed_bps) parts.push(`${formatSize(download.speed_bps)}/s`)
|
||||
if (download.eta_seconds != null && download.status === 'active') {
|
||||
parts.push(formatEta(download.eta_seconds))
|
||||
}
|
||||
return parts.join(' · ')
|
||||
})
|
||||
|
||||
function formatEta(seconds: number): string {
|
||||
const total = Math.max(0, Math.round(seconds))
|
||||
const hours = Math.floor(total / 3600)
|
||||
const minutes = Math.floor((total % 3600) / 60)
|
||||
const secs = total % 60
|
||||
const paddedSecs = secs.toString().padStart(2, '0')
|
||||
if (hours > 0) {
|
||||
return `${hours}:${minutes.toString().padStart(2, '0')}:${paddedSecs}`
|
||||
}
|
||||
return `${minutes}:${paddedSecs}`
|
||||
}
|
||||
</script>
|
||||
@@ -1,178 +0,0 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { reactive, ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { DownloadStatus } from '../types'
|
||||
import ModelManagerSidebarTab from './ModelManagerSidebarTab.vue'
|
||||
|
||||
const mockHydrate = vi.fn()
|
||||
const mockClearHistory = vi.fn()
|
||||
|
||||
const mockStore = reactive({
|
||||
downloadList: ref<DownloadStatus[]>([]),
|
||||
activeDownloads: ref<DownloadStatus[]>([]),
|
||||
historyDownloads: ref<DownloadStatus[]>([]),
|
||||
hydrate: mockHydrate
|
||||
})
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => mockStore
|
||||
}))
|
||||
|
||||
vi.mock('../composables/useModelDownloadActions', () => ({
|
||||
useModelDownloadActions: () => ({ clearHistory: mockClearHistory })
|
||||
}))
|
||||
|
||||
vi.mock('./ModelDownloadRow.vue', () => ({
|
||||
default: {
|
||||
name: 'ModelDownloadRow',
|
||||
props: ['download'],
|
||||
emits: ['openCredentials'],
|
||||
template:
|
||||
'<div><span>{{ download.download_id }}</span>' +
|
||||
'<button @click="$emit(\'openCredentials\', download.model_id)">open-credentials</button></div>'
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('./AddModelByUrlDialog.vue', () => ({
|
||||
default: {
|
||||
name: 'AddModelByUrlDialog',
|
||||
props: ['open'],
|
||||
template: '<div data-testid="add-model-dialog">{{ open }}</div>'
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('./HostCredentialsDialog.vue', () => ({
|
||||
default: {
|
||||
name: 'HostCredentialsDialog',
|
||||
props: ['open', 'prefillHost'],
|
||||
template:
|
||||
'<div data-testid="credentials-dialog">{{ open }}:{{ prefillHost }}</div>'
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/components/sidebar/tabs/SidebarTabTemplate.vue', () => ({
|
||||
default: {
|
||||
name: 'SidebarTabTemplate',
|
||||
template: '<div><slot name="tool-buttons" /><slot name="body" /></div>'
|
||||
}
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
function createDownload(
|
||||
overrides: Partial<DownloadStatus> = {}
|
||||
): DownloadStatus {
|
||||
return {
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
url: 'https://huggingface.co/org/x.safetensors',
|
||||
status: 'active',
|
||||
priority: 0,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
progress: null,
|
||||
speed_bps: null,
|
||||
eta_seconds: null,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function mountTab() {
|
||||
return render(ModelManagerSidebarTab, { global: { plugins: [i18n] } })
|
||||
}
|
||||
|
||||
describe('ModelManagerSidebarTab', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockStore.downloadList = []
|
||||
mockStore.activeDownloads = []
|
||||
mockStore.historyDownloads = []
|
||||
})
|
||||
|
||||
it('hydrates the store on mount', () => {
|
||||
mountTab()
|
||||
expect(mockHydrate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows the empty state when there are no downloads', () => {
|
||||
mountTab()
|
||||
expect(screen.getByText('No downloads yet')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders active downloads under the Active section', () => {
|
||||
const download = createDownload({ download_id: 'd1' })
|
||||
mockStore.activeDownloads = [download]
|
||||
mockStore.downloadList = [download]
|
||||
mountTab()
|
||||
|
||||
expect(screen.getByText('Active')).toBeInTheDocument()
|
||||
expect(screen.getByText('d1')).toBeInTheDocument()
|
||||
expect(screen.queryByText('No downloads yet')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders history downloads under the History section with a clear action', async () => {
|
||||
const download = createDownload({ download_id: 'd2' })
|
||||
mockStore.historyDownloads = [download]
|
||||
mockStore.downloadList = [download]
|
||||
mountTab()
|
||||
|
||||
expect(screen.getByText('History')).toBeInTheDocument()
|
||||
expect(screen.getByText('d2')).toBeInTheDocument()
|
||||
|
||||
await userEvent.click(screen.getByText('Clear history'))
|
||||
expect(mockClearHistory).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('opens the add-model dialog from the toolbar button', async () => {
|
||||
mountTab()
|
||||
|
||||
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('false')
|
||||
await userEvent.click(screen.getByTitle('Add model'))
|
||||
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('true')
|
||||
})
|
||||
|
||||
it('opens the add-model dialog from the empty state button', async () => {
|
||||
mountTab()
|
||||
|
||||
await userEvent.click(screen.getByText('Add model'))
|
||||
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('true')
|
||||
})
|
||||
|
||||
it('opens the credentials dialog with an empty prefill from the toolbar button', async () => {
|
||||
mountTab()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Credentials Manager'))
|
||||
expect(screen.getByTestId('credentials-dialog')).toHaveTextContent('true:')
|
||||
})
|
||||
|
||||
it('opens the credentials dialog prefilled with the row host', async () => {
|
||||
const download = createDownload({
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors'
|
||||
})
|
||||
mockStore.activeDownloads = [download]
|
||||
mockStore.downloadList = [download]
|
||||
mountTab()
|
||||
|
||||
await userEvent.click(screen.getByText('open-credentials'))
|
||||
|
||||
expect(screen.getByTestId('credentials-dialog')).toHaveTextContent(
|
||||
'true:loras/x.safetensors'
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -1,103 +0,0 @@
|
||||
<template>
|
||||
<SidebarTabTemplate :title="$t('modelManager.title')">
|
||||
<template #tool-buttons>
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.credentials.title')"
|
||||
@click="openCredentials('')"
|
||||
>
|
||||
<i class="icon-[lucide--key-round] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.addModel')"
|
||||
@click="addOpen = true"
|
||||
>
|
||||
<i class="icon-[lucide--plus] size-4" />
|
||||
</Button>
|
||||
</template>
|
||||
|
||||
<template #body>
|
||||
<div class="flex flex-col gap-4 p-3">
|
||||
<section v-if="activeDownloads.length" class="flex flex-col gap-2">
|
||||
<h3 class="text-xs font-semibold text-muted-foreground uppercase">
|
||||
{{ $t('modelManager.active') }}
|
||||
</h3>
|
||||
<ModelDownloadRow
|
||||
v-for="download in activeDownloads"
|
||||
:key="download.download_id"
|
||||
:download
|
||||
@open-credentials="openCredentials"
|
||||
/>
|
||||
</section>
|
||||
|
||||
<section v-if="historyDownloads.length" class="flex flex-col gap-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<h3 class="text-xs font-semibold text-muted-foreground uppercase">
|
||||
{{ $t('modelManager.history') }}
|
||||
</h3>
|
||||
<Button variant="link" size="sm" @click="actions.clearHistory()">
|
||||
{{ $t('modelManager.clearHistory') }}
|
||||
</Button>
|
||||
</div>
|
||||
<ModelDownloadRow
|
||||
v-for="download in historyDownloads"
|
||||
:key="download.download_id"
|
||||
:download
|
||||
@open-credentials="openCredentials"
|
||||
/>
|
||||
</section>
|
||||
|
||||
<div
|
||||
v-if="!store.downloadList.length"
|
||||
class="flex flex-col items-center gap-3 py-10 text-center text-muted-foreground"
|
||||
>
|
||||
<i class="icon-[lucide--download] size-8" />
|
||||
<p class="text-sm">{{ $t('modelManager.empty') }}</p>
|
||||
<Button variant="primary" size="sm" @click="addOpen = true">
|
||||
{{ $t('modelManager.addModel') }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</SidebarTabTemplate>
|
||||
|
||||
<AddModelByUrlDialog v-model:open="addOpen" />
|
||||
<HostCredentialsDialog
|
||||
v-model:open="credentialsOpen"
|
||||
:prefill-host="prefillHost"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { storeToRefs } from 'pinia'
|
||||
import { onMounted, ref } from 'vue'
|
||||
|
||||
import SidebarTabTemplate from '@/components/sidebar/tabs/SidebarTabTemplate.vue'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
|
||||
import AddModelByUrlDialog from './AddModelByUrlDialog.vue'
|
||||
import HostCredentialsDialog from './HostCredentialsDialog.vue'
|
||||
import ModelDownloadRow from './ModelDownloadRow.vue'
|
||||
import { useModelDownloadActions } from '../composables/useModelDownloadActions'
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
|
||||
const store = useModelDownloadStore()
|
||||
const actions = useModelDownloadActions()
|
||||
const { activeDownloads, historyDownloads } = storeToRefs(store)
|
||||
|
||||
const addOpen = ref(false)
|
||||
const credentialsOpen = ref(false)
|
||||
const prefillHost = ref('')
|
||||
|
||||
function openCredentials(host: string) {
|
||||
prefillHost.value = host
|
||||
credentialsOpen.value = true
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
void store.hydrate()
|
||||
})
|
||||
</script>
|
||||
@@ -1,51 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import * as downloadApi from '../api/modelDownloadApi'
|
||||
import { useModelAvailability } from './useModelAvailability'
|
||||
|
||||
vi.mock('../api/modelDownloadApi', () => ({
|
||||
checkAvailability: vi.fn()
|
||||
}))
|
||||
|
||||
describe('useModelAvailability', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
it('short-circuits without calling the API for an empty map', async () => {
|
||||
const { check, results } = useModelAvailability()
|
||||
|
||||
await check({})
|
||||
|
||||
expect(downloadApi.checkAvailability).not.toHaveBeenCalled()
|
||||
expect(results.value).toEqual({})
|
||||
})
|
||||
|
||||
it('stores the availability map from the response', async () => {
|
||||
vi.mocked(downloadApi.checkAvailability).mockResolvedValue({
|
||||
models: {
|
||||
'loras/x.safetensors': { state: 'available', url_allowed: true }
|
||||
}
|
||||
})
|
||||
const { check, results, isChecking } = useModelAvailability()
|
||||
|
||||
await check({ 'loras/x.safetensors': 'https://h.co/x' })
|
||||
|
||||
expect(results.value['loras/x.safetensors']).toEqual({
|
||||
state: 'available',
|
||||
url_allowed: true
|
||||
})
|
||||
expect(isChecking.value).toBe(false)
|
||||
})
|
||||
|
||||
it('captures errors and resets the checking flag', async () => {
|
||||
vi.mocked(downloadApi.checkAvailability).mockRejectedValue(
|
||||
new Error('boom')
|
||||
)
|
||||
const { check, error, isChecking } = useModelAvailability()
|
||||
|
||||
await expect(check({ 'loras/x.safetensors': 'u' })).rejects.toThrow('boom')
|
||||
expect(error.value).toBeInstanceOf(Error)
|
||||
expect(isChecking.value).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -1,41 +0,0 @@
|
||||
import { ref, shallowRef } from 'vue'
|
||||
|
||||
import { checkAvailability } from '../api/modelDownloadApi'
|
||||
import type { AvailabilityEntry } from '../types'
|
||||
|
||||
/**
|
||||
* Bulk-checks whether the models declared by a workflow are present,
|
||||
* downloading, or missing. Pass a `model_id -> source URL` map and badge
|
||||
* each model from the returned entries.
|
||||
*/
|
||||
export function useModelAvailability() {
|
||||
const results = shallowRef<Record<string, AvailabilityEntry>>({})
|
||||
const isChecking = ref(false)
|
||||
const error = ref<unknown>(null)
|
||||
|
||||
async function check(models: Record<string, string>) {
|
||||
if (Object.keys(models).length === 0) {
|
||||
results.value = {}
|
||||
return results.value
|
||||
}
|
||||
isChecking.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const response = await checkAvailability(models)
|
||||
results.value = response.models
|
||||
return response.models
|
||||
} catch (e) {
|
||||
error.value = e
|
||||
throw e
|
||||
} finally {
|
||||
isChecking.value = false
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
results,
|
||||
isChecking,
|
||||
error,
|
||||
check
|
||||
}
|
||||
}
|
||||
@@ -1,168 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
|
||||
|
||||
import type { DownloadStatus } from '../types'
|
||||
import { DownloadApiError } from '../types'
|
||||
import { useModelDownloadActions } from './useModelDownloadActions'
|
||||
|
||||
const mockStore = {
|
||||
pause: vi.fn(),
|
||||
resume: vi.fn(),
|
||||
cancel: vi.fn(),
|
||||
setPriority: vi.fn(),
|
||||
hydrate: vi.fn()
|
||||
}
|
||||
const mockToastAdd = vi.fn()
|
||||
const mockCloseDialog = vi.fn()
|
||||
|
||||
vi.mock('vue-i18n', () => ({
|
||||
useI18n: () => ({ t: (key: string) => key })
|
||||
}))
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => mockStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: mockToastAdd })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/dialogStore', () => ({
|
||||
useDialogStore: () => ({ closeDialog: mockCloseDialog })
|
||||
}))
|
||||
|
||||
vi.mock('@/components/dialog/confirm/confirmDialog')
|
||||
|
||||
const mockShowConfirmDialog = vi.mocked(showConfirmDialog)
|
||||
|
||||
interface CapturedConfirmOptions {
|
||||
footerProps?: {
|
||||
confirmVariant?: string
|
||||
onConfirm?: () => void | Promise<void>
|
||||
onCancel?: () => void
|
||||
}
|
||||
}
|
||||
|
||||
function capturedOptions(): CapturedConfirmOptions {
|
||||
return mockShowConfirmDialog.mock
|
||||
.calls[0][0] as unknown as CapturedConfirmOptions
|
||||
}
|
||||
|
||||
function createDownload(
|
||||
overrides: Partial<DownloadStatus> = {}
|
||||
): DownloadStatus {
|
||||
return {
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
url: 'https://huggingface.co/x.safetensors',
|
||||
status: 'active',
|
||||
priority: 0,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
progress: null,
|
||||
speed_bps: null,
|
||||
eta_seconds: null,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
describe('useModelDownloadActions', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockShowConfirmDialog.mockReturnValue(
|
||||
{} as ReturnType<typeof showConfirmDialog>
|
||||
)
|
||||
})
|
||||
|
||||
it('pauses a download by id', async () => {
|
||||
mockStore.pause.mockResolvedValue(undefined)
|
||||
const { pause } = useModelDownloadActions()
|
||||
|
||||
await pause(createDownload({ download_id: 'd1' }))
|
||||
|
||||
expect(mockStore.pause).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('resumes a download by id', async () => {
|
||||
mockStore.resume.mockResolvedValue(undefined)
|
||||
const { resume } = useModelDownloadActions()
|
||||
|
||||
await resume(createDownload({ download_id: 'd1' }))
|
||||
|
||||
expect(mockStore.resume).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('raises priority relative to the current value', async () => {
|
||||
mockStore.setPriority.mockResolvedValue(undefined)
|
||||
const { raisePriority } = useModelDownloadActions()
|
||||
|
||||
await raisePriority(createDownload({ download_id: 'd1', priority: 2 }), 1)
|
||||
|
||||
expect(mockStore.setPriority).toHaveBeenCalledWith('d1', 3)
|
||||
})
|
||||
|
||||
it('shows an error toast and re-hydrates the store when an action fails', async () => {
|
||||
mockStore.pause.mockRejectedValue(new Error('offline'))
|
||||
mockStore.hydrate.mockResolvedValue(undefined)
|
||||
const { pause } = useModelDownloadActions()
|
||||
|
||||
await pause(createDownload())
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error', detail: 'offline' })
|
||||
)
|
||||
expect(mockStore.hydrate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses the DownloadApiError message in the toast', async () => {
|
||||
mockStore.pause.mockRejectedValue(
|
||||
new DownloadApiError('nope', 'URL_NOT_ALLOWED', 400)
|
||||
)
|
||||
mockStore.hydrate.mockResolvedValue(undefined)
|
||||
const { pause } = useModelDownloadActions()
|
||||
|
||||
await pause(createDownload())
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: 'nope' })
|
||||
)
|
||||
})
|
||||
|
||||
it('swallows a failed re-hydrate after an action error', async () => {
|
||||
mockStore.pause.mockRejectedValue(new Error('offline'))
|
||||
mockStore.hydrate.mockRejectedValue(new Error('still offline'))
|
||||
const { pause } = useModelDownloadActions()
|
||||
|
||||
await expect(pause(createDownload())).resolves.toBeUndefined()
|
||||
})
|
||||
|
||||
describe('cancel', () => {
|
||||
it('opens a destructive confirm dialog and only cancels on confirm', async () => {
|
||||
mockStore.cancel.mockResolvedValue(undefined)
|
||||
const { cancel } = useModelDownloadActions()
|
||||
|
||||
cancel(createDownload({ download_id: 'd1' }))
|
||||
expect(capturedOptions().footerProps?.confirmVariant).toBe('destructive')
|
||||
|
||||
await capturedOptions().footerProps?.onConfirm?.()
|
||||
|
||||
expect(mockStore.cancel).toHaveBeenCalledWith('d1')
|
||||
expect(mockCloseDialog).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not cancel when the confirm dialog is dismissed', () => {
|
||||
const { cancel } = useModelDownloadActions()
|
||||
|
||||
cancel(createDownload({ download_id: 'd1' }))
|
||||
capturedOptions().footerProps?.onCancel?.()
|
||||
|
||||
expect(mockStore.cancel).not.toHaveBeenCalled()
|
||||
expect(mockCloseDialog).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,97 +0,0 @@
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
import type { DownloadStatus } from '../types'
|
||||
import { DownloadApiError } from '../types'
|
||||
|
||||
/**
|
||||
* Wraps download store mutations with user feedback: optimistic-friendly error
|
||||
* toasts and a confirmation prompt for the destructive cancel action (which
|
||||
* deletes the partial file).
|
||||
*/
|
||||
export function useModelDownloadActions() {
|
||||
const store = useModelDownloadStore()
|
||||
const dialogStore = useDialogStore()
|
||||
const { t } = useI18n()
|
||||
|
||||
function toastError(error: unknown) {
|
||||
const detail =
|
||||
error instanceof DownloadApiError || error instanceof Error
|
||||
? error.message
|
||||
: String(error)
|
||||
useToastStore().add({
|
||||
severity: 'error',
|
||||
summary: t('modelManager.actionFailed'),
|
||||
detail,
|
||||
life: 5000
|
||||
})
|
||||
}
|
||||
|
||||
async function run(action: () => Promise<void>) {
|
||||
try {
|
||||
await action()
|
||||
} catch (error) {
|
||||
toastError(error)
|
||||
// The store applies optimistic status/priority patches before the API
|
||||
// call; re-fetch the authoritative state so a failed mutation doesn't
|
||||
// leave the row stuck in the wrong local state.
|
||||
try {
|
||||
await store.hydrate()
|
||||
} catch {
|
||||
// Server unreachable; the next poll will reconcile.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const pause = (download: DownloadStatus) =>
|
||||
run(() => store.pause(download.download_id))
|
||||
|
||||
const resume = (download: DownloadStatus) =>
|
||||
run(() => store.resume(download.download_id))
|
||||
|
||||
const raisePriority = (download: DownloadStatus, delta: number) =>
|
||||
run(() =>
|
||||
store.setPriority(download.download_id, download.priority + delta)
|
||||
)
|
||||
|
||||
const remove = (download: DownloadStatus) =>
|
||||
run(() => store.remove(download.download_id))
|
||||
|
||||
const clearHistory = () => run(() => store.clearHistory())
|
||||
|
||||
function cancel(download: DownloadStatus) {
|
||||
const dialog = showConfirmDialog({
|
||||
headerProps: { title: t('modelManager.cancelConfirmTitle') },
|
||||
props: {
|
||||
promptText: t(
|
||||
'modelManager.cancelConfirmMessage',
|
||||
{ name: download.model_id },
|
||||
{ escapeParameter: false }
|
||||
)
|
||||
},
|
||||
footerProps: {
|
||||
confirmText: t('modelManager.cancelConfirm'),
|
||||
confirmVariant: 'destructive' as const,
|
||||
onCancel: () => dialogStore.closeDialog(dialog),
|
||||
onConfirm: async () => {
|
||||
dialogStore.closeDialog(dialog)
|
||||
await run(() => store.cancel(download.download_id))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
pause,
|
||||
resume,
|
||||
cancel,
|
||||
raisePriority,
|
||||
remove,
|
||||
clearHistory,
|
||||
toastError
|
||||
}
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
import { ref } from 'vue'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { CompletedDownload } from '../stores/modelDownloadStore'
|
||||
import { useModelDownloadEffects } from './useModelDownloadEffects'
|
||||
|
||||
const mockLastCompletedDownload = ref<CompletedDownload | null>(null)
|
||||
const mockRefreshModelFolder = vi.fn()
|
||||
const mockRefreshMissingModels = vi.fn()
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => ({
|
||||
get lastCompletedDownload() {
|
||||
return mockLastCompletedDownload.value
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelStore', () => ({
|
||||
useModelStore: () => ({ refreshModelFolder: mockRefreshModelFolder })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/missingModel/missingModelStore', () => ({
|
||||
useMissingModelStore: () => ({
|
||||
refreshMissingModels: mockRefreshMissingModels
|
||||
})
|
||||
}))
|
||||
|
||||
describe('useModelDownloadEffects', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockLastCompletedDownload.value = null
|
||||
mockRefreshModelFolder.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
it('does nothing until a download completes', () => {
|
||||
useModelDownloadEffects()
|
||||
|
||||
expect(mockRefreshModelFolder).not.toHaveBeenCalled()
|
||||
expect(mockRefreshMissingModels).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('refreshes the model folder and re-scans missing models on completion', async () => {
|
||||
useModelDownloadEffects()
|
||||
|
||||
mockLastCompletedDownload.value = {
|
||||
downloadId: 'd1',
|
||||
modelId: 'loras/x.safetensors',
|
||||
directory: 'loras',
|
||||
timestamp: 1
|
||||
}
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshModelFolder).toHaveBeenCalledWith('loras')
|
||||
})
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('skips the folder refresh when the directory is unknown', async () => {
|
||||
useModelDownloadEffects()
|
||||
|
||||
mockLastCompletedDownload.value = {
|
||||
downloadId: 'd1',
|
||||
modelId: 'unknown.safetensors',
|
||||
directory: '',
|
||||
timestamp: 1
|
||||
}
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
})
|
||||
expect(mockRefreshModelFolder).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('still re-scans missing models when the folder refresh fails', async () => {
|
||||
const consoleWarn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
mockRefreshModelFolder.mockRejectedValue(new Error('boom'))
|
||||
useModelDownloadEffects()
|
||||
|
||||
mockLastCompletedDownload.value = {
|
||||
downloadId: 'd1',
|
||||
modelId: 'loras/x.safetensors',
|
||||
directory: 'loras',
|
||||
timestamp: 1
|
||||
}
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
})
|
||||
expect(consoleWarn).toHaveBeenCalled()
|
||||
consoleWarn.mockRestore()
|
||||
})
|
||||
})
|
||||
@@ -1,45 +0,0 @@
|
||||
import { watch } from 'vue'
|
||||
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
import { useModelStore } from '@/stores/modelStore'
|
||||
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
|
||||
/**
|
||||
* Side effects that run when a server-side download finishes: refresh the
|
||||
* affected model folder so the file is recognized, and re-scan missing models
|
||||
* so node errors for the now-present model clear automatically.
|
||||
*
|
||||
* Mounted once at the app root so it runs regardless of whether the Model
|
||||
* Manager panel is open.
|
||||
*/
|
||||
export function useModelDownloadEffects() {
|
||||
const store = useModelDownloadStore()
|
||||
|
||||
watch(
|
||||
() => store.lastCompletedDownload,
|
||||
async (completed) => {
|
||||
if (!completed) return
|
||||
|
||||
if (completed.directory) {
|
||||
try {
|
||||
await useModelStore().refreshModelFolder(completed.directory)
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'[ModelManager] Failed to refresh model folder after download',
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await useMissingModelStore().refreshMissingModels()
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'[ModelManager] Failed to re-scan missing models after download',
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const mockActiveDownloadCount = { value: 0 }
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => ({
|
||||
get activeDownloadCount() {
|
||||
return mockActiveDownloadCount.value
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('../components/ModelManagerSidebarTab.vue', () => ({
|
||||
default: { name: 'ModelManagerSidebarTab' }
|
||||
}))
|
||||
|
||||
import { useModelManagerSidebarTab } from './useModelManagerSidebarTab'
|
||||
|
||||
describe('useModelManagerSidebarTab', () => {
|
||||
it('returns the expected sidebar tab extension shape', () => {
|
||||
const tab = useModelManagerSidebarTab()
|
||||
|
||||
expect(tab.id).toBe('model-manager')
|
||||
expect(tab.type).toBe('vue')
|
||||
expect(tab.title).toBe('modelManager.title')
|
||||
expect(tab.tooltip).toBe('modelManager.title')
|
||||
expect(tab.label).toBe('modelManager.title')
|
||||
})
|
||||
|
||||
it('shows no badge when there are no active downloads', () => {
|
||||
mockActiveDownloadCount.value = 0
|
||||
const tab = useModelManagerSidebarTab()
|
||||
|
||||
expect(typeof tab.iconBadge).toBe('function')
|
||||
expect((tab.iconBadge as () => string | null)()).toBeNull()
|
||||
})
|
||||
|
||||
it('shows the active download count as a badge', () => {
|
||||
mockActiveDownloadCount.value = 3
|
||||
const tab = useModelManagerSidebarTab()
|
||||
|
||||
expect((tab.iconBadge as () => string | null)()).toBe('3')
|
||||
})
|
||||
})
|
||||
@@ -1,22 +0,0 @@
|
||||
import { markRaw } from 'vue'
|
||||
|
||||
import type { SidebarTabExtension } from '@/types/extensionTypes'
|
||||
|
||||
import ModelManagerSidebarTab from '../components/ModelManagerSidebarTab.vue'
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
|
||||
export function useModelManagerSidebarTab(): SidebarTabExtension {
|
||||
return {
|
||||
id: 'model-manager',
|
||||
icon: 'icon-[lucide--download]',
|
||||
title: 'modelManager.title',
|
||||
tooltip: 'modelManager.title',
|
||||
label: 'modelManager.title',
|
||||
component: markRaw(ModelManagerSidebarTab),
|
||||
type: 'vue',
|
||||
iconBadge: () => {
|
||||
const count = useModelDownloadStore().activeDownloadCount
|
||||
return count > 0 ? count.toString() : null
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import * as downloadApi from '../api/modelDownloadApi'
|
||||
import type { HostCredentialUpsert, HostCredentialView } from '../types'
|
||||
import { useHostCredentialsStore } from './hostCredentialsStore'
|
||||
|
||||
vi.mock('../api/modelDownloadApi', () => ({
|
||||
listCredentials: vi.fn(),
|
||||
upsertCredential: vi.fn(),
|
||||
deleteCredential: vi.fn()
|
||||
}))
|
||||
|
||||
function createCredential(
|
||||
overrides: Partial<HostCredentialView> = {}
|
||||
): HostCredentialView {
|
||||
return {
|
||||
id: 'c1',
|
||||
host: 'huggingface.co',
|
||||
auth_scheme: 'bearer',
|
||||
header_name: null,
|
||||
query_param: null,
|
||||
label: null,
|
||||
match_subdomains: false,
|
||||
enabled: true,
|
||||
secret_last4: '1234',
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
describe('useHostCredentialsStore', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('fetchCredentials', () => {
|
||||
it('toggles isLoading and stores the result', async () => {
|
||||
const credential = createCredential()
|
||||
vi.mocked(downloadApi.listCredentials).mockResolvedValue([credential])
|
||||
const store = useHostCredentialsStore()
|
||||
|
||||
const promise = store.fetchCredentials()
|
||||
expect(store.isLoading).toBe(true)
|
||||
await promise
|
||||
|
||||
expect(store.isLoading).toBe(false)
|
||||
expect(store.credentials).toEqual([credential])
|
||||
})
|
||||
|
||||
it('clears isLoading even when the request fails', async () => {
|
||||
vi.mocked(downloadApi.listCredentials).mockRejectedValue(
|
||||
new Error('boom')
|
||||
)
|
||||
const store = useHostCredentialsStore()
|
||||
|
||||
await expect(store.fetchCredentials()).rejects.toThrow('boom')
|
||||
expect(store.isLoading).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('upsert', () => {
|
||||
it('normalizes the host and appends a new credential', async () => {
|
||||
const created = createCredential({ id: 'new', host: 'huggingface.co' })
|
||||
vi.mocked(downloadApi.upsertCredential).mockResolvedValue(created)
|
||||
const store = useHostCredentialsStore()
|
||||
const body: HostCredentialUpsert = {
|
||||
host: ' HuggingFace.co ',
|
||||
secret: 's3cret'
|
||||
}
|
||||
|
||||
const result = await store.upsert(body)
|
||||
|
||||
expect(downloadApi.upsertCredential).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ host: 'huggingface.co', secret: 's3cret' })
|
||||
)
|
||||
expect(result).toEqual(created)
|
||||
expect(store.credentials).toEqual([created])
|
||||
})
|
||||
|
||||
it('replaces an existing credential by id', async () => {
|
||||
const original = createCredential({ id: 'c1', label: 'Old' })
|
||||
const updated = createCredential({ id: 'c1', label: 'New' })
|
||||
vi.mocked(downloadApi.upsertCredential).mockResolvedValue(updated)
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(original)
|
||||
|
||||
await store.upsert({ host: 'huggingface.co', secret: 's' })
|
||||
|
||||
expect(store.credentials).toEqual([updated])
|
||||
})
|
||||
})
|
||||
|
||||
describe('remove', () => {
|
||||
it('deletes and filters out the credential by id', async () => {
|
||||
vi.mocked(downloadApi.deleteCredential).mockResolvedValue(undefined)
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({ id: 'c1' }),
|
||||
createCredential({ id: 'c2' })
|
||||
)
|
||||
|
||||
await store.remove('c1')
|
||||
|
||||
expect(downloadApi.deleteCredential).toHaveBeenCalledWith('c1')
|
||||
expect(store.credentials.map((c) => c.id)).toEqual(['c2'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('enabledCredentialForHost', () => {
|
||||
it('matches an exact enabled host case-insensitively', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({ host: 'HuggingFace.co', enabled: true })
|
||||
)
|
||||
|
||||
expect(store.enabledCredentialForHost('huggingface.co')?.host).toBe(
|
||||
'HuggingFace.co'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns undefined for a disabled exact match', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({ host: 'huggingface.co', enabled: false })
|
||||
)
|
||||
|
||||
expect(store.enabledCredentialForHost('huggingface.co')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('falls back to a subdomain match only when match_subdomains is set', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({
|
||||
host: 'huggingface.co',
|
||||
match_subdomains: true,
|
||||
enabled: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(store.enabledCredentialForHost('cdn.huggingface.co')?.host).toBe(
|
||||
'huggingface.co'
|
||||
)
|
||||
})
|
||||
|
||||
it('skips a disabled subdomain credential and returns a later enabled match', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({
|
||||
id: 'disabled',
|
||||
host: 'huggingface.co',
|
||||
match_subdomains: true,
|
||||
enabled: false
|
||||
}),
|
||||
createCredential({
|
||||
id: 'enabled',
|
||||
host: 'huggingface.co',
|
||||
match_subdomains: true,
|
||||
enabled: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(store.enabledCredentialForHost('cdn.huggingface.co')?.id).toBe(
|
||||
'enabled'
|
||||
)
|
||||
})
|
||||
|
||||
it('does not subdomain-match when match_subdomains is false', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({
|
||||
host: 'huggingface.co',
|
||||
match_subdomains: false,
|
||||
enabled: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
store.enabledCredentialForHost('cdn.huggingface.co')
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined when no host matches', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
expect(store.enabledCredentialForHost('example.com')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,75 +0,0 @@
|
||||
import { defineStore } from 'pinia'
|
||||
import { computed, ref } from 'vue'
|
||||
|
||||
import {
|
||||
deleteCredential,
|
||||
listCredentials,
|
||||
upsertCredential
|
||||
} from '../api/modelDownloadApi'
|
||||
import type { HostCredentialUpsert, HostCredentialView } from '../types'
|
||||
|
||||
function normalizeHost(host: string): string {
|
||||
return host.trim().toLowerCase()
|
||||
}
|
||||
|
||||
export const useHostCredentialsStore = defineStore('hostCredentials', () => {
|
||||
const credentials = ref<HostCredentialView[]>([])
|
||||
const isLoading = ref(false)
|
||||
|
||||
const credentialsByHost = computed(
|
||||
() => new Map(credentials.value.map((c) => [normalizeHost(c.host), c]))
|
||||
)
|
||||
|
||||
async function fetchCredentials() {
|
||||
isLoading.value = true
|
||||
try {
|
||||
credentials.value = await listCredentials()
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function upsert(
|
||||
body: HostCredentialUpsert
|
||||
): Promise<HostCredentialView> {
|
||||
const view = await upsertCredential({
|
||||
...body,
|
||||
host: normalizeHost(body.host)
|
||||
})
|
||||
const index = credentials.value.findIndex((c) => c.id === view.id)
|
||||
credentials.value =
|
||||
index === -1
|
||||
? [...credentials.value, view]
|
||||
: credentials.value.map((c) => (c.id === view.id ? view : c))
|
||||
return view
|
||||
}
|
||||
|
||||
async function remove(id: string) {
|
||||
await deleteCredential(id)
|
||||
credentials.value = credentials.value.filter((c) => c.id !== id)
|
||||
}
|
||||
|
||||
function enabledCredentialForHost(
|
||||
host: string
|
||||
): HostCredentialView | undefined {
|
||||
const normalized = normalizeHost(host)
|
||||
const exact = credentialsByHost.value.get(normalized)
|
||||
if (exact) return exact.enabled ? exact : undefined
|
||||
|
||||
return credentials.value.find(
|
||||
(c) =>
|
||||
c.enabled &&
|
||||
c.match_subdomains &&
|
||||
normalized.endsWith(`.${normalizeHost(c.host)}`)
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
credentials,
|
||||
isLoading,
|
||||
fetchCredentials,
|
||||
upsert,
|
||||
remove,
|
||||
enabledCredentialForHost
|
||||
}
|
||||
})
|
||||
@@ -1,388 +0,0 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const flushPromises = () => vi.advanceTimersByTimeAsync(0)
|
||||
|
||||
import * as downloadApi from '../api/modelDownloadApi'
|
||||
import type { DownloadState, DownloadStatus } from '../types'
|
||||
import {
|
||||
downloadProgressFraction,
|
||||
useModelDownloadStore
|
||||
} from './modelDownloadStore'
|
||||
|
||||
type ProgressHandler = (e: CustomEvent<DownloadStatus>) => void
|
||||
|
||||
const eventHandler = vi.hoisted(() => {
|
||||
const state: { current: ProgressHandler | null } = { current: null }
|
||||
return state
|
||||
})
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
addEventListener: vi.fn((_event: string, handler: ProgressHandler) => {
|
||||
eventHandler.current = handler
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../api/modelDownloadApi', () => ({
|
||||
enqueueDownload: vi.fn(),
|
||||
listDownloads: vi.fn(),
|
||||
pauseDownload: vi.fn(),
|
||||
resumeDownload: vi.fn(),
|
||||
cancelDownload: vi.fn(),
|
||||
setDownloadPriority: vi.fn(),
|
||||
deleteDownload: vi.fn(),
|
||||
clearDownloads: vi.fn()
|
||||
}))
|
||||
|
||||
function createStatus(overrides: Partial<DownloadStatus> = {}): DownloadStatus {
|
||||
return {
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
url: 'https://huggingface.co/x.safetensors',
|
||||
status: 'active',
|
||||
priority: 0,
|
||||
total_bytes: 1000,
|
||||
bytes_done: 250,
|
||||
progress: 0.25,
|
||||
speed_bps: 100,
|
||||
eta_seconds: 10,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: 1,
|
||||
updated_at: 2,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function dispatch(status: DownloadStatus) {
|
||||
if (!eventHandler.current) throw new Error('handler not registered')
|
||||
eventHandler.current(new CustomEvent('download_progress', { detail: status }))
|
||||
}
|
||||
|
||||
describe('useModelDownloadStore', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.useFakeTimers({ shouldAdvanceTime: false })
|
||||
vi.resetAllMocks()
|
||||
eventHandler.current = null
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('upserts rows from download_progress events', () => {
|
||||
const store = useModelDownloadStore()
|
||||
|
||||
dispatch(createStatus({ download_id: 'd1', bytes_done: 100 }))
|
||||
dispatch(createStatus({ download_id: 'd1', bytes_done: 900 }))
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'queued' }))
|
||||
|
||||
expect(store.downloadList).toHaveLength(2)
|
||||
expect(
|
||||
store.downloadList.find((d) => d.download_id === 'd1')?.bytes_done
|
||||
).toBe(900)
|
||||
})
|
||||
|
||||
it('splits active and terminal downloads', () => {
|
||||
const store = useModelDownloadStore()
|
||||
const states: DownloadState[] = [
|
||||
'queued',
|
||||
'active',
|
||||
'paused',
|
||||
'verifying',
|
||||
'completed',
|
||||
'failed',
|
||||
'cancelled'
|
||||
]
|
||||
states.forEach((status, idx) =>
|
||||
dispatch(createStatus({ download_id: `d${idx}`, status }))
|
||||
)
|
||||
|
||||
expect(store.activeDownloads.map((d) => d.status)).toEqual([
|
||||
'queued',
|
||||
'active',
|
||||
'paused',
|
||||
'verifying'
|
||||
])
|
||||
expect(store.historyDownloads.map((d) => d.status)).toEqual([
|
||||
'completed',
|
||||
'failed',
|
||||
'cancelled'
|
||||
])
|
||||
expect(store.activeDownloadCount).toBe(4)
|
||||
})
|
||||
|
||||
it('inserts an optimistic queued row on enqueue', async () => {
|
||||
vi.mocked(downloadApi.enqueueDownload).mockResolvedValue({
|
||||
download_id: 'new-id',
|
||||
accepted: true
|
||||
})
|
||||
const store = useModelDownloadStore()
|
||||
|
||||
const result = await store.enqueue({
|
||||
url: 'https://huggingface.co/x.safetensors',
|
||||
model_id: 'loras/x.safetensors'
|
||||
})
|
||||
|
||||
expect(result.download_id).toBe('new-id')
|
||||
const row = store.downloadList.find((d) => d.download_id === 'new-id')
|
||||
expect(row?.status).toBe('queued')
|
||||
expect(row?.model_id).toBe('loras/x.safetensors')
|
||||
})
|
||||
|
||||
it('optimistically updates status when pausing', async () => {
|
||||
vi.mocked(downloadApi.pauseDownload).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
await store.pause('d1')
|
||||
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
|
||||
'paused'
|
||||
)
|
||||
expect(downloadApi.pauseDownload).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('optimistically updates status when resuming', async () => {
|
||||
vi.mocked(downloadApi.resumeDownload).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'paused' }))
|
||||
|
||||
await store.resume('d1')
|
||||
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
|
||||
'queued'
|
||||
)
|
||||
expect(downloadApi.resumeDownload).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('marks a download cancelled after the API call resolves', async () => {
|
||||
vi.mocked(downloadApi.cancelDownload).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
await store.cancel('d1')
|
||||
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
|
||||
'cancelled'
|
||||
)
|
||||
expect(downloadApi.cancelDownload).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('optimistically updates priority and calls the API', async () => {
|
||||
vi.mocked(downloadApi.setDownloadPriority).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'queued', priority: 0 }))
|
||||
|
||||
await store.setPriority('d1', 5)
|
||||
|
||||
expect(
|
||||
store.downloadList.find((d) => d.download_id === 'd1')?.priority
|
||||
).toBe(5)
|
||||
expect(downloadApi.setDownloadPriority).toHaveBeenCalledWith('d1', 5)
|
||||
})
|
||||
|
||||
it('is a no-op when patching priority for an unknown id', async () => {
|
||||
vi.mocked(downloadApi.setDownloadPriority).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
|
||||
await store.setPriority('missing', 5)
|
||||
|
||||
expect(store.downloadList).toHaveLength(0)
|
||||
expect(downloadApi.setDownloadPriority).toHaveBeenCalledWith('missing', 5)
|
||||
})
|
||||
|
||||
it('finds a download by model id', () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(
|
||||
createStatus({ download_id: 'd1', model_id: 'loras/x.safetensors' })
|
||||
)
|
||||
|
||||
expect(store.findByModelId('loras/x.safetensors')?.download_id).toBe('d1')
|
||||
expect(store.findByModelId('loras/missing.safetensors')).toBeUndefined()
|
||||
})
|
||||
|
||||
describe('hydrate', () => {
|
||||
it('replaces the download map with the fetched list', async () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'stale', status: 'active' }))
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
|
||||
createStatus({ download_id: 'fresh', status: 'active' })
|
||||
])
|
||||
|
||||
await store.hydrate()
|
||||
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['fresh'])
|
||||
})
|
||||
|
||||
it('records a completion when a refreshed row transitions to completed', async () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
|
||||
createStatus({
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
status: 'completed'
|
||||
})
|
||||
])
|
||||
|
||||
await store.hydrate()
|
||||
|
||||
expect(store.lastCompletedDownload).toMatchObject({
|
||||
downloadId: 'd1',
|
||||
modelId: 'loras/x.safetensors',
|
||||
directory: 'loras'
|
||||
})
|
||||
})
|
||||
|
||||
it('does not re-record a completion for a row that was already completed', async () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'completed' }))
|
||||
const firstTimestamp = store.lastCompletedDownload?.timestamp
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
|
||||
createStatus({ download_id: 'd1', status: 'completed' })
|
||||
])
|
||||
|
||||
await store.hydrate()
|
||||
|
||||
expect(store.lastCompletedDownload?.timestamp).toBe(firstTimestamp)
|
||||
})
|
||||
})
|
||||
|
||||
it('records the last completed download once on the completing transition', () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
expect(store.lastCompletedDownload).toBeNull()
|
||||
|
||||
dispatch(
|
||||
createStatus({
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
status: 'completed'
|
||||
})
|
||||
)
|
||||
|
||||
expect(store.lastCompletedDownload).toMatchObject({
|
||||
downloadId: 'd1',
|
||||
modelId: 'loras/x.safetensors',
|
||||
directory: 'loras'
|
||||
})
|
||||
|
||||
const firstTimestamp = store.lastCompletedDownload?.timestamp
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'completed' }))
|
||||
expect(store.lastCompletedDownload?.timestamp).toBe(firstTimestamp)
|
||||
})
|
||||
|
||||
it('refetches when a download fails without an error message', async () => {
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
|
||||
createStatus({ download_id: 'd1', status: 'failed', error: 'gated' })
|
||||
])
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'failed', error: null }))
|
||||
await flushPromises()
|
||||
|
||||
expect(downloadApi.listDownloads).toHaveBeenCalledOnce()
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.error).toBe(
|
||||
'gated'
|
||||
)
|
||||
})
|
||||
|
||||
it('does not refetch when the failure event already carries an error', async () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
dispatch(
|
||||
createStatus({ download_id: 'd1', status: 'failed', error: 'disk full' })
|
||||
)
|
||||
await flushPromises()
|
||||
|
||||
expect(downloadApi.listDownloads).not.toHaveBeenCalled()
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.error).toBe(
|
||||
'disk full'
|
||||
)
|
||||
})
|
||||
|
||||
it('deletes a row through the backend so it stays gone', async () => {
|
||||
vi.mocked(downloadApi.deleteDownload).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
|
||||
|
||||
await store.remove('d2')
|
||||
|
||||
expect(downloadApi.deleteDownload).toHaveBeenCalledWith('d2')
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d1'])
|
||||
})
|
||||
|
||||
it('clears every history row in one backend call, leaving active downloads', async () => {
|
||||
vi.mocked(downloadApi.clearDownloads).mockResolvedValue(2)
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
|
||||
dispatch(createStatus({ download_id: 'd3', status: 'failed' }))
|
||||
|
||||
await store.clearHistory()
|
||||
|
||||
expect(downloadApi.clearDownloads).toHaveBeenCalledOnce()
|
||||
expect(downloadApi.deleteDownload).not.toHaveBeenCalled()
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d1'])
|
||||
})
|
||||
|
||||
it('keeps history rows locally when the bulk clear fails', async () => {
|
||||
vi.mocked(downloadApi.clearDownloads).mockRejectedValue(new Error('boom'))
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
|
||||
|
||||
await expect(store.clearHistory()).rejects.toThrow('boom')
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d2'])
|
||||
})
|
||||
|
||||
it('keeps the row locally when the backend delete fails', async () => {
|
||||
vi.mocked(downloadApi.deleteDownload).mockRejectedValue(new Error('boom'))
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
|
||||
|
||||
await expect(store.remove('d2')).rejects.toThrow('boom')
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d2'])
|
||||
})
|
||||
|
||||
it('polls the list when active downloads go stale', async () => {
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([])
|
||||
useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
await vi.advanceTimersByTimeAsync(15_000)
|
||||
|
||||
expect(downloadApi.listDownloads).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
describe('downloadProgressFraction', () => {
|
||||
it('uses live progress when present', () => {
|
||||
expect(downloadProgressFraction(createStatus({ progress: 0.4 }))).toBe(
|
||||
0.4
|
||||
)
|
||||
})
|
||||
|
||||
it('derives from bytes when progress is null', () => {
|
||||
expect(
|
||||
downloadProgressFraction(
|
||||
createStatus({ progress: null, bytes_done: 500, total_bytes: 1000 })
|
||||
)
|
||||
).toBe(0.5)
|
||||
})
|
||||
|
||||
it('returns null when total is unknown', () => {
|
||||
expect(
|
||||
downloadProgressFraction(
|
||||
createStatus({ progress: null, total_bytes: null })
|
||||
)
|
||||
).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,229 +0,0 @@
|
||||
import { useIntervalFn } from '@vueuse/core'
|
||||
import { defineStore } from 'pinia'
|
||||
import { computed, ref } from 'vue'
|
||||
|
||||
import { api } from '@/scripts/api'
|
||||
|
||||
import {
|
||||
cancelDownload,
|
||||
clearDownloads,
|
||||
deleteDownload,
|
||||
enqueueDownload,
|
||||
listDownloads,
|
||||
pauseDownload,
|
||||
resumeDownload,
|
||||
setDownloadPriority
|
||||
} from '../api/modelDownloadApi'
|
||||
import type {
|
||||
DownloadState,
|
||||
DownloadStatus,
|
||||
EnqueueRequest,
|
||||
EnqueueResponse
|
||||
} from '../types'
|
||||
import { directoryOf } from '../utils/modelId'
|
||||
|
||||
const ACTIVE_STATES: ReadonlySet<DownloadState> = new Set([
|
||||
'queued',
|
||||
'active',
|
||||
'paused',
|
||||
'verifying'
|
||||
])
|
||||
|
||||
const POLL_INTERVAL_MS = 5_000
|
||||
const STALE_THRESHOLD_MS = 8_000
|
||||
|
||||
/**
|
||||
* Static completion fraction (0..1) for a download. Live values (`progress`)
|
||||
* are only populated while a worker is running; for paused/terminal rows we
|
||||
* derive it from `bytes_done / total_bytes`.
|
||||
*/
|
||||
export function downloadProgressFraction(
|
||||
download: DownloadStatus
|
||||
): number | null {
|
||||
if (download.progress != null) return download.progress
|
||||
if (download.total_bytes && download.total_bytes > 0) {
|
||||
return download.bytes_done / download.total_bytes
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function optimisticRow(
|
||||
downloadId: string,
|
||||
request: EnqueueRequest
|
||||
): DownloadStatus {
|
||||
const now = Math.floor(Date.now() / 1000)
|
||||
return {
|
||||
download_id: downloadId,
|
||||
model_id: request.model_id,
|
||||
url: request.url,
|
||||
status: 'queued',
|
||||
priority: request.priority ?? 0,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
progress: null,
|
||||
speed_bps: null,
|
||||
eta_seconds: null,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
}
|
||||
|
||||
export interface CompletedDownload {
|
||||
downloadId: string
|
||||
modelId: string
|
||||
directory: string
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
export const useModelDownloadStore = defineStore('modelDownload', () => {
|
||||
const downloads = ref<Map<string, DownloadStatus>>(new Map())
|
||||
const lastWsUpdate = ref(0)
|
||||
const lastCompletedDownload = ref<CompletedDownload | null>(null)
|
||||
|
||||
const downloadList = computed(() => Array.from(downloads.value.values()))
|
||||
const activeDownloads = computed(() =>
|
||||
downloadList.value.filter((d) => ACTIVE_STATES.has(d.status))
|
||||
)
|
||||
const historyDownloads = computed(() =>
|
||||
downloadList.value.filter((d) => !ACTIVE_STATES.has(d.status))
|
||||
)
|
||||
const hasActiveDownloads = computed(() => activeDownloads.value.length > 0)
|
||||
const activeDownloadCount = computed(() => activeDownloads.value.length)
|
||||
|
||||
function upsert(status: DownloadStatus) {
|
||||
downloads.value.set(status.download_id, status)
|
||||
}
|
||||
|
||||
function findByModelId(modelId: string): DownloadStatus | undefined {
|
||||
return downloadList.value.find((d) => d.model_id === modelId)
|
||||
}
|
||||
|
||||
function handleProgress(e: CustomEvent<DownloadStatus>) {
|
||||
lastWsUpdate.value = Date.now()
|
||||
const previous = downloads.value.get(e.detail.download_id)
|
||||
upsert(e.detail)
|
||||
if (e.detail.status === 'completed' && previous?.status !== 'completed') {
|
||||
lastCompletedDownload.value = {
|
||||
downloadId: e.detail.download_id,
|
||||
modelId: e.detail.model_id,
|
||||
directory: directoryOf(e.detail.model_id),
|
||||
timestamp: Date.now()
|
||||
}
|
||||
}
|
||||
// The live failure event carries the terminal status but not the error
|
||||
// text, and polling stops once nothing is active. Refetch so the failure
|
||||
// reason surfaces without the user reopening the panel.
|
||||
if (
|
||||
e.detail.status === 'failed' &&
|
||||
previous?.status !== 'failed' &&
|
||||
!e.detail.error
|
||||
) {
|
||||
void hydrate().catch(() => {})
|
||||
}
|
||||
}
|
||||
|
||||
async function hydrate() {
|
||||
const previous = downloads.value
|
||||
const list = await listDownloads()
|
||||
downloads.value = new Map(list.map((d) => [d.download_id, d]))
|
||||
for (const download of list) {
|
||||
const prior = previous.get(download.download_id)
|
||||
if (
|
||||
download.status === 'completed' &&
|
||||
prior &&
|
||||
prior.status !== 'completed'
|
||||
) {
|
||||
lastCompletedDownload.value = {
|
||||
downloadId: download.download_id,
|
||||
modelId: download.model_id,
|
||||
directory: directoryOf(download.model_id),
|
||||
timestamp: Date.now()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function enqueue(request: EnqueueRequest): Promise<EnqueueResponse> {
|
||||
const response = await enqueueDownload(request)
|
||||
upsert(optimisticRow(response.download_id, request))
|
||||
return response
|
||||
}
|
||||
|
||||
function patchStatus(id: string, status: DownloadState) {
|
||||
const existing = downloads.value.get(id)
|
||||
if (existing) {
|
||||
downloads.value.set(id, { ...existing, status })
|
||||
}
|
||||
}
|
||||
|
||||
async function pause(id: string) {
|
||||
patchStatus(id, 'paused')
|
||||
await pauseDownload(id)
|
||||
}
|
||||
|
||||
async function resume(id: string) {
|
||||
patchStatus(id, 'queued')
|
||||
await resumeDownload(id)
|
||||
}
|
||||
|
||||
async function cancel(id: string) {
|
||||
await cancelDownload(id)
|
||||
patchStatus(id, 'cancelled')
|
||||
}
|
||||
|
||||
async function setPriority(id: string, priority: number) {
|
||||
const existing = downloads.value.get(id)
|
||||
if (existing) {
|
||||
downloads.value.set(id, { ...existing, priority })
|
||||
}
|
||||
await setDownloadPriority(id, priority)
|
||||
}
|
||||
|
||||
async function remove(id: string) {
|
||||
await deleteDownload(id)
|
||||
downloads.value.delete(id)
|
||||
}
|
||||
|
||||
async function clearHistory() {
|
||||
const clearedIds = historyDownloads.value.map((d) => d.download_id)
|
||||
await clearDownloads()
|
||||
for (const id of clearedIds) {
|
||||
downloads.value.delete(id)
|
||||
}
|
||||
}
|
||||
|
||||
async function pollIfStale() {
|
||||
if (!hasActiveDownloads.value) return
|
||||
if (Date.now() - lastWsUpdate.value < STALE_THRESHOLD_MS) return
|
||||
try {
|
||||
await hydrate()
|
||||
} catch {
|
||||
// Server unreachable; retry on next interval
|
||||
}
|
||||
}
|
||||
|
||||
useIntervalFn(() => void pollIfStale(), POLL_INTERVAL_MS)
|
||||
|
||||
api.addEventListener('download_progress', handleProgress)
|
||||
|
||||
return {
|
||||
downloadList,
|
||||
activeDownloads,
|
||||
historyDownloads,
|
||||
hasActiveDownloads,
|
||||
activeDownloadCount,
|
||||
lastCompletedDownload,
|
||||
upsert,
|
||||
findByModelId,
|
||||
hydrate,
|
||||
enqueue,
|
||||
pause,
|
||||
resume,
|
||||
cancel,
|
||||
setPriority,
|
||||
remove,
|
||||
clearHistory
|
||||
}
|
||||
})
|
||||
@@ -1,130 +0,0 @@
|
||||
import { z } from 'zod'
|
||||
|
||||
import type { DownloadState, DownloadStatus } from '@/schemas/apiSchema'
|
||||
|
||||
export type { DownloadState, DownloadStatus }
|
||||
|
||||
/**
|
||||
* Known model file extensions accepted by the backend without
|
||||
* `allow_any_extension`. Mirrors the server's extension allowlist and is
|
||||
* used for instant client-side validation only — the server stays the source
|
||||
* of truth.
|
||||
*/
|
||||
export const MODEL_EXTENSIONS = [
|
||||
'.safetensors',
|
||||
'.sft',
|
||||
'.ckpt',
|
||||
'.pth',
|
||||
'.pt',
|
||||
'.gguf',
|
||||
'.bin'
|
||||
] as const
|
||||
|
||||
/**
|
||||
* Hosts the backend allows out of the box. Admins can extend this
|
||||
* server-side, so this list is only for optimistic client-side hints; rely on
|
||||
* the server's `URL_NOT_ALLOWED` / `url_allowed` as the source of truth.
|
||||
*/
|
||||
export const DEFAULT_ALLOWED_HOSTS = [
|
||||
'huggingface.co',
|
||||
'civitai.com',
|
||||
'localhost',
|
||||
'127.0.0.1'
|
||||
] as const
|
||||
|
||||
export interface EnqueueRequest {
|
||||
url: string
|
||||
model_id: string
|
||||
priority?: number
|
||||
expected_sha256?: string | null
|
||||
allow_any_extension?: boolean
|
||||
credential_id?: string | null
|
||||
}
|
||||
|
||||
export interface EnqueueResponse {
|
||||
download_id: string
|
||||
accepted: boolean
|
||||
}
|
||||
|
||||
export const AUTH_SCHEMES = ['bearer', 'header', 'query'] as const
|
||||
export type AuthScheme = (typeof AUTH_SCHEMES)[number]
|
||||
|
||||
const zHostCredentialView = z.object({
|
||||
id: z.string(),
|
||||
host: z.string(),
|
||||
auth_scheme: z.enum(AUTH_SCHEMES),
|
||||
header_name: z.string().nullable(),
|
||||
query_param: z.string().nullable(),
|
||||
label: z.string().nullable(),
|
||||
match_subdomains: z.boolean(),
|
||||
enabled: z.boolean(),
|
||||
secret_last4: z.string().nullable(),
|
||||
created_at: z.number(),
|
||||
updated_at: z.number()
|
||||
})
|
||||
export type HostCredentialView = z.infer<typeof zHostCredentialView>
|
||||
|
||||
export interface HostCredentialUpsert {
|
||||
host: string
|
||||
secret: string
|
||||
auth_scheme?: AuthScheme
|
||||
header_name?: string | null
|
||||
query_param?: string | null
|
||||
label?: string | null
|
||||
match_subdomains?: boolean
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
interface AvailabilityBase {
|
||||
url_allowed: boolean
|
||||
}
|
||||
|
||||
export type AvailabilityEntry =
|
||||
| (AvailabilityBase & { state: 'available' })
|
||||
| (AvailabilityBase & { state: 'missing' })
|
||||
| (AvailabilityBase & {
|
||||
state: 'downloading'
|
||||
download_id: string
|
||||
progress: number | null
|
||||
bytes_done: number
|
||||
total_bytes: number | null
|
||||
speed_bps: number | null
|
||||
})
|
||||
|
||||
export interface AvailabilityResponse {
|
||||
models: Record<string, AvailabilityEntry>
|
||||
}
|
||||
|
||||
const DOWNLOAD_ERROR_CODES = [
|
||||
'INVALID_JSON',
|
||||
'INVALID_BODY',
|
||||
'URL_NOT_ALLOWED',
|
||||
'INVALID_MODEL_ID',
|
||||
'INVALID_CREDENTIAL',
|
||||
'ALREADY_AVAILABLE',
|
||||
'ALREADY_DOWNLOADING',
|
||||
'DOWNLOAD_ACTIVE',
|
||||
'NOT_FOUND'
|
||||
] as const
|
||||
export type DownloadErrorCode = (typeof DOWNLOAD_ERROR_CODES)[number]
|
||||
|
||||
/**
|
||||
* Error envelope returned by every download-manager endpoint on failure.
|
||||
* `code` is the stable machine-readable discriminator; `message` is
|
||||
* user-facing; `details` is an open object (do not assume a shape).
|
||||
*/
|
||||
export class DownloadApiError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public readonly code: string,
|
||||
public readonly status: number,
|
||||
public readonly details?: Record<string, unknown>
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'DownloadApiError'
|
||||
}
|
||||
|
||||
is(code: DownloadErrorCode): boolean {
|
||||
return this.code === code
|
||||
}
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
buildModelId,
|
||||
filenameFromUrl,
|
||||
hasModelExtension,
|
||||
hostFromUrl,
|
||||
isLikelyAllowedHost,
|
||||
isValidPathSegment
|
||||
} from './modelId'
|
||||
|
||||
describe('modelId utils', () => {
|
||||
describe('isValidPathSegment', () => {
|
||||
it('accepts valid filenames', () => {
|
||||
expect(isValidPathSegment('my_lora.safetensors')).toBe(true)
|
||||
expect(isValidPathSegment('sdxl-base.ckpt')).toBe(true)
|
||||
})
|
||||
|
||||
it('rejects traversal, leading dots, and slashes', () => {
|
||||
expect(isValidPathSegment('..')).toBe(false)
|
||||
expect(isValidPathSegment('.hidden')).toBe(false)
|
||||
expect(isValidPathSegment('nested/path')).toBe(false)
|
||||
expect(isValidPathSegment('')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasModelExtension', () => {
|
||||
it('matches known extensions case-insensitively', () => {
|
||||
expect(hasModelExtension('model.SAFETENSORS')).toBe(true)
|
||||
expect(hasModelExtension('weights.gguf')).toBe(true)
|
||||
})
|
||||
|
||||
it('rejects unknown extensions', () => {
|
||||
expect(hasModelExtension('readme.txt')).toBe(false)
|
||||
expect(hasModelExtension('archive.zip')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildModelId', () => {
|
||||
it('joins directory and filename with a single slash', () => {
|
||||
expect(buildModelId('loras', 'x.safetensors')).toBe('loras/x.safetensors')
|
||||
})
|
||||
|
||||
it('collapses redundant slashes at the join boundary', () => {
|
||||
expect(buildModelId('loras/', 'x.safetensors')).toBe(
|
||||
'loras/x.safetensors'
|
||||
)
|
||||
expect(buildModelId('loras', '/x.safetensors')).toBe(
|
||||
'loras/x.safetensors'
|
||||
)
|
||||
expect(buildModelId('loras//', '//x.safetensors')).toBe(
|
||||
'loras/x.safetensors'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('hostFromUrl', () => {
|
||||
it('extracts the lowercased host', () => {
|
||||
expect(hostFromUrl('https://HuggingFace.co/a/b')).toBe('huggingface.co')
|
||||
})
|
||||
|
||||
it('returns null for invalid URLs', () => {
|
||||
expect(hostFromUrl('not a url')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('isLikelyAllowedHost', () => {
|
||||
it('allows built-in hosts and their subdomains', () => {
|
||||
expect(isLikelyAllowedHost('https://huggingface.co/x')).toBe(true)
|
||||
expect(isLikelyAllowedHost('https://cdn.huggingface.co/x')).toBe(true)
|
||||
expect(isLikelyAllowedHost('https://civitai.com/api/download/1')).toBe(
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('flags unknown hosts', () => {
|
||||
expect(isLikelyAllowedHost('https://example.com/x')).toBe(false)
|
||||
expect(isLikelyAllowedHost('garbage')).toBe(false)
|
||||
})
|
||||
|
||||
it('does not treat a fake subdomain of an allowed IP literal as allowed', () => {
|
||||
expect(isLikelyAllowedHost('https://127.0.0.1/x')).toBe(true)
|
||||
expect(isLikelyAllowedHost('https://evil.127.0.0.1/x')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('filenameFromUrl', () => {
|
||||
it('returns the decoded trailing path segment', () => {
|
||||
expect(filenameFromUrl('https://h.co/a/my%20model.safetensors')).toBe(
|
||||
'my model.safetensors'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns empty string when no segment is present', () => {
|
||||
expect(filenameFromUrl('https://h.co/')).toBe('')
|
||||
expect(filenameFromUrl('bad')).toBe('')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,79 +0,0 @@
|
||||
import { DEFAULT_ALLOWED_HOSTS, MODEL_EXTENSIONS } from '../types'
|
||||
|
||||
const SEGMENT_PATTERN = /^[A-Za-z0-9][A-Za-z0-9._-]*$/
|
||||
|
||||
export function isValidPathSegment(segment: string): boolean {
|
||||
return SEGMENT_PATTERN.test(segment)
|
||||
}
|
||||
|
||||
export function hasModelExtension(filename: string): boolean {
|
||||
const lower = filename.toLowerCase()
|
||||
return MODEL_EXTENSIONS.some((ext) => lower.endsWith(ext))
|
||||
}
|
||||
|
||||
export function buildModelId(directory: string, filename: string): string {
|
||||
return `${directory.replace(/\/+$/, '')}/${filename.replace(/^\/+/, '')}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Directory portion of a `directory/filename` model id, or an empty string when
|
||||
* the id has no directory prefix.
|
||||
*/
|
||||
export function directoryOf(modelId: string): string {
|
||||
const slash = modelId.indexOf('/')
|
||||
return slash === -1 ? '' : modelId.slice(0, slash)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filename portion of a `directory/filename` model id, falling back to the full
|
||||
* id when there is no directory prefix.
|
||||
*/
|
||||
export function filenameOf(modelId: string): string {
|
||||
const slash = modelId.indexOf('/')
|
||||
return slash === -1 ? modelId : modelId.slice(slash + 1)
|
||||
}
|
||||
|
||||
/**
|
||||
* Lowercased host of a URL, or `null` when the URL is unparseable.
|
||||
*/
|
||||
export function hostFromUrl(url: string): string | null {
|
||||
try {
|
||||
return new URL(url).hostname.toLowerCase()
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const IPV4_PATTERN = /^\d{1,3}(?:\.\d{1,3}){3}$/
|
||||
|
||||
function isIpLiteral(host: string): boolean {
|
||||
return IPV4_PATTERN.test(host) || host.includes(':')
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimistic client-side allowlist hint. The server can extend the
|
||||
* allowlist, so a `false` here is advisory — defer to `URL_NOT_ALLOWED`.
|
||||
*/
|
||||
export function isLikelyAllowedHost(url: string): boolean {
|
||||
const host = hostFromUrl(url)
|
||||
if (!host) return false
|
||||
return DEFAULT_ALLOWED_HOSTS.some(
|
||||
(allowed) =>
|
||||
host === allowed ||
|
||||
(!isIpLiteral(allowed) && host.endsWith(`.${allowed}`))
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Best-effort filename guess from a URL path, for prefilling the model_id
|
||||
* filename field. May be empty when the URL has no usable trailing segment.
|
||||
*/
|
||||
export function filenameFromUrl(url: string): string {
|
||||
try {
|
||||
const { pathname } = new URL(url)
|
||||
const last = pathname.split('/').filter(Boolean).pop() ?? ''
|
||||
return decodeURIComponent(last)
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,39 @@ describe('TelemetryRegistry', () => {
|
||||
expect(b.trackSearchQuery).toHaveBeenCalledExactlyOnceWith(payload)
|
||||
})
|
||||
|
||||
it('dispatches trackBeginCheckout with intent metadata to every provider', () => {
|
||||
const a: TelemetryProvider = { trackBeginCheckout: vi.fn() }
|
||||
const b: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(a)
|
||||
registry.registerProvider(b)
|
||||
|
||||
const metadata = {
|
||||
user_id: 'user-1',
|
||||
tier: 'pro' as const,
|
||||
cycle: 'monthly' as const,
|
||||
checkout_type: 'new' as const,
|
||||
payment_intent_source: 'subscribe_to_run' as const
|
||||
}
|
||||
registry.trackBeginCheckout(metadata)
|
||||
|
||||
expect(a.trackBeginCheckout).toHaveBeenCalledExactlyOnceWith(metadata)
|
||||
})
|
||||
|
||||
it('dispatches trackAddApiCreditButtonClicked with its source', () => {
|
||||
const provider: TelemetryProvider = {
|
||||
trackAddApiCreditButtonClicked: vi.fn()
|
||||
}
|
||||
const registry = new TelemetryRegistry()
|
||||
registry.registerProvider(provider)
|
||||
|
||||
registry.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(
|
||||
provider.trackAddApiCreditButtonClicked
|
||||
).toHaveBeenCalledExactlyOnceWith({ source: 'credits_panel' })
|
||||
})
|
||||
|
||||
it('skips providers that do not implement trackSearchQuery', () => {
|
||||
const empty: TelemetryProvider = {}
|
||||
const registry = new TelemetryRegistry()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -99,8 +100,10 @@ export class TelemetryRegistry implements TelemetryDispatcher {
|
||||
this.dispatch((provider) => provider.trackMonthlySubscriptionCancelled?.())
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.dispatch((provider) => provider.trackAddApiCreditButtonClicked?.())
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.dispatch((provider) =>
|
||||
provider.trackAddApiCreditButtonClicked?.(metadata)
|
||||
)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -313,6 +313,42 @@ describe('PostHogTelemetryProvider', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('captures begin_checkout with intent metadata', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackBeginCheckout({
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.BEGIN_CHECKOUT,
|
||||
{
|
||||
user_id: 'user-1',
|
||||
tier: 'pro',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'new',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
it('captures add-credit clicks with their source', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
provider.trackAddApiCreditButtonClicked({ source: 'credits_panel' })
|
||||
|
||||
expect(hoisted.mockCapture).toHaveBeenCalledWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'credits_panel' }
|
||||
)
|
||||
})
|
||||
|
||||
it('captures share attribution events', async () => {
|
||||
const provider = createProvider()
|
||||
await vi.dynamicImportSettled()
|
||||
|
||||
@@ -10,7 +10,9 @@ import { remoteConfig } from '@/platform/remoteConfig/remoteConfig'
|
||||
import type { RemoteConfig } from '@/platform/remoteConfig/types'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
EnterLinearMetadata,
|
||||
ShareFlowMetadata,
|
||||
@@ -350,8 +352,12 @@ export class PostHogTelemetryProvider implements TelemetryProvider {
|
||||
this.trackEvent(eventName, metadata)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackBeginCheckout(metadata: BeginCheckoutMetadata): void {
|
||||
this.trackEvent(TelemetryEvents.BEGIN_CHECKOUT, metadata)
|
||||
}
|
||||
|
||||
trackMonthlySubscriptionSucceeded(
|
||||
|
||||
@@ -115,6 +115,17 @@ describe('HostTelemetrySink', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('forwards add-credit clicks with their source', () => {
|
||||
new HostTelemetrySink().trackAddApiCreditButtonClicked({
|
||||
source: 'avatar_menu'
|
||||
})
|
||||
|
||||
expect(state.capture).toHaveBeenCalledExactlyOnceWith(
|
||||
TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED,
|
||||
{ source: 'avatar_menu' }
|
||||
)
|
||||
})
|
||||
|
||||
it('does nothing when the host bridge is absent', () => {
|
||||
delete window.__comfyDesktop2
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
|
||||
import type {
|
||||
AddCreditsClickMetadata,
|
||||
AuthMetadata,
|
||||
BeginCheckoutMetadata,
|
||||
DefaultViewSetMetadata,
|
||||
@@ -126,8 +127,8 @@ export class HostTelemetrySink implements TelemetryProvider {
|
||||
this.capture(TelemetryEvents.MONTHLY_SUBSCRIPTION_CANCELLED)
|
||||
}
|
||||
|
||||
trackAddApiCreditButtonClicked(): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED)
|
||||
trackAddApiCreditButtonClicked(metadata?: AddCreditsClickMetadata): void {
|
||||
this.capture(TelemetryEvents.ADD_API_CREDIT_BUTTON_CLICKED, metadata)
|
||||
}
|
||||
|
||||
trackApiCreditTopupButtonPurchaseClicked(amount: number): void {
|
||||
|
||||
@@ -12,12 +12,29 @@
|
||||
* 3. Check dist/assets/*.js files contain no tracking code
|
||||
*/
|
||||
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import type { AuditLog } from '@/services/customerEventsService'
|
||||
import type { AppMode } from '@/utils/appMode'
|
||||
|
||||
export type PaymentIntentSource =
|
||||
| 'subscription_required'
|
||||
| 'out_of_credits'
|
||||
| 'top_up_blocked'
|
||||
| 'deep_link'
|
||||
| 'subscribe_to_run'
|
||||
| 'subscribe_now_button'
|
||||
| 'upgrade_to_add_credits'
|
||||
| 'settings_billing_panel'
|
||||
| 'avatar_menu_plans'
|
||||
| 'team_members_panel'
|
||||
| 'invite_member_upsell'
|
||||
| 'upload_model_upgrade'
|
||||
| 'team_upgrade_resume'
|
||||
|
||||
export type SubscriptionCheckoutType = 'new' | 'change'
|
||||
export type SubscriptionCheckoutTier = TierKey | 'team'
|
||||
|
||||
/**
|
||||
* Authentication metadata for sign-up tracking
|
||||
*/
|
||||
@@ -426,16 +443,23 @@ export interface CheckoutAttributionMetadata {
|
||||
|
||||
export interface SubscriptionMetadata {
|
||||
current_tier?: string
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export interface AddCreditsClickMetadata {
|
||||
source: 'credits_panel' | 'avatar_menu' | 'settings_billing_panel'
|
||||
}
|
||||
|
||||
export interface BeginCheckoutMetadata
|
||||
extends Record<string, unknown>, CheckoutAttributionMetadata {
|
||||
user_id: string
|
||||
tier: TierKey
|
||||
tier: SubscriptionCheckoutTier
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
checkout_attempt_id?: string
|
||||
billing_op_id?: string
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
}
|
||||
|
||||
interface EcommerceItemMetadata {
|
||||
@@ -457,8 +481,9 @@ export interface SubscriptionSuccessMetadata extends Record<string, unknown> {
|
||||
checkout_attempt_id: string
|
||||
tier: TierKey
|
||||
cycle: BillingCycle
|
||||
checkout_type: 'new' | 'change'
|
||||
checkout_type: SubscriptionCheckoutType
|
||||
previous_tier?: TierKey
|
||||
payment_intent_source?: PaymentIntentSource
|
||||
value: number
|
||||
currency: string
|
||||
ecommerce: EcommerceMetadata
|
||||
@@ -489,7 +514,7 @@ export interface TelemetryProvider {
|
||||
metadata?: SubscriptionSuccessMetadata
|
||||
): void
|
||||
trackMonthlySubscriptionCancelled?(): void
|
||||
trackAddApiCreditButtonClicked?(): void
|
||||
trackAddApiCreditButtonClicked?(metadata?: AddCreditsClickMetadata): void
|
||||
trackApiCreditTopupButtonPurchaseClicked?(amount: number): void
|
||||
trackApiCreditTopupSucceeded?(): void
|
||||
trackWorkspaceInviteSent?(metadata: WorkspaceInviteMetadata): void
|
||||
|
||||
@@ -321,7 +321,7 @@ const handleOpenWorkspaceSettings = () => {
|
||||
}
|
||||
|
||||
const handleOpenPlansAndPricing = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'avatar_menu_plans' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -336,13 +336,12 @@ const handleOpenPlanAndCreditsSettings = () => {
|
||||
}
|
||||
|
||||
const handleUpgradeToAddCredits = () => {
|
||||
subscriptionDialog.showPricingTable()
|
||||
subscriptionDialog.showPricingTable({ reason: 'upgrade_to_add_credits' })
|
||||
emit('close')
|
||||
}
|
||||
|
||||
const handleTopUp = () => {
|
||||
// Track purchase credits entry from avatar popover
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked()
|
||||
useTelemetry()?.trackAddApiCreditButtonClicked({ source: 'avatar_menu' })
|
||||
dialogService.showTopUpCreditsDialog()
|
||||
emit('close')
|
||||
}
|
||||
|
||||
@@ -391,12 +391,13 @@ const showZeroState = computed(
|
||||
)
|
||||
|
||||
function handleSubscribeWorkspace() {
|
||||
showSubscriptionDialog()
|
||||
showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
}
|
||||
|
||||
function handleUpgrade() {
|
||||
if (isFreeTierPlan.value) showPricingTable()
|
||||
else showSubscriptionDialog()
|
||||
if (isFreeTierPlan.value)
|
||||
showPricingTable({ reason: 'settings_billing_panel' })
|
||||
else showSubscriptionDialog({ reason: 'settings_billing_panel' })
|
||||
}
|
||||
|
||||
function handleViewMoreDetails() {
|
||||
|
||||
@@ -113,7 +113,7 @@ import { cn } from '@comfyorg/tailwind-utils'
|
||||
import { useEventListener } from '@vueuse/core'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
|
||||
|
||||
import SubscriptionAddPaymentPreviewWorkspace from './SubscriptionAddPaymentPreviewWorkspace.vue'
|
||||
@@ -123,7 +123,7 @@ import UnifiedPricingTable from './UnifiedPricingTable.vue'
|
||||
|
||||
const { onClose, reason, initialPlanMode } = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
initialPlanMode?: 'personal' | 'team'
|
||||
}>()
|
||||
|
||||
@@ -152,7 +152,7 @@ const {
|
||||
handleConfirmTransition,
|
||||
handleTeamSubscribe,
|
||||
handleResubscribe
|
||||
} = useSubscriptionCheckout(emit)
|
||||
} = useSubscriptionCheckout(emit, reason)
|
||||
|
||||
// Backspace mirrors the back arrow on the confirm step, but never while an
|
||||
// editable element is focused (let it delete text there).
|
||||
|
||||
@@ -5,7 +5,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
|
||||
import SubscriptionRequiredDialogContentWorkspace from './SubscriptionRequiredDialogContentWorkspace.vue'
|
||||
|
||||
@@ -17,25 +17,10 @@ const mockHandleResubscribe = vi.fn()
|
||||
const mockHandleSuccessClose = vi.fn()
|
||||
const mockCheckoutStep = ref<'pricing' | 'preview' | 'success'>('pricing')
|
||||
const mockPreviewData = ref<{ transition_type: string } | null>(null)
|
||||
const mockUseSubscriptionCheckout = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useSubscriptionCheckout', () => ({
|
||||
useSubscriptionCheckout: () => ({
|
||||
checkoutStep: mockCheckoutStep,
|
||||
isLoadingPreview: ref(false),
|
||||
loadingTier: ref(null),
|
||||
isSubscribing: ref(false),
|
||||
isResubscribing: ref(false),
|
||||
previewData: mockPreviewData,
|
||||
selectedTierKey: ref('standard'),
|
||||
selectedBillingCycle: ref('yearly'),
|
||||
isPolling: ref(false),
|
||||
handleSubscribeClick: mockHandleSubscribeClick,
|
||||
handleBackToPricing: mockHandleBackToPricing,
|
||||
handleAddCreditCard: mockHandleAddCreditCard,
|
||||
handleConfirmTransition: mockHandleConfirmTransition,
|
||||
handleResubscribe: mockHandleResubscribe,
|
||||
handleSuccessClose: mockHandleSuccessClose
|
||||
})
|
||||
useSubscriptionCheckout: mockUseSubscriptionCheckout
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
@@ -91,7 +76,7 @@ const SuccessStub = {
|
||||
function renderComponent(
|
||||
props: {
|
||||
onClose?: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
isPersonal?: boolean
|
||||
} = {}
|
||||
) {
|
||||
@@ -121,6 +106,23 @@ function renderComponent(
|
||||
describe('SubscriptionRequiredDialogContentWorkspace', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseSubscriptionCheckout.mockReturnValue({
|
||||
checkoutStep: mockCheckoutStep,
|
||||
isLoadingPreview: ref(false),
|
||||
loadingTier: ref(null),
|
||||
isSubscribing: ref(false),
|
||||
isResubscribing: ref(false),
|
||||
previewData: mockPreviewData,
|
||||
selectedTierKey: ref('standard'),
|
||||
selectedBillingCycle: ref('yearly'),
|
||||
isPolling: ref(false),
|
||||
handleSubscribeClick: mockHandleSubscribeClick,
|
||||
handleBackToPricing: mockHandleBackToPricing,
|
||||
handleAddCreditCard: mockHandleAddCreditCard,
|
||||
handleConfirmTransition: mockHandleConfirmTransition,
|
||||
handleResubscribe: mockHandleResubscribe,
|
||||
handleSuccessClose: mockHandleSuccessClose
|
||||
})
|
||||
mockCheckoutStep.value = 'pricing'
|
||||
mockPreviewData.value = null
|
||||
})
|
||||
@@ -132,6 +134,15 @@ describe('SubscriptionRequiredDialogContentWorkspace', () => {
|
||||
expect(screen.queryByTestId('transition-preview')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('passes the reason into subscription checkout', () => {
|
||||
renderComponent({ reason: 'out_of_credits' })
|
||||
|
||||
expect(mockUseSubscriptionCheckout).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
'out_of_credits'
|
||||
)
|
||||
})
|
||||
|
||||
it('shows the team workspace header by default', () => {
|
||||
renderComponent()
|
||||
expect(screen.getByText('Team Workspace')).toBeInTheDocument()
|
||||
|
||||
@@ -116,7 +116,7 @@
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import { useSubscriptionCheckout } from '@/platform/workspace/composables/useSubscriptionCheckout'
|
||||
|
||||
import PricingTableWorkspace from './PricingTableWorkspace.vue'
|
||||
@@ -130,7 +130,7 @@ const {
|
||||
isPersonal = false
|
||||
} = defineProps<{
|
||||
onClose: () => void
|
||||
reason?: SubscriptionDialogReason
|
||||
reason?: PaymentIntentSource
|
||||
isPersonal?: boolean
|
||||
}>()
|
||||
|
||||
@@ -154,7 +154,7 @@ const {
|
||||
handleConfirmTransition,
|
||||
handleResubscribe,
|
||||
handleSuccessClose
|
||||
} = useSubscriptionCheckout(emit)
|
||||
} = useSubscriptionCheckout(emit, reason)
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -61,6 +61,9 @@ function onDismiss() {
|
||||
|
||||
function onUpgrade() {
|
||||
dialogStore.closeDialog({ key: 'invite-member-upsell' })
|
||||
subscriptionDialog.show({ planMode: 'team' })
|
||||
subscriptionDialog.show({
|
||||
planMode: 'team',
|
||||
reason: 'invite_member_upsell'
|
||||
})
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -277,7 +277,7 @@ export function useMembersPanel() {
|
||||
}
|
||||
|
||||
function showTeamPlans() {
|
||||
subscriptionDialog.show({ planMode: 'team' })
|
||||
subscriptionDialog.show({ planMode: 'team', reason: 'team_members_panel' })
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { computed } from 'vue'
|
||||
import { computed, reactive } from 'vue'
|
||||
|
||||
import type { PaymentIntentSource } from '@/platform/telemetry/types'
|
||||
import type { Plan } from '@/platform/workspace/api/workspaceApi'
|
||||
|
||||
import { findPlanSlug } from './useSubscriptionCheckout'
|
||||
@@ -75,7 +76,9 @@ const {
|
||||
mockPlans,
|
||||
mockResubscribe,
|
||||
mockToastAdd,
|
||||
mockStartOperation
|
||||
mockStartOperation,
|
||||
mockTrackBeginCheckout,
|
||||
mockUserId
|
||||
} = vi.hoisted(() => ({
|
||||
mockSubscribe: vi.fn(),
|
||||
mockPreviewSubscribe: vi.fn(),
|
||||
@@ -84,7 +87,9 @@ const {
|
||||
mockPlans: { value: [] as Plan[] },
|
||||
mockResubscribe: vi.fn(),
|
||||
mockToastAdd: vi.fn(),
|
||||
mockStartOperation: vi.fn()
|
||||
mockStartOperation: vi.fn(),
|
||||
mockTrackBeginCheckout: vi.fn(),
|
||||
mockUserId: { value: 'user-1' as string | null }
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
@@ -119,7 +124,14 @@ vi.mock('primevue/usetoast', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/telemetry', () => ({
|
||||
useTelemetry: () => ({ trackMonthlySubscriptionSucceeded: vi.fn() })
|
||||
useTelemetry: () => ({
|
||||
trackMonthlySubscriptionSucceeded: vi.fn(),
|
||||
trackBeginCheckout: mockTrackBeginCheckout
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: () => reactive({ userId: computed(() => mockUserId.value) })
|
||||
}))
|
||||
|
||||
vi.mock('vue-i18n', async (importOriginal) => {
|
||||
@@ -135,10 +147,10 @@ vi.mock('vue-i18n', async (importOriginal) => {
|
||||
describe('useSubscriptionCheckout', () => {
|
||||
let emit: ReturnType<typeof vi.fn>
|
||||
|
||||
async function setup() {
|
||||
async function setup(paymentIntentSource?: PaymentIntentSource) {
|
||||
const { useSubscriptionCheckout } =
|
||||
await import('./useSubscriptionCheckout')
|
||||
return useSubscriptionCheckout(emit as never)
|
||||
return useSubscriptionCheckout(emit as never, paymentIntentSource)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -146,6 +158,7 @@ describe('useSubscriptionCheckout', () => {
|
||||
vi.clearAllMocks()
|
||||
mockPlans.value = allPlans()
|
||||
mockStartOperation.mockResolvedValue({ status: 'succeeded' })
|
||||
mockUserId.value = 'user-1'
|
||||
emit = vi.fn()
|
||||
})
|
||||
|
||||
@@ -459,6 +472,13 @@ describe('useSubscriptionCheckout', () => {
|
||||
cancelUrl: 'https://platform.comfy.org/payment/failed'
|
||||
})
|
||||
expect(checkout.checkoutStep.value).toBe('success')
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tier: 'team',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op-team-1'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('uses the annual plan slug for the yearly cycle', async () => {
|
||||
@@ -553,6 +573,39 @@ describe('useSubscriptionCheckout', () => {
|
||||
detail: 'Team payment failed'
|
||||
})
|
||||
)
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('keeps team checkout_type as change when the preview request fails', async () => {
|
||||
const checkout = await setup()
|
||||
mockPreviewSubscribe.mockRejectedValueOnce(new Error('not supported'))
|
||||
await checkout.handleSubscribeTeamClick({
|
||||
stop: {
|
||||
id: 'team_1400',
|
||||
usd: 1400,
|
||||
credits: 295_400,
|
||||
discountedUsd: 1295
|
||||
},
|
||||
billingCycle: 'monthly',
|
||||
isChange: true
|
||||
})
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-team-change'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleTeamSubscribe()
|
||||
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tier: 'team',
|
||||
cycle: 'monthly',
|
||||
checkout_type: 'change',
|
||||
billing_op_id: 'op-team-change'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -603,6 +656,47 @@ describe('useSubscriptionCheckout', () => {
|
||||
expect(checkout.checkoutStep.value).toBe('success')
|
||||
})
|
||||
|
||||
it('skips begin_checkout when no user id is available', async () => {
|
||||
mockUserId.value = null
|
||||
const checkout = await setup('subscribe_to_run')
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
checkout.selectedBillingCycle.value = 'yearly'
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-1'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleAddCreditCard()
|
||||
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
mockUserId.value = 'user-1'
|
||||
})
|
||||
|
||||
it('fires begin_checkout carrying the payment intent source', async () => {
|
||||
const checkout = await setup('subscribe_to_run')
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
checkout.selectedBillingCycle.value = 'yearly'
|
||||
mockSubscribe.mockResolvedValueOnce({
|
||||
status: 'subscribed',
|
||||
billing_op_id: 'op-1'
|
||||
})
|
||||
mockFetchStatus.mockResolvedValueOnce(undefined)
|
||||
mockFetchBalance.mockResolvedValueOnce(undefined)
|
||||
|
||||
await checkout.handleAddCreditCard()
|
||||
|
||||
expect(mockTrackBeginCheckout).toHaveBeenCalledWith({
|
||||
user_id: 'user-1',
|
||||
tier: 'standard',
|
||||
cycle: 'yearly',
|
||||
checkout_type: 'new',
|
||||
billing_op_id: 'op-1',
|
||||
payment_intent_source: 'subscribe_to_run'
|
||||
})
|
||||
})
|
||||
|
||||
it('opens payment URL when needs_payment_method', async () => {
|
||||
const checkout = await setup()
|
||||
checkout.selectedTierKey.value = 'standard'
|
||||
@@ -720,6 +814,7 @@ describe('useSubscriptionCheckout', () => {
|
||||
detail: 'Payment failed'
|
||||
})
|
||||
)
|
||||
expect(mockTrackBeginCheckout).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -9,16 +9,26 @@ import type { TeamPlanSelection } from '@/platform/cloud/subscription/constants/
|
||||
import type { TierKey } from '@/platform/cloud/subscription/constants/tierPricing'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutType
|
||||
} from '@/platform/telemetry/types'
|
||||
import type {
|
||||
Plan,
|
||||
PreviewSubscribeResponse,
|
||||
SubscribeResponse
|
||||
} from '@/platform/workspace/api/workspaceApi'
|
||||
import { useBillingOperationStore } from '@/platform/workspace/stores/billingOperationStore'
|
||||
import { trackWorkspaceCheckoutStarted } from '@/platform/workspace/utils/workspaceCheckoutTelemetry'
|
||||
|
||||
type CheckoutStep = 'pricing' | 'preview' | 'success'
|
||||
type CheckoutTierKey = Exclude<TierKey, 'free' | 'founder'>
|
||||
|
||||
interface SelectedTeamCheckout {
|
||||
stop: TeamPlanSelection
|
||||
checkoutType: SubscriptionCheckoutType
|
||||
}
|
||||
|
||||
/**
|
||||
* Which screen the `preview` step shows. Only a change prorates: a team change
|
||||
* carries `previewData` (handleSubscribeTeamClick sets it solely for an immediate
|
||||
@@ -45,9 +55,12 @@ export function findPlanSlug(
|
||||
return plan?.slug ?? null
|
||||
}
|
||||
|
||||
export function useSubscriptionCheckout(emit: {
|
||||
(e: 'close', subscribed: boolean): void
|
||||
}) {
|
||||
export function useSubscriptionCheckout(
|
||||
emit: {
|
||||
(e: 'close', subscribed: boolean): void
|
||||
},
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
) {
|
||||
const { t } = useI18n()
|
||||
const toast = useToast()
|
||||
const {
|
||||
@@ -68,13 +81,16 @@ export function useSubscriptionCheckout(emit: {
|
||||
const isResubscribing = ref(false)
|
||||
const previewData = ref<PreviewSubscribeResponse | null>(null)
|
||||
const selectedTierKey = ref<CheckoutTierKey | null>(null)
|
||||
const selectedTeamStop = ref<TeamPlanSelection | null>(null)
|
||||
const selectedTeamCheckout = ref<SelectedTeamCheckout | null>(null)
|
||||
const selectedBillingCycle = ref<BillingCycle>('yearly')
|
||||
const isPolling = computed(() => billingOperationStore.hasPendingOperations)
|
||||
const isTeamCheckout = computed(() => selectedTeamStop.value !== null)
|
||||
const selectedTeamStop = computed(
|
||||
() => selectedTeamCheckout.value?.stop ?? null
|
||||
)
|
||||
const isTeamCheckout = computed(() => selectedTeamCheckout.value !== null)
|
||||
|
||||
const previewVariant = computed<PreviewVariant>(() => {
|
||||
if (selectedTeamStop.value) {
|
||||
if (selectedTeamCheckout.value) {
|
||||
return previewData.value ? 'team-change' : 'team-new'
|
||||
}
|
||||
if (previewData.value) {
|
||||
@@ -154,7 +170,10 @@ export function useSubscriptionCheckout(emit: {
|
||||
billingCycle: BillingCycle
|
||||
isChange?: boolean
|
||||
}) {
|
||||
selectedTeamStop.value = payload.stop
|
||||
selectedTeamCheckout.value = {
|
||||
stop: payload.stop,
|
||||
checkoutType: payload.isChange ? 'change' : 'new'
|
||||
}
|
||||
selectedBillingCycle.value = payload.billingCycle
|
||||
selectedTierKey.value = null
|
||||
previewData.value = null
|
||||
@@ -182,7 +201,7 @@ export function useSubscriptionCheckout(emit: {
|
||||
function handleBackToPricing() {
|
||||
checkoutStep.value = 'pricing'
|
||||
previewData.value = null
|
||||
selectedTeamStop.value = null
|
||||
selectedTeamCheckout.value = null
|
||||
}
|
||||
|
||||
function handleSuccessClose() {
|
||||
@@ -190,20 +209,34 @@ export function useSubscriptionCheckout(emit: {
|
||||
}
|
||||
|
||||
async function handleSubscription() {
|
||||
if (!selectedTierKey.value) return
|
||||
const tierKey = selectedTierKey.value
|
||||
if (!tierKey) return
|
||||
|
||||
const billingCycle = selectedBillingCycle.value
|
||||
const checkoutType =
|
||||
previewData.value &&
|
||||
previewData.value.transition_type !== 'new_subscription'
|
||||
? 'change'
|
||||
: 'new'
|
||||
|
||||
isSubscribing.value = true
|
||||
try {
|
||||
const planSlug = getApiPlanSlug(
|
||||
selectedTierKey.value,
|
||||
selectedBillingCycle.value
|
||||
)
|
||||
const planSlug = getApiPlanSlug(tierKey, billingCycle)
|
||||
if (!planSlug) return
|
||||
const response = await subscribe(planSlug, {
|
||||
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
|
||||
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
|
||||
})
|
||||
|
||||
if (response) {
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: tierKey,
|
||||
cycle: billingCycle,
|
||||
checkoutType,
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource
|
||||
})
|
||||
}
|
||||
await handleSubscribeResponse(response)
|
||||
} catch (error) {
|
||||
showSubscribeError(error)
|
||||
@@ -269,8 +302,8 @@ export function useSubscriptionCheckout(emit: {
|
||||
}
|
||||
|
||||
async function handleTeamSubscription() {
|
||||
const stop = selectedTeamStop.value
|
||||
if (!stop?.id) {
|
||||
const teamCheckout = selectedTeamCheckout.value
|
||||
if (!teamCheckout?.stop.id) {
|
||||
toast.add({
|
||||
severity: 'error',
|
||||
summary: t('subscription.teamPlan.name'),
|
||||
@@ -279,16 +312,28 @@ export function useSubscriptionCheckout(emit: {
|
||||
return
|
||||
}
|
||||
|
||||
const { stop, checkoutType } = teamCheckout
|
||||
const billingCycle = selectedBillingCycle.value
|
||||
|
||||
isSubscribing.value = true
|
||||
try {
|
||||
const planSlug = getTeamPlanSlug(selectedBillingCycle.value)
|
||||
const planSlug = getTeamPlanSlug(billingCycle)
|
||||
const response = await subscribe(planSlug, {
|
||||
teamCreditStopId: stop.id,
|
||||
billingCycle: selectedBillingCycle.value,
|
||||
billingCycle,
|
||||
returnUrl: `${getComfyPlatformBaseUrl()}/payment/success`,
|
||||
cancelUrl: `${getComfyPlatformBaseUrl()}/payment/failed`
|
||||
})
|
||||
|
||||
if (response) {
|
||||
trackWorkspaceCheckoutStarted({
|
||||
tier: 'team',
|
||||
cycle: billingCycle,
|
||||
checkoutType,
|
||||
billingOpId: response.billing_op_id,
|
||||
paymentIntentSource
|
||||
})
|
||||
}
|
||||
await handleSubscribeResponse(response)
|
||||
} catch (error) {
|
||||
showSubscribeError(error)
|
||||
|
||||
@@ -2,6 +2,7 @@ import { computed, ref, shallowRef } from 'vue'
|
||||
|
||||
import { useBillingPlans } from '@/platform/cloud/subscription/composables/useBillingPlans'
|
||||
import { useSubscriptionDialog } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type {
|
||||
BillingBalanceResponse,
|
||||
BillingStatusResponse,
|
||||
@@ -275,12 +276,12 @@ export function useWorkspaceBilling(): BillingState & BillingActions {
|
||||
async function requireActiveSubscription(): Promise<void> {
|
||||
await fetchStatus()
|
||||
if (!isActiveSubscription.value) {
|
||||
subscriptionDialog.show()
|
||||
subscriptionDialog.show({ reason: 'subscription_required' })
|
||||
}
|
||||
}
|
||||
|
||||
function showSubscriptionDialog(): void {
|
||||
subscriptionDialog.show()
|
||||
function showSubscriptionDialog(options?: SubscriptionDialogOptions): void {
|
||||
subscriptionDialog.show(options)
|
||||
}
|
||||
|
||||
return {
|
||||
|
||||
38
src/platform/workspace/utils/workspaceCheckoutTelemetry.ts
Normal file
38
src/platform/workspace/utils/workspaceCheckoutTelemetry.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
import type {
|
||||
PaymentIntentSource,
|
||||
SubscriptionCheckoutTier,
|
||||
SubscriptionCheckoutType
|
||||
} from '@/platform/telemetry/types'
|
||||
import type { BillingCycle } from '@/platform/cloud/subscription/utils/subscriptionTierRank'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
|
||||
interface TrackWorkspaceCheckoutStartedOptions {
|
||||
tier: SubscriptionCheckoutTier
|
||||
cycle: BillingCycle
|
||||
checkoutType: SubscriptionCheckoutType
|
||||
billingOpId: string
|
||||
paymentIntentSource?: PaymentIntentSource
|
||||
}
|
||||
|
||||
export function trackWorkspaceCheckoutStarted({
|
||||
tier,
|
||||
cycle,
|
||||
checkoutType,
|
||||
billingOpId,
|
||||
paymentIntentSource
|
||||
}: TrackWorkspaceCheckoutStartedOptions) {
|
||||
const { userId } = useAuthStore()
|
||||
if (!userId) return
|
||||
|
||||
useTelemetry()?.trackBeginCheckout({
|
||||
user_id: userId,
|
||||
tier,
|
||||
cycle,
|
||||
checkout_type: checkoutType,
|
||||
billing_op_id: billingOpId,
|
||||
...(paymentIntentSource
|
||||
? { payment_intent_source: paymentIntentSource }
|
||||
: {})
|
||||
})
|
||||
}
|
||||
@@ -154,39 +154,6 @@ const zAssetDownloadWsMessage = z.object({
|
||||
error: z.string().optional()
|
||||
})
|
||||
|
||||
const DOWNLOAD_STATES = [
|
||||
'queued',
|
||||
'active',
|
||||
'paused',
|
||||
'verifying',
|
||||
'completed',
|
||||
'failed',
|
||||
'cancelled'
|
||||
] as const
|
||||
|
||||
const zDownloadSegment = z.object({
|
||||
idx: z.number().int().nonnegative(),
|
||||
bytes_done: z.number().int().nonnegative(),
|
||||
length: z.number().int().nonnegative()
|
||||
})
|
||||
|
||||
const zDownloadStatus = z.object({
|
||||
download_id: z.string(),
|
||||
model_id: z.string(),
|
||||
url: z.string(),
|
||||
status: z.enum(DOWNLOAD_STATES),
|
||||
priority: z.number().int(),
|
||||
total_bytes: z.number().int().nonnegative().nullable(),
|
||||
bytes_done: z.number().int().nonnegative(),
|
||||
progress: z.number().nullable(),
|
||||
speed_bps: z.number().nonnegative().nullable(),
|
||||
eta_seconds: z.number().nonnegative().nullable(),
|
||||
segments: z.array(zDownloadSegment).nullable(),
|
||||
error: z.string().nullable(),
|
||||
created_at: z.number(),
|
||||
updated_at: z.number()
|
||||
})
|
||||
|
||||
const zAssetExportWsMessage = z.object({
|
||||
task_id: z.string(),
|
||||
export_name: z.string().optional(),
|
||||
@@ -221,8 +188,6 @@ export type ProgressStateWsMessage = z.infer<typeof zProgressStateWsMessage>
|
||||
export type FeatureFlagsWsMessage = z.infer<typeof zFeatureFlagsWsMessage>
|
||||
export type AssetDownloadWsMessage = z.infer<typeof zAssetDownloadWsMessage>
|
||||
export type AssetExportWsMessage = z.infer<typeof zAssetExportWsMessage>
|
||||
export type DownloadState = (typeof DOWNLOAD_STATES)[number]
|
||||
export type DownloadStatus = z.infer<typeof zDownloadStatus>
|
||||
// End of ws messages
|
||||
|
||||
export type NotificationWsMessage = z.infer<typeof zNotificationWsMessage>
|
||||
|
||||
@@ -34,7 +34,6 @@ import type {
|
||||
AssetDownloadWsMessage,
|
||||
AssetExportWsMessage,
|
||||
CustomNodesI18n,
|
||||
DownloadStatus,
|
||||
EmbeddingsResponse,
|
||||
ExecutedWsMessage,
|
||||
ExecutingWsMessage,
|
||||
@@ -187,7 +186,6 @@ interface BackendApiCalls {
|
||||
feature_flags: FeatureFlagsWsMessage
|
||||
asset_download: AssetDownloadWsMessage
|
||||
asset_export: AssetExportWsMessage
|
||||
download_progress: DownloadStatus
|
||||
}
|
||||
|
||||
/** Dictionary of all api calls */
|
||||
|
||||
@@ -18,7 +18,7 @@ import type {
|
||||
} from '@/stores/dialogStore'
|
||||
|
||||
import type { ComponentAttrs } from 'vue-component-type-helpers'
|
||||
import type { SubscriptionDialogReason } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { SubscriptionDialogOptions } from '@/platform/cloud/subscription/composables/useSubscriptionDialog'
|
||||
import type { WorkspaceRole } from '@/platform/workspace/api/workspaceApi'
|
||||
|
||||
// Lazy loaders for dialogs - components are loaded on first use
|
||||
@@ -442,9 +442,9 @@ export const useDialogService = () => {
|
||||
})
|
||||
}
|
||||
|
||||
async function showSubscriptionRequiredDialog(options?: {
|
||||
reason?: SubscriptionDialogReason
|
||||
}) {
|
||||
async function showSubscriptionRequiredDialog(
|
||||
options?: SubscriptionDialogOptions
|
||||
) {
|
||||
if (!isCloud || !window.__CONFIG__?.subscription_required) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -113,27 +113,6 @@ describe('commandStore', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('unregisterCommand', () => {
|
||||
it('removes a registered command', () => {
|
||||
const store = useCommandStore()
|
||||
store.registerCommand({ id: 'to.remove', function: vi.fn() })
|
||||
|
||||
store.unregisterCommand('to.remove')
|
||||
|
||||
expect(store.isRegistered('to.remove')).toBe(false)
|
||||
})
|
||||
|
||||
it('is a no-op for an unregistered id', () => {
|
||||
const store = useCommandStore()
|
||||
store.registerCommand({ id: 'keep.me', function: vi.fn() })
|
||||
|
||||
store.unregisterCommand('nonexistent')
|
||||
|
||||
expect(store.isRegistered('keep.me')).toBe(true)
|
||||
expect(store.commands).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isRegistered', () => {
|
||||
it('returns false for unregistered command', () => {
|
||||
const store = useCommandStore()
|
||||
|
||||
@@ -88,12 +88,6 @@ export const useCommandStore = defineStore('command', () => {
|
||||
}
|
||||
}
|
||||
|
||||
const unregisterCommand = (commandId: string) => {
|
||||
if (!commandsById.value[commandId]) return
|
||||
const { [commandId]: _removed, ...rest } = commandsById.value
|
||||
commandsById.value = rest
|
||||
}
|
||||
|
||||
const getCommand = (command: string) => {
|
||||
return commandsById.value[command]
|
||||
}
|
||||
@@ -145,7 +139,6 @@ export const useCommandStore = defineStore('command', () => {
|
||||
getCommand,
|
||||
registerCommand,
|
||||
registerCommands,
|
||||
unregisterCommand,
|
||||
isRegistered,
|
||||
loadExtensionCommands,
|
||||
formatKeySequence
|
||||
|
||||
@@ -5,19 +5,12 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { useSidebarTabStore } from '@/stores/workspace/sidebarTabStore'
|
||||
|
||||
const {
|
||||
mockGetSetting,
|
||||
mockRegisterCommand,
|
||||
mockRegisterCommands,
|
||||
mockUnregisterCommand,
|
||||
mockServerSideModelDownloads
|
||||
} = vi.hoisted(() => ({
|
||||
mockGetSetting: vi.fn(),
|
||||
mockRegisterCommand: vi.fn(),
|
||||
mockRegisterCommands: vi.fn(),
|
||||
mockUnregisterCommand: vi.fn(),
|
||||
mockServerSideModelDownloads: { value: false }
|
||||
}))
|
||||
const { mockGetSetting, mockRegisterCommand, mockRegisterCommands } =
|
||||
vi.hoisted(() => ({
|
||||
mockGetSetting: vi.fn(),
|
||||
mockRegisterCommand: vi.fn(),
|
||||
mockRegisterCommands: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => ({
|
||||
@@ -28,43 +21,10 @@ vi.mock('@/platform/settings/settingStore', () => ({
|
||||
vi.mock('@/stores/commandStore', () => ({
|
||||
useCommandStore: () => ({
|
||||
registerCommand: mockRegisterCommand,
|
||||
unregisterCommand: mockUnregisterCommand,
|
||||
commands: []
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useFeatureFlags', async () => {
|
||||
const { ref } = await import('vue')
|
||||
const serverSideModelDownloadsRef = ref(mockServerSideModelDownloads.value)
|
||||
Object.defineProperty(mockServerSideModelDownloads, 'value', {
|
||||
get: () => serverSideModelDownloadsRef.value,
|
||||
set: (value: boolean) => {
|
||||
serverSideModelDownloadsRef.value = value
|
||||
}
|
||||
})
|
||||
return {
|
||||
useFeatureFlags: () => ({
|
||||
flags: {
|
||||
get serverSideModelDownloads() {
|
||||
return mockServerSideModelDownloads.value
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock(
|
||||
'@/platform/modelManager/composables/useModelManagerSidebarTab',
|
||||
() => ({
|
||||
useModelManagerSidebarTab: () => ({
|
||||
id: 'model-manager',
|
||||
title: 'model-manager',
|
||||
type: 'vue',
|
||||
component: {}
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@/stores/menuItemStore', () => ({
|
||||
useMenuItemStore: () => ({
|
||||
registerCommands: mockRegisterCommands
|
||||
@@ -139,8 +99,6 @@ describe('useSidebarTabStore', () => {
|
||||
mockGetSetting.mockReset()
|
||||
mockRegisterCommand.mockClear()
|
||||
mockRegisterCommands.mockClear()
|
||||
mockUnregisterCommand.mockClear()
|
||||
mockServerSideModelDownloads.value = false
|
||||
})
|
||||
|
||||
it('registers the job history tab when QPO V2 is enabled', () => {
|
||||
@@ -202,57 +160,4 @@ describe('useSidebarTabStore', () => {
|
||||
])
|
||||
expect(mockRegisterCommand).toHaveBeenCalledTimes(6)
|
||||
})
|
||||
|
||||
it('registers the model-manager tab when the feature flag starts enabled', () => {
|
||||
mockServerSideModelDownloads.value = true
|
||||
|
||||
const store = useSidebarTabStore()
|
||||
store.registerCoreSidebarTabs()
|
||||
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).toContain('model-manager')
|
||||
})
|
||||
|
||||
it('does not register the model-manager tab when the feature flag is disabled', () => {
|
||||
const store = useSidebarTabStore()
|
||||
store.registerCoreSidebarTabs()
|
||||
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).not.toContain(
|
||||
'model-manager'
|
||||
)
|
||||
})
|
||||
|
||||
it('registers the model-manager tab when the feature flag turns on later', async () => {
|
||||
const store = useSidebarTabStore()
|
||||
store.registerCoreSidebarTabs()
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).not.toContain(
|
||||
'model-manager'
|
||||
)
|
||||
|
||||
mockServerSideModelDownloads.value = true
|
||||
await nextTick()
|
||||
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).toContain('model-manager')
|
||||
expect(mockRegisterCommand).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
id: 'Workspace.ToggleSidebarTab.model-manager'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('unregisters the model-manager tab and its command when the flag turns off', async () => {
|
||||
mockServerSideModelDownloads.value = true
|
||||
const store = useSidebarTabStore()
|
||||
store.registerCoreSidebarTabs()
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).toContain('model-manager')
|
||||
|
||||
mockServerSideModelDownloads.value = false
|
||||
await nextTick()
|
||||
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).not.toContain(
|
||||
'model-manager'
|
||||
)
|
||||
expect(mockUnregisterCommand).toHaveBeenCalledWith(
|
||||
'Workspace.ToggleSidebarTab.model-manager'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,9 +5,7 @@ import { useAssetsSidebarTab } from '@/composables/sidebarTabs/useAssetsSidebarT
|
||||
import { useJobHistorySidebarTab } from '@/composables/sidebarTabs/useJobHistorySidebarTab'
|
||||
import { useModelLibrarySidebarTab } from '@/composables/sidebarTabs/useModelLibrarySidebarTab'
|
||||
import { useNodeLibrarySidebarTab } from '@/composables/sidebarTabs/useNodeLibrarySidebarTab'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { t, te } from '@/i18n'
|
||||
import { useModelManagerSidebarTab } from '@/platform/modelManager/composables/useModelManagerSidebarTab'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { useAppsSidebarTab } from '@/platform/workflow/management/composables/useAppsSidebarTab'
|
||||
import { useWorkflowsSidebarTab } from '@/platform/workflow/management/composables/useWorkflowsSidebarTab'
|
||||
@@ -55,8 +53,7 @@ export const useSidebarTabStore = defineStore('sidebarTab', () => {
|
||||
'model-library': 'sideToolbar.modelLibrary',
|
||||
workflows: 'sideToolbar.workflows',
|
||||
assets: 'sideToolbar.assets',
|
||||
'job-history': 'queue.jobHistory',
|
||||
'model-manager': 'modelManager.title'
|
||||
'job-history': 'queue.jobHistory'
|
||||
}
|
||||
|
||||
const key = menubarLabelKeys[tab.id]
|
||||
@@ -135,27 +132,6 @@ export const useSidebarTabStore = defineStore('sidebarTab', () => {
|
||||
(enabled) => syncJobHistoryTab(enabled)
|
||||
)
|
||||
|
||||
const modelManagerTabId = 'model-manager'
|
||||
const { flags } = useFeatureFlags()
|
||||
const syncModelManagerTab = (enabled: boolean) => {
|
||||
const hasTab = sidebarTabs.value.some(
|
||||
(tab) => tab.id === modelManagerTabId
|
||||
)
|
||||
if (enabled && !hasTab) {
|
||||
registerSidebarTab(useModelManagerSidebarTab())
|
||||
} else if (!enabled && hasTab) {
|
||||
unregisterSidebarTab(modelManagerTabId)
|
||||
useCommandStore().unregisterCommand(
|
||||
`Workspace.ToggleSidebarTab.${modelManagerTabId}`
|
||||
)
|
||||
}
|
||||
}
|
||||
watch(
|
||||
() => flags.serverSideModelDownloads,
|
||||
(enabled) => syncModelManagerTab(enabled),
|
||||
{ immediate: true }
|
||||
)
|
||||
|
||||
registerSidebarTab(useAssetsSidebarTab())
|
||||
registerSidebarTab(useNodeLibrarySidebarTab())
|
||||
registerSidebarTab(useModelLibrarySidebarTab())
|
||||
|
||||
@@ -58,7 +58,6 @@ import { useQueuePolling } from '@/platform/remote/comfyui/useQueuePolling'
|
||||
import { useErrorHandling } from '@/composables/useErrorHandling'
|
||||
import { useReconnectQueueRefresh } from '@/composables/useReconnectQueueRefresh'
|
||||
import { useReconnectingNotification } from '@/composables/useReconnectingNotification'
|
||||
import { useModelDownloadEffects } from '@/platform/modelManager/composables/useModelDownloadEffects'
|
||||
import { useProgressFavicon } from '@/composables/useProgressFavicon'
|
||||
import { SERVER_CONFIG_ITEMS } from '@/constants/serverConfig'
|
||||
import type { ServerConfig, ServerConfigValue } from '@/constants/serverConfig'
|
||||
@@ -228,7 +227,6 @@ useMenuItemStore().registerCoreMenuCommands()
|
||||
useKeybindingService().registerCoreKeybindings()
|
||||
useSidebarTabStore().registerCoreSidebarTabs()
|
||||
void useBottomPanelStore().registerCoreBottomPanelTabs()
|
||||
useModelDownloadEffects()
|
||||
|
||||
useQueuePolling()
|
||||
const queuePendingTaskCountStore = useQueuePendingTaskCountStore()
|
||||
|
||||
Reference in New Issue
Block a user