mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-07-03 13:48:49 +00:00
Compare commits
19 Commits
feat/creat
...
model_down
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5be31cdeb2 | ||
|
|
cf321ef95e | ||
|
|
7263e5068a | ||
|
|
d867d3ed8f | ||
|
|
fa4e8b317a | ||
|
|
375721c4a0 | ||
|
|
f0664df0a7 | ||
|
|
611ffffbe9 | ||
|
|
30386b1004 | ||
|
|
364268d5fd | ||
|
|
807bc24f3a | ||
|
|
9e92b9f747 | ||
|
|
4059950d36 | ||
|
|
70a41cc9ca | ||
|
|
ffe0760430 | ||
|
|
6a3afc42c6 | ||
|
|
fb5b087586 | ||
|
|
7c715c22ab | ||
|
|
9cf5c9a93f |
@@ -15,7 +15,7 @@ const { categories } = defineProps<{
|
||||
|
||||
const activeSection = ref(categories[0]?.value ?? '')
|
||||
|
||||
const HEADER_OFFSET = -144
|
||||
const HEADER_OFFSET_PX = -144
|
||||
const BOTTOM_THRESHOLD_PX = 4
|
||||
const SCROLL_SAFETY_MS = 1500
|
||||
|
||||
@@ -52,7 +52,7 @@ function scrollToSection(id: string) {
|
||||
const el = document.getElementById(id)
|
||||
if (el) {
|
||||
scrollTo(el, {
|
||||
offset: HEADER_OFFSET,
|
||||
offset: HEADER_OFFSET_PX,
|
||||
duration: 0.8,
|
||||
immediate: prefersReducedMotion(),
|
||||
onComplete: clearScrollLock
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<li
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow before:content-['']"
|
||||
class="flex items-start gap-2 text-primary-comfy-canvas before:mt-1.5 before:size-1.5 before:shrink-0 before:rounded-full before:bg-primary-comfy-yellow"
|
||||
>
|
||||
<slot />
|
||||
</li>
|
||||
|
||||
304
browser_tests/tests/sidebar/modelManagerSidebarTab.spec.ts
Normal file
304
browser_tests/tests/sidebar/modelManagerSidebarTab.spec.ts
Normal file
@@ -0,0 +1,304 @@
|
||||
import { expect, mergeTests } from '@playwright/test'
|
||||
import type { Page } from '@playwright/test'
|
||||
|
||||
import type { ComfyPage } from '@e2e/fixtures/ComfyPage'
|
||||
import { comfyPageFixture } from '@e2e/fixtures/ComfyPage'
|
||||
import { ConfirmDialog } from '@e2e/fixtures/components/ConfirmDialog'
|
||||
import { jsonRoute } from '@e2e/fixtures/utils/jsonRoute'
|
||||
import { webSocketFixture } from '@e2e/fixtures/ws'
|
||||
|
||||
import type {
|
||||
DownloadStatus,
|
||||
EnqueueResponse,
|
||||
HostCredentialUpsert,
|
||||
HostCredentialView
|
||||
} from '@/platform/modelManager/types'
|
||||
|
||||
const test = mergeTests(comfyPageFixture, webSocketFixture)
|
||||
|
||||
const DOWNLOADS_ROUTE = /\/api\/download$/
|
||||
const ENQUEUE_ROUTE = /\/api\/download\/enqueue$/
|
||||
const CREDENTIALS_ROUTE = /\/api\/download\/credentials$/
|
||||
const CREDENTIAL_ROUTE = /\/api\/download\/credentials\/([^/?]+)$/
|
||||
const CLEAR_ROUTE = /\/api\/download\/clear$/
|
||||
|
||||
const DOWNLOAD_ID = 'e2e-download-1'
|
||||
const MODEL_URL =
|
||||
'https://huggingface.co/e2e/test/resolve/main/model.safetensors'
|
||||
const MODEL_ID = 'checkpoints/e2e-test-model.safetensors'
|
||||
|
||||
function makeDownloadStatus(
|
||||
overrides: Partial<DownloadStatus> = {}
|
||||
): DownloadStatus {
|
||||
const now = Math.floor(Date.now() / 1000)
|
||||
return {
|
||||
download_id: DOWNLOAD_ID,
|
||||
model_id: MODEL_ID,
|
||||
url: MODEL_URL,
|
||||
status: 'queued',
|
||||
priority: 0,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
progress: null,
|
||||
speed_bps: null,
|
||||
eta_seconds: null,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
async function enableServerSideModelDownloads(comfyPage: ComfyPage) {
|
||||
await comfyPage.page.evaluate(() => {
|
||||
window.app!.api.serverFeatureFlags.value = {
|
||||
...window.app!.api.serverFeatureFlags.value,
|
||||
server_side_model_downloads: true
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/** Keeps GET /download consistent so the 5s stale-poll can't wipe test state. */
|
||||
async function mockDownloadsList(
|
||||
page: Page,
|
||||
getDownloads: () => DownloadStatus[]
|
||||
) {
|
||||
await page.route(DOWNLOADS_ROUTE, async (route) => {
|
||||
if (route.request().method() !== 'GET') return route.fallback()
|
||||
await route.fulfill(jsonRoute({ downloads: getDownloads() }))
|
||||
})
|
||||
}
|
||||
|
||||
function modelDownloaderTabButton(comfyPage: ComfyPage) {
|
||||
return comfyPage.page.locator('.model-manager-tab-button')
|
||||
}
|
||||
|
||||
function modelDownloaderBadge(comfyPage: ComfyPage) {
|
||||
return modelDownloaderTabButton(comfyPage).locator('.sidebar-icon-badge')
|
||||
}
|
||||
|
||||
function modelDownloaderPanel(comfyPage: ComfyPage) {
|
||||
return comfyPage.page.locator('.sidebar-content-container')
|
||||
}
|
||||
|
||||
async function openModelDownloaderTab(comfyPage: ComfyPage) {
|
||||
await modelDownloaderTabButton(comfyPage).click()
|
||||
const panel = modelDownloaderPanel(comfyPage)
|
||||
await expect(panel.getByText('Downloads', { exact: true })).toBeVisible()
|
||||
// Toolbar buttons are only interactive while the tab panel is hovered
|
||||
// (`group-hover/sidebar-tab`), so keep the cursor over it for later clicks.
|
||||
await panel.hover()
|
||||
}
|
||||
|
||||
test.describe('Model Downloader sidebar', { tag: '@ui' }, () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.modelLibrary.mockFoldersWithFiles({ checkpoints: [] })
|
||||
// Isolate from any real downloads already tracked by the backend.
|
||||
await mockDownloadsList(comfyPage.page, () => [])
|
||||
await enableServerSideModelDownloads(comfyPage)
|
||||
})
|
||||
|
||||
test('adds a model by URL and reflects live progress over the websocket', async ({
|
||||
comfyPage,
|
||||
getWebSocket
|
||||
}) => {
|
||||
let downloads: DownloadStatus[] = []
|
||||
await mockDownloadsList(comfyPage.page, () => downloads)
|
||||
await comfyPage.page.route(ENQUEUE_ROUTE, async (route) => {
|
||||
downloads = [makeDownloadStatus()]
|
||||
const response: EnqueueResponse = {
|
||||
download_id: DOWNLOAD_ID,
|
||||
accepted: true
|
||||
}
|
||||
await route.fulfill({
|
||||
status: 202,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(response)
|
||||
})
|
||||
})
|
||||
|
||||
await openModelDownloaderTab(comfyPage)
|
||||
const panel = modelDownloaderPanel(comfyPage)
|
||||
await panel.getByTitle('Add model').click()
|
||||
|
||||
const addDialog = comfyPage.page.getByRole('dialog')
|
||||
await expect(addDialog.getByText('Add model')).toBeVisible()
|
||||
|
||||
await addDialog.getByLabel('URL').fill(MODEL_URL)
|
||||
await expect(addDialog.getByLabel('Filename')).toHaveValue(
|
||||
'model.safetensors'
|
||||
)
|
||||
await addDialog.getByLabel('Filename').fill('e2e-test-model.safetensors')
|
||||
await addDialog.getByRole('combobox', { name: 'Select a folder' }).click()
|
||||
await comfyPage.page
|
||||
.getByRole('option', { name: 'Checkpoints', exact: true })
|
||||
.click()
|
||||
|
||||
await addDialog
|
||||
.getByRole('button', { name: 'Download', exact: true })
|
||||
.click()
|
||||
await expect(addDialog).toBeHidden()
|
||||
|
||||
await expect(panel.getByText('e2e-test-model.safetensors')).toBeVisible()
|
||||
await expect(panel.getByText('Queued', { exact: true })).toBeVisible()
|
||||
await expect(modelDownloaderBadge(comfyPage)).toHaveText('1')
|
||||
|
||||
const ws = await getWebSocket()
|
||||
function sendProgress(overrides: Partial<DownloadStatus>) {
|
||||
const payload = makeDownloadStatus(overrides)
|
||||
downloads = [payload]
|
||||
ws.send(JSON.stringify({ type: 'download_progress', data: payload }))
|
||||
}
|
||||
|
||||
sendProgress({
|
||||
status: 'active',
|
||||
progress: 0.4,
|
||||
bytes_done: 400,
|
||||
total_bytes: 1000,
|
||||
speed_bps: 500_000
|
||||
})
|
||||
|
||||
await expect(panel.getByText('Downloading', { exact: true })).toBeVisible()
|
||||
await expect(panel.getByText(/40%/)).toBeVisible()
|
||||
await expect(modelDownloaderBadge(comfyPage)).toHaveText('1')
|
||||
|
||||
sendProgress({
|
||||
status: 'completed',
|
||||
progress: 1,
|
||||
bytes_done: 1000,
|
||||
total_bytes: 1000,
|
||||
speed_bps: null,
|
||||
eta_seconds: null
|
||||
})
|
||||
|
||||
await expect(panel.getByText('Completed', { exact: true })).toBeVisible()
|
||||
await expect(panel.getByText('History', { exact: true })).toBeVisible()
|
||||
await expect(modelDownloaderBadge(comfyPage)).toHaveCount(0)
|
||||
})
|
||||
|
||||
test('clears history and the cleared rows stay gone after reopening the tab', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
let downloads: DownloadStatus[] = [
|
||||
makeDownloadStatus({ status: 'completed', progress: 1, bytes_done: 1000 })
|
||||
]
|
||||
await mockDownloadsList(comfyPage.page, () => downloads)
|
||||
await comfyPage.page.route(CLEAR_ROUTE, async (route) => {
|
||||
const count = downloads.length
|
||||
downloads = []
|
||||
await route.fulfill(jsonRoute({ deleted: count }))
|
||||
})
|
||||
|
||||
await openModelDownloaderTab(comfyPage)
|
||||
const panel = modelDownloaderPanel(comfyPage)
|
||||
await expect(panel.getByText('History', { exact: true })).toBeVisible()
|
||||
await expect(panel.getByText('e2e-test-model.safetensors')).toBeVisible()
|
||||
|
||||
await panel.getByRole('button', { name: 'Clear history' }).click()
|
||||
await expect(panel.getByText('No downloads yet')).toBeVisible()
|
||||
|
||||
// Reopening the tab re-runs hydrate() -> GET /api/download. The bug was
|
||||
// that the cleared row reappeared here; the persisted delete must prevent
|
||||
// the backend list from returning it again.
|
||||
await modelDownloaderTabButton(comfyPage).click()
|
||||
await openModelDownloaderTab(comfyPage)
|
||||
await expect(
|
||||
modelDownloaderPanel(comfyPage).getByText('No downloads yet')
|
||||
).toBeVisible()
|
||||
await expect(
|
||||
modelDownloaderPanel(comfyPage).getByText('e2e-test-model.safetensors')
|
||||
).toBeHidden()
|
||||
})
|
||||
|
||||
test('manages download credentials: add, edit, and delete', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
let credentials: HostCredentialView[] = []
|
||||
let nextId = 1
|
||||
|
||||
await comfyPage.page.route(CREDENTIALS_ROUTE, async (route) => {
|
||||
const method = route.request().method()
|
||||
if (method === 'GET') {
|
||||
await route.fulfill(jsonRoute({ credentials }))
|
||||
return
|
||||
}
|
||||
if (method === 'POST') {
|
||||
const body = route.request().postDataJSON() as HostCredentialUpsert
|
||||
const existing = credentials.find((c) => c.host === body.host)
|
||||
const now = Math.floor(Date.now() / 1000)
|
||||
const view: HostCredentialView = {
|
||||
id: existing?.id ?? `cred-${nextId++}`,
|
||||
host: body.host,
|
||||
auth_scheme: body.auth_scheme ?? 'bearer',
|
||||
header_name: body.header_name ?? null,
|
||||
query_param: body.query_param ?? null,
|
||||
label: body.label ?? null,
|
||||
match_subdomains: body.match_subdomains ?? false,
|
||||
enabled: body.enabled ?? true,
|
||||
secret_last4: body.secret.slice(-4),
|
||||
created_at: existing?.created_at ?? now,
|
||||
updated_at: now
|
||||
}
|
||||
credentials = existing
|
||||
? credentials.map((c) => (c.id === view.id ? view : c))
|
||||
: [...credentials, view]
|
||||
await route.fulfill(jsonRoute(view))
|
||||
return
|
||||
}
|
||||
await route.fallback()
|
||||
})
|
||||
await comfyPage.page.route(CREDENTIAL_ROUTE, async (route) => {
|
||||
if (route.request().method() !== 'DELETE') return route.fallback()
|
||||
const id = route.request().url().match(CREDENTIAL_ROUTE)?.[1]
|
||||
credentials = credentials.filter((c) => c.id !== id)
|
||||
await route.fulfill(jsonRoute({ deleted: true }))
|
||||
})
|
||||
|
||||
const dialog = comfyPage.page
|
||||
.getByRole('dialog')
|
||||
.filter({ hasText: 'Credentials Manager' })
|
||||
|
||||
// The success toast shown after saving can momentarily steal focus and
|
||||
// close the dialog, so re-opening is retried until it sticks.
|
||||
async function ensureCredentialsDialogOpen() {
|
||||
await expect(async () => {
|
||||
const panel = modelDownloaderPanel(comfyPage)
|
||||
await panel.hover()
|
||||
await panel.getByTitle('Credentials Manager').click()
|
||||
await expect(dialog).toBeVisible({ timeout: 1000 })
|
||||
}).toPass({ timeout: 10_000 })
|
||||
}
|
||||
|
||||
await openModelDownloaderTab(comfyPage)
|
||||
await ensureCredentialsDialogOpen()
|
||||
|
||||
await dialog.getByLabel('Host').fill('huggingface.co')
|
||||
await dialog.getByLabel('API key').fill('secret-key-1234')
|
||||
await dialog.getByRole('button', { name: 'Save', exact: true }).click()
|
||||
|
||||
await ensureCredentialsDialogOpen()
|
||||
await expect(
|
||||
dialog.getByText('huggingface.co · Bearer token · ••••1234')
|
||||
).toBeVisible()
|
||||
|
||||
await dialog.getByTitle('Edit').click()
|
||||
await expect(dialog.getByText('Update credential')).toBeVisible()
|
||||
await expect(dialog.getByLabel('API key')).toHaveValue('')
|
||||
await dialog.getByLabel('Label').fill('My HF Key')
|
||||
await dialog.getByLabel('API key').fill('updated-secret-5678')
|
||||
await dialog.getByRole('button', { name: 'Save', exact: true }).click()
|
||||
|
||||
await ensureCredentialsDialogOpen()
|
||||
await expect(dialog.getByText('My HF Key')).toBeVisible()
|
||||
await expect(
|
||||
dialog.getByText('huggingface.co · Bearer token · ••••5678')
|
||||
).toBeVisible()
|
||||
|
||||
await dialog.getByTitle('Delete').click()
|
||||
const confirm = new ConfirmDialog(comfyPage.page)
|
||||
await confirm.click('delete')
|
||||
|
||||
await expect(dialog.getByText('My HF Key')).toBeHidden()
|
||||
})
|
||||
})
|
||||
@@ -4,74 +4,37 @@
|
||||
data-testid="bounding-boxes"
|
||||
@pointerdown.stop
|
||||
>
|
||||
<div class="flex flex-col">
|
||||
<div
|
||||
class="flex h-9 items-center gap-1 rounded-t-sm border border-b-0 border-component-node-border bg-component-node-widget-background px-2"
|
||||
>
|
||||
<Button
|
||||
v-tooltip.bottom="{ value: $t('boundingBoxes.grid'), showDelay: 300 }"
|
||||
variant="textonly"
|
||||
size="unset"
|
||||
:aria-pressed="grid"
|
||||
:aria-label="$t('boundingBoxes.grid')"
|
||||
:class="
|
||||
cn(
|
||||
actionBtnClass,
|
||||
grid && 'bg-component-node-widget-background-selected'
|
||||
)
|
||||
"
|
||||
@click="grid = !grid"
|
||||
>
|
||||
<i class="icon-[lucide--grid-3x3] size-4" />
|
||||
<span>{{ $t('boundingBoxes.grid') }}</span>
|
||||
</Button>
|
||||
<Button
|
||||
v-tooltip.bottom="{
|
||||
value: $t('boundingBoxes.clearAll'),
|
||||
showDelay: 300
|
||||
}"
|
||||
variant="textonly"
|
||||
size="unset"
|
||||
:aria-label="$t('boundingBoxes.clearAll')"
|
||||
:class="cn(actionBtnClass, 'ml-auto')"
|
||||
@click="clearAll"
|
||||
>
|
||||
<i class="icon-[lucide--undo-2] size-4" />
|
||||
<span>{{ $t('boundingBoxes.clearAll') }}</span>
|
||||
</Button>
|
||||
</div>
|
||||
<div
|
||||
ref="canvasContainer"
|
||||
class="relative w-full shrink-0 overflow-hidden rounded-b-sm border border-t-0 border-component-node-border bg-base-background"
|
||||
:style="canvasStyle"
|
||||
>
|
||||
<canvas
|
||||
ref="canvasEl"
|
||||
tabindex="0"
|
||||
class="absolute inset-0 size-full rounded-sm outline-none"
|
||||
:style="{ cursor: canvasCursor }"
|
||||
@pointerdown="onPointerDown"
|
||||
@pointermove="onCanvasPointerMove"
|
||||
@pointerup="onDocPointerUp"
|
||||
@pointercancel="onDocPointerUp"
|
||||
@pointerleave="onPointerLeave"
|
||||
@lostpointercapture="onDocPointerUp"
|
||||
@dblclick="onDoubleClick"
|
||||
@keydown="onCanvasKeyDown"
|
||||
@focus="focused = true"
|
||||
@blur="focused = false"
|
||||
/>
|
||||
<textarea
|
||||
v-if="inlineEditor"
|
||||
ref="inlineEditorEl"
|
||||
v-model="inlineEditor.value"
|
||||
class="absolute box-border resize-none rounded-sm border-2 bg-black/90 p-1 font-mono text-xs text-white outline-none"
|
||||
:style="inlineEditor.style"
|
||||
data-capture-wheel="true"
|
||||
@keydown.stop="onInlineKeyDown"
|
||||
@blur="commitInlineEditor"
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
ref="canvasContainer"
|
||||
class="relative w-full shrink-0 overflow-hidden rounded-sm border border-component-node-border bg-node-component-surface"
|
||||
:style="canvasStyle"
|
||||
>
|
||||
<canvas
|
||||
ref="canvasEl"
|
||||
tabindex="0"
|
||||
class="absolute inset-0 size-full rounded-sm outline-none"
|
||||
:style="{ cursor: canvasCursor }"
|
||||
@pointerdown="onPointerDown"
|
||||
@pointermove="onCanvasPointerMove"
|
||||
@pointerup="onDocPointerUp"
|
||||
@pointercancel="onDocPointerUp"
|
||||
@pointerleave="onPointerLeave"
|
||||
@lostpointercapture="onDocPointerUp"
|
||||
@dblclick="onDoubleClick"
|
||||
@keydown="onCanvasKeyDown"
|
||||
@focus="focused = true"
|
||||
@blur="focused = false"
|
||||
/>
|
||||
<textarea
|
||||
v-if="inlineEditor"
|
||||
ref="inlineEditorEl"
|
||||
v-model="inlineEditor.value"
|
||||
class="absolute box-border resize-none rounded-sm border-2 bg-black/90 p-1 font-mono text-xs text-white outline-none"
|
||||
:style="inlineEditor.style"
|
||||
data-capture-wheel="true"
|
||||
@keydown.stop="onInlineKeyDown"
|
||||
@blur="commitInlineEditor"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div
|
||||
@@ -159,6 +122,16 @@
|
||||
<div v-else-if="hasRegions" class="text-node-text-muted px-1 text-xs">
|
||||
{{ $t('boundingBoxes.clickRegionToEdit') }}
|
||||
</div>
|
||||
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="md"
|
||||
class="gap-2 rounded-lg border border-component-node-border bg-component-node-background text-xs text-muted-foreground hover:text-base-foreground"
|
||||
@click="clearAll"
|
||||
>
|
||||
<i class="icon-[lucide--undo-2]" />
|
||||
{{ $t('boundingBoxes.clearAll') }}
|
||||
</Button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -174,9 +147,6 @@ import { useBoundingBoxes } from '@/composables/boundingBoxes/useBoundingBoxes'
|
||||
import type { BoundingBox } from '@/types/boundingBoxes'
|
||||
import type { NodeId } from '@/types/nodeId'
|
||||
|
||||
const actionBtnClass =
|
||||
'flex shrink-0 items-center gap-1.5 rounded-md border-0 bg-transparent px-2 py-1 text-sm text-base-foreground outline-none transition-colors hover:bg-component-node-widget-background-hovered'
|
||||
|
||||
const { nodeId } = defineProps<{ nodeId: NodeId }>()
|
||||
const modelValue = defineModel<BoundingBox[]>({ default: () => [] })
|
||||
|
||||
@@ -202,8 +172,7 @@ const {
|
||||
commitInlineEditor,
|
||||
setActiveType,
|
||||
clearAll,
|
||||
syncState,
|
||||
grid
|
||||
syncState
|
||||
} = useBoundingBoxes(nodeId, {
|
||||
canvasEl,
|
||||
canvasContainer,
|
||||
|
||||
@@ -243,7 +243,7 @@ onMounted(() => {
|
||||
--sidebar-padding: 4px;
|
||||
--sidebar-icon-size: 1rem;
|
||||
|
||||
--sidebar-default-floating-width: 48px;
|
||||
--sidebar-default-floating-width: 50px;
|
||||
--sidebar-default-connected-width: calc(
|
||||
var(--sidebar-default-floating-width) + var(--sidebar-padding) * 2
|
||||
);
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
<script setup lang="ts">
|
||||
import type { DialogContentEmits, DialogContentProps } from 'reka-ui'
|
||||
import { DialogContent, useForwardPropsEmits } from 'reka-ui'
|
||||
import {
|
||||
DialogContent,
|
||||
injectDialogRootContext,
|
||||
useForwardPropsEmits
|
||||
} from 'reka-ui'
|
||||
import type { HTMLAttributes } from 'vue'
|
||||
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
import type { DialogContentSize } from './dialog.variants'
|
||||
import { dialogContentVariants } from './dialog.variants'
|
||||
import { useModalPointerLock } from './useModalPointerLock'
|
||||
|
||||
const {
|
||||
size,
|
||||
@@ -23,6 +28,11 @@ const {
|
||||
|
||||
const emits = defineEmits<DialogContentEmits>()
|
||||
const forwarded = useForwardPropsEmits(restProps, emits)
|
||||
|
||||
const dialogRootContext = injectDialogRootContext(null)
|
||||
if (dialogRootContext?.modal.value) {
|
||||
useModalPointerLock(() => dialogRootContext.open.value)
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
|
||||
114
src/components/ui/dialog/useModalPointerLock.test.ts
Normal file
114
src/components/ui/dialog/useModalPointerLock.test.ts
Normal file
@@ -0,0 +1,114 @@
|
||||
import { render } from '@testing-library/vue'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import {
|
||||
SelectContent,
|
||||
SelectPortal,
|
||||
SelectRoot,
|
||||
SelectTrigger,
|
||||
SelectViewport
|
||||
} from 'reka-ui'
|
||||
import { defineComponent, h, nextTick, ref } from 'vue'
|
||||
|
||||
import Dialog from './Dialog.vue'
|
||||
import DialogContent from './DialogContent.vue'
|
||||
import DialogPortal from './DialogPortal.vue'
|
||||
|
||||
async function flush() {
|
||||
await nextTick()
|
||||
await nextTick()
|
||||
await new Promise((resolve) => setTimeout(resolve, 0))
|
||||
await nextTick()
|
||||
}
|
||||
|
||||
function mountDialogWithSelect(modal: boolean) {
|
||||
const dialogOpen = ref(true)
|
||||
const selectOpen = ref(false)
|
||||
|
||||
const Parent = defineComponent({
|
||||
setup() {
|
||||
return () =>
|
||||
h(
|
||||
Dialog,
|
||||
{
|
||||
modal,
|
||||
open: dialogOpen.value,
|
||||
'onUpdate:open': (value: boolean) => (dialogOpen.value = value)
|
||||
},
|
||||
() =>
|
||||
h(DialogPortal, null, () =>
|
||||
h(DialogContent, null, () =>
|
||||
h(
|
||||
SelectRoot,
|
||||
{
|
||||
open: selectOpen.value,
|
||||
'onUpdate:open': (value: boolean) =>
|
||||
(selectOpen.value = value)
|
||||
},
|
||||
() => [
|
||||
h(SelectTrigger, null, () => 'trigger'),
|
||||
h(SelectPortal, null, () =>
|
||||
h(SelectContent, { position: 'popper' }, () =>
|
||||
h(SelectViewport, null, () => 'items')
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
return { ...render(Parent), dialogOpen, selectOpen }
|
||||
}
|
||||
|
||||
describe('modal dialog pointer lock', () => {
|
||||
it('keeps body inert after a nested combobox popover opens and closes', async () => {
|
||||
const { selectOpen } = mountDialogWithSelect(true)
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
selectOpen.value = true
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
// Reka restores body pointer events when the popover layer unmounts; the
|
||||
// lock must re-assert it so the canvas behind the dialog stays inert.
|
||||
selectOpen.value = false
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
})
|
||||
|
||||
it('restores body pointer events once the dialog closes', async () => {
|
||||
const { dialogOpen } = mountDialogWithSelect(true)
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
dialogOpen.value = false
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('')
|
||||
})
|
||||
|
||||
it('does not lock body for a non-modal dialog', async () => {
|
||||
mountDialogWithSelect(false)
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('')
|
||||
})
|
||||
|
||||
it('holds the body lock until every open modal dialog closes', async () => {
|
||||
const first = mountDialogWithSelect(true)
|
||||
const second = mountDialogWithSelect(true)
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
// With one modal still open the shared lock must keep the body inert.
|
||||
first.dialogOpen.value = false
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).toBe('none')
|
||||
|
||||
// Once the last modal closes the lock releases and stops forcing inert.
|
||||
second.dialogOpen.value = false
|
||||
await flush()
|
||||
expect(document.body.style.pointerEvents).not.toBe('none')
|
||||
})
|
||||
})
|
||||
54
src/components/ui/dialog/useModalPointerLock.ts
Normal file
54
src/components/ui/dialog/useModalPointerLock.ts
Normal file
@@ -0,0 +1,54 @@
|
||||
import { onScopeDispose, toValue, watch } from 'vue'
|
||||
import type { MaybeRefOrGetter } from 'vue'
|
||||
|
||||
/**
|
||||
* Keeps the canvas behind a modal dialog inert by holding `document.body`'s
|
||||
* pointer-events lock for as long as at least one modal dialog is open.
|
||||
*
|
||||
* Reka-UI locks body pointer events per modal layer, but a nested dismissable
|
||||
* layer that is portalled to the body — e.g. a `Select` popover inside the
|
||||
* dialog — restores the body's pointer events when it closes, even while the
|
||||
* outer modal dialog is still open. That momentarily re-enables the canvas, so
|
||||
* combobox clicks leak through to it and can select a node or dismiss the
|
||||
* dialog. Reka still performs the initial lock and final restore; the
|
||||
* `MutationObserver` only re-asserts `none` if the lock is cleared while a
|
||||
* modal dialog is still open.
|
||||
*/
|
||||
let openModalCount = 0
|
||||
let observer: MutationObserver | null = null
|
||||
|
||||
function enforceLock() {
|
||||
if (openModalCount > 0 && document.body.style.pointerEvents !== 'none') {
|
||||
document.body.style.pointerEvents = 'none'
|
||||
}
|
||||
}
|
||||
|
||||
function acquire() {
|
||||
openModalCount += 1
|
||||
if (observer) return
|
||||
observer = new MutationObserver(enforceLock)
|
||||
observer.observe(document.body, {
|
||||
attributes: true,
|
||||
attributeFilter: ['style']
|
||||
})
|
||||
}
|
||||
|
||||
function release() {
|
||||
openModalCount = Math.max(0, openModalCount - 1)
|
||||
if (openModalCount > 0) return
|
||||
observer?.disconnect()
|
||||
observer = null
|
||||
document.body.style.pointerEvents = ''
|
||||
}
|
||||
|
||||
export function useModalPointerLock(isOpen: MaybeRefOrGetter<boolean>) {
|
||||
let holding = false
|
||||
const sync = (open: boolean) => {
|
||||
if (open === holding) return
|
||||
holding = open
|
||||
if (open) acquire()
|
||||
else release()
|
||||
}
|
||||
watch(() => toValue(isOpen), sync, { immediate: true })
|
||||
onScopeDispose(() => sync(false))
|
||||
}
|
||||
@@ -8,32 +8,14 @@ import { useBoundingBoxes } from './useBoundingBoxes'
|
||||
import type { BoundingBox } from '@/types/boundingBoxes'
|
||||
import { toNodeId } from '@/types/nodeId'
|
||||
|
||||
const { appState, outputState } = vi.hoisted(() => ({
|
||||
appState: { node: null as unknown },
|
||||
outputState: {
|
||||
outputs: undefined as unknown,
|
||||
nodeOutputs: null as { value: Record<string, unknown> } | null
|
||||
}
|
||||
const { appState } = vi.hoisted(() => ({
|
||||
appState: { node: null as unknown }
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: { canvas: { graph: { getNodeById: () => appState.node } } }
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/nodeOutputStore', async () => {
|
||||
const { ref } = await import('vue')
|
||||
const nodeOutputs = ref<Record<string, unknown>>({})
|
||||
outputState.nodeOutputs = nodeOutputs
|
||||
return {
|
||||
useNodeOutputStore: () => ({
|
||||
nodeOutputs,
|
||||
nodePreviewImages: ref({}),
|
||||
getNodeImageUrls: () => undefined,
|
||||
getNodeOutputs: () => outputState.outputs
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
const ctx = {
|
||||
measureText: (s: string) => ({ width: s.length * 7 }),
|
||||
setTransform: () => {},
|
||||
@@ -45,8 +27,6 @@ const ctx = {
|
||||
save: () => {},
|
||||
restore: () => {},
|
||||
beginPath: () => {},
|
||||
arc: () => {},
|
||||
fill: () => {},
|
||||
rect: () => {},
|
||||
clip: () => {},
|
||||
font: '',
|
||||
@@ -148,22 +128,9 @@ const box = (over: Partial<BoundingBox> = {}): BoundingBox => ({
|
||||
...over
|
||||
})
|
||||
|
||||
function makeConnectedNode() {
|
||||
return {
|
||||
widgets: [
|
||||
{ name: 'width', value: 512 },
|
||||
{ name: 'height', value: 512 }
|
||||
],
|
||||
findInputSlot: (name: string) => (name === 'bboxes' ? 1 : -1),
|
||||
getInputNode: () => null,
|
||||
isInputConnected: () => true
|
||||
}
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
appState.node = makeNode()
|
||||
outputState.outputs = undefined
|
||||
vi.stubGlobal('requestAnimationFrame', (cb: FrameRequestCallback) => {
|
||||
void Promise.resolve().then(() => cb(0))
|
||||
return 1
|
||||
@@ -272,68 +239,6 @@ describe('useBoundingBoxes inline editor', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('useBoundingBoxes incoming bboxes input', () => {
|
||||
it('overrides the canvas when the bboxes input is connected', () => {
|
||||
appState.node = makeConnectedNode()
|
||||
outputState.outputs = {
|
||||
input_bboxes: [box({ x: 0, y: 0, width: 100, height: 100 })]
|
||||
}
|
||||
const c = setup([])
|
||||
expect(c.modelValue.value).toHaveLength(1)
|
||||
expect(c.modelValue.value[0].width).toBe(100)
|
||||
})
|
||||
|
||||
it('replaces existing drawn boxes with the incoming ones', () => {
|
||||
appState.node = makeConnectedNode()
|
||||
outputState.outputs = { input_bboxes: [box({ x: 0, width: 100 })] }
|
||||
const c = setup([box({ x: 200, width: 300 }), box({ x: 400, width: 50 })])
|
||||
expect(c.modelValue.value).toHaveLength(1)
|
||||
expect(c.modelValue.value[0].width).toBe(100)
|
||||
})
|
||||
|
||||
it('ignores incoming output when the input is not connected', () => {
|
||||
outputState.outputs = { input_bboxes: [box({ x: 0, width: 100 })] }
|
||||
const c = setup([])
|
||||
expect(c.modelValue.value).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('applies incoming boxes when outputs stream in after mount', async () => {
|
||||
appState.node = makeConnectedNode()
|
||||
const c = setup([])
|
||||
expect(c.modelValue.value).toHaveLength(0)
|
||||
|
||||
outputState.outputs = { input_bboxes: [box({ x: 0, width: 100 })] }
|
||||
outputState.nodeOutputs!.value = { updated: true }
|
||||
await flush()
|
||||
|
||||
expect(c.modelValue.value).toHaveLength(1)
|
||||
expect(c.modelValue.value[0].width).toBe(100)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useBoundingBoxes grid snapping', () => {
|
||||
it('snaps a drawn box to the grid when grid is enabled (default)', async () => {
|
||||
const c = setup()
|
||||
c.onPointerDown(pe(10, 10))
|
||||
c.onCanvasPointerMove(pe(60, 60))
|
||||
c.onDocPointerUp(pe(60, 60))
|
||||
await flush()
|
||||
expect(c.modelValue.value).toHaveLength(1)
|
||||
expect(c.modelValue.value[0].x).toBe(64)
|
||||
expect(c.modelValue.value[0].width).toBe(256)
|
||||
})
|
||||
|
||||
it('does not snap when grid is disabled', async () => {
|
||||
const c = setup()
|
||||
c.grid.value = false
|
||||
c.onPointerDown(pe(10, 10))
|
||||
c.onCanvasPointerMove(pe(55, 55))
|
||||
c.onDocPointerUp(pe(55, 55))
|
||||
await flush()
|
||||
expect(c.modelValue.value[0].width).toBe(230)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useBoundingBoxes hover cursor', () => {
|
||||
it('switches to a pointer cursor over a tag', async () => {
|
||||
const c = setup([box({ x: 10, y: 10, width: 256, height: 256 })])
|
||||
|
||||
@@ -15,7 +15,6 @@ import type {
|
||||
Region
|
||||
} from '@/composables/boundingBoxes/boundingBoxesUtil'
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import type { NodeOutputWith } from '@/schemas/apiSchema'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useNodeOutputStore } from '@/stores/nodeOutputStore'
|
||||
import type { BoundingBox } from '@/types/boundingBoxes'
|
||||
@@ -26,10 +25,6 @@ const HANDLE_PX = 8
|
||||
const DIMENSION_STEP = 16
|
||||
const BG_DIM = 0.75
|
||||
const MAX_ELEMENT_COLORS = 5
|
||||
const GRID_PX = 32
|
||||
const MAX_GRID_CELLS = 64
|
||||
const DOT_COLOR = 'rgba(255,255,255,0.18)'
|
||||
const DOT_RADIUS = 1
|
||||
|
||||
interface InlineEditorState {
|
||||
value: string
|
||||
@@ -62,7 +57,6 @@ export function useBoundingBoxes(
|
||||
const hoverTagIndex = ref<number | null>(null)
|
||||
const bgImage = ref<HTMLImageElement | null>(null)
|
||||
const inlineEditor = ref<InlineEditorState | null>(null)
|
||||
const grid = ref(true)
|
||||
|
||||
const { width: containerWidth } = useElementSize(canvasContainer)
|
||||
|
||||
@@ -102,63 +96,6 @@ export function useBoundingBoxes(
|
||||
return Math.max(0, Math.min(1, n))
|
||||
}
|
||||
|
||||
function gridSpec() {
|
||||
const stepX = Math.max(
|
||||
GRID_PX,
|
||||
Math.ceil(widthValue.value / MAX_GRID_CELLS)
|
||||
)
|
||||
const stepY = Math.max(
|
||||
GRID_PX,
|
||||
Math.ceil(heightValue.value / MAX_GRID_CELLS)
|
||||
)
|
||||
return {
|
||||
fx: stepX / (widthValue.value || 1),
|
||||
fy: stepY / (heightValue.value || 1)
|
||||
}
|
||||
}
|
||||
|
||||
function snapFraction(value: number, step: number) {
|
||||
return step > 0 ? clampToCanvas(Math.round(value / step) * step) : value
|
||||
}
|
||||
|
||||
function snapRegion(region: Region, mode: HitMode): Region {
|
||||
if (!grid.value) return region
|
||||
const { fx, fy } = gridSpec()
|
||||
if (mode === 'move') {
|
||||
return {
|
||||
...region,
|
||||
x: Math.min(snapFraction(region.x, fx), 1 - region.w),
|
||||
y: Math.min(snapFraction(region.y, fy), 1 - region.h)
|
||||
}
|
||||
}
|
||||
const x1 = snapFraction(region.x, fx)
|
||||
const y1 = snapFraction(region.y, fy)
|
||||
const x2 = snapFraction(region.x + region.w, fx)
|
||||
const y2 = snapFraction(region.y + region.h, fy)
|
||||
return {
|
||||
...region,
|
||||
x: x1,
|
||||
y: y1,
|
||||
w: Math.max(0, x2 - x1),
|
||||
h: Math.max(0, y2 - y1)
|
||||
}
|
||||
}
|
||||
|
||||
function drawDots(ctx: CanvasRenderingContext2D, W: number, H: number) {
|
||||
const { fx, fy } = gridSpec()
|
||||
if (fx <= 0 || fy <= 0) return
|
||||
ctx.save()
|
||||
ctx.fillStyle = DOT_COLOR
|
||||
for (let gx = 0; gx <= 1.0001; gx += fx) {
|
||||
for (let gy = 0; gy <= 1.0001; gy += fy) {
|
||||
ctx.beginPath()
|
||||
ctx.arc(gx * W, gy * H, DOT_RADIUS, 0, Math.PI * 2)
|
||||
ctx.fill()
|
||||
}
|
||||
}
|
||||
ctx.restore()
|
||||
}
|
||||
|
||||
function logicalSize() {
|
||||
const el = canvasEl.value
|
||||
return { w: el?.clientWidth || 1, h: el?.clientHeight || 1 }
|
||||
@@ -209,8 +146,6 @@ export function useBoundingBoxes(
|
||||
ctx.fillRect(0, 0, W, H)
|
||||
}
|
||||
|
||||
if (grid.value) drawDots(ctx, W, H)
|
||||
|
||||
const showActive = focused.value || isNodeSelected.value
|
||||
const aIdx = showActive ? activeIndex.value : -1
|
||||
const order = state.value.regions
|
||||
@@ -431,7 +366,7 @@ export function useBoundingBoxes(
|
||||
const dx = mN.x - dragStartNorm.value.x
|
||||
const dy = mN.y - dragStartNorm.value.y
|
||||
const nb = applyDrag(dragMode.value, boxAtStart.value, dx, dy)
|
||||
state.value.regions[activeIndex.value] = snapRegion(nb, dragMode.value)
|
||||
state.value.regions[activeIndex.value] = nb
|
||||
requestDraw()
|
||||
}
|
||||
|
||||
@@ -595,23 +530,6 @@ export function useBoundingBoxes(
|
||||
watch(isNodeSelected, () => requestDraw())
|
||||
watch([widthValue, heightValue], () => syncState())
|
||||
|
||||
watch(
|
||||
litegraphNode,
|
||||
(node) => {
|
||||
const props = node?.properties as { bboxGrid?: unknown } | undefined
|
||||
if (props && typeof props.bboxGrid === 'boolean')
|
||||
grid.value = props.bboxGrid
|
||||
},
|
||||
{ immediate: true }
|
||||
)
|
||||
watch(grid, (enabled) => {
|
||||
const props = litegraphNode.value?.properties as
|
||||
| Record<string, unknown>
|
||||
| undefined
|
||||
if (props) props.bboxGrid = enabled
|
||||
requestDraw()
|
||||
})
|
||||
|
||||
const nodeOutputStore = useNodeOutputStore()
|
||||
function applyImageDimensions(naturalWidth: number, naturalHeight: number) {
|
||||
const node = litegraphNode.value
|
||||
@@ -662,38 +580,10 @@ export function useBoundingBoxes(
|
||||
}
|
||||
img.src = url
|
||||
}
|
||||
let lastIncoming = ''
|
||||
function applyIncomingBoxes() {
|
||||
const node = litegraphNode.value
|
||||
if (!node) return
|
||||
const slot = node.findInputSlot('bboxes')
|
||||
if (slot < 0 || !node.isInputConnected(slot)) {
|
||||
lastIncoming = ''
|
||||
return
|
||||
}
|
||||
const outputs = nodeOutputStore.getNodeOutputs(node) as
|
||||
| NodeOutputWith<{ input_bboxes?: BoundingBox[] }>
|
||||
| undefined
|
||||
const incoming = outputs?.input_bboxes
|
||||
if (!incoming?.length) return
|
||||
const key = JSON.stringify(incoming)
|
||||
if (key === lastIncoming) return
|
||||
lastIncoming = key
|
||||
state.value.regions = fromBoundingBoxes(
|
||||
incoming,
|
||||
widthValue.value,
|
||||
heightValue.value
|
||||
)
|
||||
activeIndex.value = state.value.regions.length ? 0 : -1
|
||||
syncState()
|
||||
}
|
||||
|
||||
watch(() => nodeOutputStore.nodeOutputs, updateBgImage, { deep: true })
|
||||
watch(() => nodeOutputStore.nodePreviewImages, updateBgImage, { deep: true })
|
||||
watch(() => nodeOutputStore.nodeOutputs, applyIncomingBoxes, { deep: true })
|
||||
|
||||
updateBgImage()
|
||||
applyIncomingBoxes()
|
||||
void nextTick(() => requestDraw())
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
@@ -718,7 +608,6 @@ export function useBoundingBoxes(
|
||||
commitInlineEditor,
|
||||
setActiveType,
|
||||
clearAll,
|
||||
syncState,
|
||||
grid
|
||||
syncState
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +30,8 @@ export enum ServerFeatureFlag {
|
||||
COMFYHUB_PROFILE_GATE_ENABLED = 'comfyhub_profile_gate_enabled',
|
||||
SHOW_SIGNIN_BUTTON = 'show_signin_button',
|
||||
UNIFIED_CLOUD_AUTH = 'unified_cloud_auth',
|
||||
SIGNUP_TURNSTILE = 'signup_turnstile'
|
||||
SIGNUP_TURNSTILE = 'signup_turnstile',
|
||||
SERVER_SIDE_MODEL_DOWNLOADS = 'server_side_model_downloads'
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -181,6 +182,12 @@ export function useFeatureFlags() {
|
||||
remoteConfig.value.signup_turnstile,
|
||||
'off'
|
||||
)
|
||||
},
|
||||
get serverSideModelDownloads() {
|
||||
return api.getServerFeature(
|
||||
ServerFeatureFlag.SERVER_SIDE_MODEL_DOWNLOADS,
|
||||
false
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -1069,6 +1069,74 @@
|
||||
"update": "Update",
|
||||
"description": "Check out the latest improvements and features in this update."
|
||||
},
|
||||
"modelManager": {
|
||||
"title": "Downloads",
|
||||
"active": "Active",
|
||||
"history": "History",
|
||||
"empty": "No downloads yet",
|
||||
"clearHistory": "Clear history",
|
||||
"addModel": "Add model",
|
||||
"addModelDescription": "Download a model directly to your ComfyUI models folder.",
|
||||
"url": "URL",
|
||||
"urlPlaceholder": "https://huggingface.co/.../model.safetensors",
|
||||
"hostNotAllowedHint": "This host may not be on the download allowlist.",
|
||||
"folder": "Model folder",
|
||||
"selectFolder": "Select a folder",
|
||||
"filename": "Filename",
|
||||
"filenamePlaceholder": "model.safetensors",
|
||||
"allowAnyExtension": "Allow any file extension (advanced)",
|
||||
"download": "Download",
|
||||
"downloadQueued": "Download queued",
|
||||
"alreadyInstalled": "Model already installed",
|
||||
"alreadyDownloading": "Model is already downloading",
|
||||
"resume": "Resume",
|
||||
"raisePriority": "Raise priority",
|
||||
"removeFromList": "Remove from list",
|
||||
"addCredentials": "Add credentials",
|
||||
"authErrorHint": "{host} needs an API key. Add one in the Credentials Manager, then resume.",
|
||||
"authErrorHintNoHost": "This host/model needs an API key. Add one in the Credentials Manager, then resume.",
|
||||
"gatedModelHint": "This model is gated. Accept its license on the model's page, then add an API key and resume.",
|
||||
"openModelPage": "Accept license",
|
||||
"actionFailed": "Action failed",
|
||||
"cancelConfirmTitle": "Cancel download?",
|
||||
"cancelConfirmMessage": "This will stop the download and delete the partial file for \"{name}\".",
|
||||
"cancelConfirm": "Cancel download",
|
||||
"status": {
|
||||
"queued": "Queued",
|
||||
"active": "Downloading",
|
||||
"paused": "Paused",
|
||||
"verifying": "Verifying",
|
||||
"completed": "Completed",
|
||||
"failed": "Failed",
|
||||
"cancelled": "Cancelled"
|
||||
},
|
||||
"credentials": {
|
||||
"title": "Credentials Manager",
|
||||
"description": "Store API keys per host (e.g. HuggingFace, Civitai). Secrets are write-only and never shown again.",
|
||||
"add": "Add credential",
|
||||
"update": "Update credential",
|
||||
"edit": "Edit",
|
||||
"save": "Save",
|
||||
"saved": "Credential saved",
|
||||
"host": "Host",
|
||||
"secret": "API key",
|
||||
"secretPlaceholder": "Paste API key",
|
||||
"authScheme": "Auth scheme",
|
||||
"headerName": "Header name",
|
||||
"queryParam": "Query parameter",
|
||||
"label": "Label",
|
||||
"disabled": "Disabled",
|
||||
"matchSubdomains": "Match subdomains",
|
||||
"matchSubdomainsWarning": "Not recommended: hubs redirect to sibling CDN hosts that must not receive your key.",
|
||||
"deleteTitle": "Delete credential?",
|
||||
"deleteMessage": "Delete the stored credential for \"{host}\"?",
|
||||
"scheme": {
|
||||
"bearer": "Bearer token",
|
||||
"header": "Custom header",
|
||||
"query": "Query parameter"
|
||||
}
|
||||
}
|
||||
},
|
||||
"menu": {
|
||||
"hideMenu": "Hide Menu",
|
||||
"showMenu": "Show Menu",
|
||||
@@ -2166,8 +2234,7 @@
|
||||
"descLabel": "description",
|
||||
"textPlaceholder": "text to render (verbatim)",
|
||||
"descPlaceholder": "description of this region",
|
||||
"colors": "color_palette",
|
||||
"grid": "grid"
|
||||
"colors": "color_palette"
|
||||
},
|
||||
"palette": {
|
||||
"addColor": "Add a color",
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { DownloadApiError } from '@/platform/modelManager/types'
|
||||
|
||||
import {
|
||||
downloadModel,
|
||||
fetchModelMetadata,
|
||||
@@ -7,13 +9,23 @@ import {
|
||||
toBrowsableUrl
|
||||
} from './missingModelDownload'
|
||||
|
||||
const { fetchMock, mockIsDesktop, mockSidebarTabStore, mockStartDownload } =
|
||||
vi.hoisted(() => ({
|
||||
fetchMock: vi.fn(),
|
||||
mockIsDesktop: { value: false },
|
||||
mockSidebarTabStore: { activeSidebarTabId: null as string | null },
|
||||
mockStartDownload: vi.fn()
|
||||
}))
|
||||
const {
|
||||
fetchMock,
|
||||
mockIsDesktop,
|
||||
mockSidebarTabStore,
|
||||
mockStartDownload,
|
||||
mockEnqueue,
|
||||
mockToastAdd,
|
||||
mockFlags
|
||||
} = vi.hoisted(() => ({
|
||||
fetchMock: vi.fn(),
|
||||
mockIsDesktop: { value: false },
|
||||
mockSidebarTabStore: { activeSidebarTabId: null as string | null },
|
||||
mockStartDownload: vi.fn(),
|
||||
mockEnqueue: vi.fn(),
|
||||
mockToastAdd: vi.fn(),
|
||||
mockFlags: { serverSideModelDownloads: false }
|
||||
}))
|
||||
|
||||
vi.stubGlobal('fetch', fetchMock)
|
||||
|
||||
@@ -33,12 +45,38 @@ vi.mock('@/stores/workspace/sidebarTabStore', () => ({
|
||||
useSidebarTabStore: () => mockSidebarTabStore
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useFeatureFlags', () => ({
|
||||
useFeatureFlags: () => ({ flags: mockFlags })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/modelManager/stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => ({ enqueue: mockEnqueue })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: mockToastAdd })
|
||||
}))
|
||||
|
||||
const mockRefreshModelFolder = vi.fn()
|
||||
const mockRefreshMissingModels = vi.fn()
|
||||
|
||||
vi.mock('@/stores/modelStore', () => ({
|
||||
useModelStore: () => ({ refreshModelFolder: mockRefreshModelFolder })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/missingModel/missingModelStore', () => ({
|
||||
useMissingModelStore: () => ({
|
||||
refreshMissingModels: mockRefreshMissingModels
|
||||
})
|
||||
}))
|
||||
|
||||
let testId = 0
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
vi.resetAllMocks()
|
||||
delete window.__comfyDesktop2
|
||||
mockFlags.serverSideModelDownloads = false
|
||||
})
|
||||
|
||||
describe('fetchModelMetadata', () => {
|
||||
@@ -117,6 +155,26 @@ describe('fetchModelMetadata', () => {
|
||||
expect(fetchMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('returns null metadata when the Civitai request throws', async () => {
|
||||
fetchMock.mockRejectedValueOnce(new Error('network down'))
|
||||
|
||||
const metadata = await fetchModelMetadata(
|
||||
`https://civitai.com/api/download/models/${testId}`
|
||||
)
|
||||
|
||||
expect(metadata).toEqual({ fileSize: null, gatedRepoUrl: null })
|
||||
})
|
||||
|
||||
it('returns null metadata when the HEAD request throws', async () => {
|
||||
fetchMock.mockRejectedValueOnce(new Error('network down'))
|
||||
|
||||
const metadata = await fetchModelMetadata(
|
||||
`https://huggingface.co/org/model/resolve/main/throws-${testId}.safetensors`
|
||||
)
|
||||
|
||||
expect(metadata).toEqual({ fileSize: null, gatedRepoUrl: null })
|
||||
})
|
||||
|
||||
it('returns cached metadata on second call', async () => {
|
||||
const url = `https://huggingface.co/org/model/resolve/main/cached-${testId}.safetensors`
|
||||
|
||||
@@ -240,6 +298,16 @@ describe('isModelDownloadable', () => {
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('allows explicitly whitelisted URLs from an otherwise disallowed host', () => {
|
||||
expect(
|
||||
isModelDownloadable({
|
||||
name: 'RealESRGAN_x4plus.pth',
|
||||
url: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
||||
directory: 'upscale_models'
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('downloadModel', () => {
|
||||
@@ -379,6 +447,152 @@ describe('downloadModel', () => {
|
||||
expect(anchorClick).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('uses the browser fallback on web when server-side downloads are disabled', () => {
|
||||
const anchorClick = vi
|
||||
.spyOn(HTMLAnchorElement.prototype, 'click')
|
||||
.mockImplementation(() => {})
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
expect(anchorClick).toHaveBeenCalledTimes(1)
|
||||
expect(mockEnqueue).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('enqueues a server-side download and reveals the manager when enabled', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockResolvedValue({ download_id: 'd1', accepted: true })
|
||||
const anchorClick = vi
|
||||
.spyOn(HTMLAnchorElement.prototype, 'click')
|
||||
.mockImplementation(() => {})
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockSidebarTabStore.activeSidebarTabId).toBe('model-manager')
|
||||
})
|
||||
expect(mockEnqueue).toHaveBeenCalledWith({
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
model_id: 'checkpoints/model.safetensors'
|
||||
})
|
||||
expect(anchorClick).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows a toast when a server-side enqueue fails', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockRejectedValue(new Error('boom'))
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error', detail: 'boom' })
|
||||
)
|
||||
})
|
||||
expect(mockSidebarTabStore.activeSidebarTabId).toBeNull()
|
||||
})
|
||||
|
||||
it('reveals the download manager and shows an info toast for an in-progress download', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('exists', 'ALREADY_DOWNLOADING', 409)
|
||||
)
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockSidebarTabStore.activeSidebarTabId).toBe('model-manager')
|
||||
})
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'info',
|
||||
detail: 'model.safetensors'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('refreshes the model folder and re-scans missing models when already available', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
|
||||
)
|
||||
mockRefreshModelFolder.mockResolvedValue(undefined)
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'info',
|
||||
detail: 'model.safetensors'
|
||||
})
|
||||
)
|
||||
})
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshModelFolder).toHaveBeenCalledWith('checkpoints')
|
||||
})
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
expect(mockSidebarTabStore.activeSidebarTabId).toBeNull()
|
||||
})
|
||||
|
||||
it('still re-scans missing models when the post-available folder refresh fails', async () => {
|
||||
mockFlags.serverSideModelDownloads = true
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
|
||||
)
|
||||
const consoleWarn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
mockRefreshModelFolder.mockRejectedValue(new Error('boom'))
|
||||
|
||||
downloadModel(
|
||||
{
|
||||
name: 'model.safetensors',
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
directory: 'checkpoints'
|
||||
},
|
||||
{ checkpoints: ['/models/checkpoints'] }
|
||||
)
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
})
|
||||
expect(consoleWarn).toHaveBeenCalled()
|
||||
consoleWarn.mockRestore()
|
||||
})
|
||||
|
||||
it('opens the model library sidebar before starting a desktop download', () => {
|
||||
mockIsDesktop.value = true
|
||||
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
import { downloadUrlToHfRepoUrl, isCivitaiModelUrl } from '@/utils/formatUtil'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { t } from '@/i18n'
|
||||
import { isDesktop } from '@/platform/distribution/types'
|
||||
import { useModelDownloadStore } from '@/platform/modelManager/stores/modelDownloadStore'
|
||||
import { DownloadApiError } from '@/platform/modelManager/types'
|
||||
import { buildModelId } from '@/platform/modelManager/utils/modelId'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { useElectronDownloadStore } from '@/stores/electronDownloadStore'
|
||||
import { useSidebarTabStore } from '@/stores/workspace/sidebarTabStore'
|
||||
import type { ComfyDesktop2Bridge } from '@/types'
|
||||
@@ -29,6 +35,7 @@ const WHITE_LISTED_URLS: ReadonlySet<string> = new Set([
|
||||
])
|
||||
|
||||
const MODEL_LIBRARY_TAB_ID = 'model-library'
|
||||
const MODEL_MANAGER_TAB_ID = 'model-manager'
|
||||
|
||||
export interface ModelWithUrl {
|
||||
name: string
|
||||
@@ -47,6 +54,96 @@ async function startDesktop2ModelDownload(
|
||||
}
|
||||
}
|
||||
|
||||
function revealDownloadManager(): void {
|
||||
useSidebarTabStore().activeSidebarTabId = MODEL_MANAGER_TAB_ID
|
||||
}
|
||||
|
||||
/**
|
||||
* Already on disk: surface a confirmation, refresh the model folder, and
|
||||
* re-scan missing models so any node error for this model clears. Loaded
|
||||
* lazily to keep this module's import graph (and its unit tests) light.
|
||||
*/
|
||||
async function refreshAfterModelAvailable(model: ModelWithUrl): Promise<void> {
|
||||
try {
|
||||
const [{ useModelStore }, { useMissingModelStore }] = await Promise.all([
|
||||
import('@/stores/modelStore'),
|
||||
import('@/platform/missingModel/missingModelStore')
|
||||
])
|
||||
if (model.directory) {
|
||||
try {
|
||||
await useModelStore().refreshModelFolder(model.directory)
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'[MissingModel] Failed to refresh model folder after model available',
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
void useMissingModelStore().refreshMissingModels()
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'[MissingModel] Failed to refresh after model available',
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enqueues a server-side download and reveals the Model Manager panel so the
|
||||
* user can watch live progress, status, and completion. The two benign `409`
|
||||
* cases get an info toast: `ALREADY_DOWNLOADING` links to the existing job,
|
||||
* `ALREADY_AVAILABLE` confirms it's installed and clears the node error. Any
|
||||
* other failure is reported via an error toast.
|
||||
*/
|
||||
async function startServerSideModelDownload(
|
||||
model: ModelWithUrl
|
||||
): Promise<void> {
|
||||
const toast = useToastStore()
|
||||
try {
|
||||
await useModelDownloadStore().enqueue({
|
||||
url: model.url,
|
||||
model_id: buildModelId(model.directory, model.name)
|
||||
})
|
||||
revealDownloadManager()
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof DownloadApiError && error.is('ALREADY_DOWNLOADING')) {
|
||||
revealDownloadManager()
|
||||
toast.add({
|
||||
severity: 'info',
|
||||
summary: t('modelManager.alreadyDownloading'),
|
||||
detail: model.name,
|
||||
life: 4000
|
||||
})
|
||||
return
|
||||
}
|
||||
if (error instanceof DownloadApiError && error.is('ALREADY_AVAILABLE')) {
|
||||
toast.add({
|
||||
severity: 'info',
|
||||
summary: t('modelManager.alreadyInstalled'),
|
||||
detail: model.name,
|
||||
life: 4000
|
||||
})
|
||||
void refreshAfterModelAvailable(model)
|
||||
return
|
||||
}
|
||||
toast.add({
|
||||
severity: 'error',
|
||||
summary: t('modelManager.actionFailed'),
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
life: 5000
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function startBrowserModelDownload(model: ModelWithUrl): void {
|
||||
const link = document.createElement('a')
|
||||
link.href = model.url
|
||||
link.download = model.name
|
||||
link.target = '_blank'
|
||||
link.rel = 'noopener noreferrer'
|
||||
link.click()
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a model download URL to a browsable page URL.
|
||||
* - HuggingFace: `/resolve/` → `/blob/` (file page with model info)
|
||||
@@ -82,12 +179,11 @@ export function downloadModel(
|
||||
}
|
||||
|
||||
if (!isDesktop) {
|
||||
const link = document.createElement('a')
|
||||
link.href = model.url
|
||||
link.download = model.name
|
||||
link.target = '_blank'
|
||||
link.rel = 'noopener noreferrer'
|
||||
link.click()
|
||||
if (useFeatureFlags().flags.serverSideModelDownloads) {
|
||||
void startServerSideModelDownload(model)
|
||||
} else {
|
||||
startBrowserModelDownload(model)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
240
src/platform/modelManager/api/modelDownloadApi.test.ts
Normal file
240
src/platform/modelManager/api/modelDownloadApi.test.ts
Normal file
@@ -0,0 +1,240 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { api } from '@/scripts/api'
|
||||
|
||||
import { DownloadApiError } from '../types'
|
||||
import {
|
||||
cancelDownload,
|
||||
checkAvailability,
|
||||
clearDownloads,
|
||||
deleteCredential,
|
||||
deleteDownload,
|
||||
enqueueDownload,
|
||||
listCredentials,
|
||||
listDownloads,
|
||||
pauseDownload,
|
||||
resumeDownload,
|
||||
setDownloadPriority,
|
||||
upsertCredential
|
||||
} from './modelDownloadApi'
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
fetchApi: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
const fetchApi = vi.mocked(api.fetchApi)
|
||||
|
||||
function jsonResponse(status: number, body: unknown): Response {
|
||||
return {
|
||||
ok: status >= 200 && status < 300,
|
||||
status,
|
||||
statusText: `status ${status}`,
|
||||
json: () => Promise.resolve(body)
|
||||
} as unknown as Response
|
||||
}
|
||||
|
||||
function nonJsonErrorResponse(status: number, statusText: string): Response {
|
||||
return {
|
||||
ok: false,
|
||||
status,
|
||||
statusText,
|
||||
json: () => Promise.reject(new Error('not json'))
|
||||
} as unknown as Response
|
||||
}
|
||||
|
||||
describe('modelDownloadApi', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('enqueueDownload', () => {
|
||||
it('returns the enqueue response on 202', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(202, { download_id: 'd1', accepted: true })
|
||||
)
|
||||
|
||||
const result = await enqueueDownload({
|
||||
url: 'https://huggingface.co/x.safetensors',
|
||||
model_id: 'loras/x.safetensors'
|
||||
})
|
||||
|
||||
expect(result).toEqual({ download_id: 'd1', accepted: true })
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/enqueue',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
|
||||
it('throws a DownloadApiError carrying the error code on 409', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(409, {
|
||||
error: { code: 'ALREADY_DOWNLOADING', message: 'exists' }
|
||||
})
|
||||
)
|
||||
|
||||
await expect(
|
||||
enqueueDownload({ url: 'u', model_id: 'loras/x.safetensors' })
|
||||
).rejects.toMatchObject({
|
||||
code: 'ALREADY_DOWNLOADING',
|
||||
status: 409,
|
||||
message: 'exists'
|
||||
})
|
||||
})
|
||||
|
||||
it('falls back to statusText and an UNKNOWN code for a non-JSON error body', async () => {
|
||||
fetchApi.mockResolvedValue(nonJsonErrorResponse(502, 'Bad Gateway'))
|
||||
|
||||
await expect(
|
||||
enqueueDownload({ url: 'u', model_id: 'loras/x.safetensors' })
|
||||
).rejects.toMatchObject({
|
||||
message: 'Bad Gateway',
|
||||
code: 'UNKNOWN',
|
||||
status: 502
|
||||
})
|
||||
})
|
||||
|
||||
it('exposes the code through DownloadApiError.is()', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(400, {
|
||||
error: { code: 'URL_NOT_ALLOWED', message: 'nope' }
|
||||
})
|
||||
)
|
||||
|
||||
const error = await enqueueDownload({
|
||||
url: 'u',
|
||||
model_id: 'loras/x.safetensors'
|
||||
}).catch((e) => e)
|
||||
|
||||
expect(error).toBeInstanceOf(DownloadApiError)
|
||||
expect(error.is('URL_NOT_ALLOWED')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('listDownloads', () => {
|
||||
it('unwraps the downloads array', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(200, { downloads: [{ download_id: 'd1' }] })
|
||||
)
|
||||
|
||||
const result = await listDownloads()
|
||||
|
||||
expect(result).toEqual([{ download_id: 'd1' }])
|
||||
})
|
||||
})
|
||||
|
||||
describe('actions', () => {
|
||||
it('posts to the pause route', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
|
||||
|
||||
await pauseDownload('d1')
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1/pause',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
|
||||
it('posts to the resume route', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
|
||||
|
||||
await resumeDownload('d1')
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1/resume',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
|
||||
it('posts to the cancel route', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
|
||||
|
||||
await cancelDownload('d1')
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1/cancel',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
|
||||
it('sends the priority in the body', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { ok: true }))
|
||||
|
||||
await setDownloadPriority('d1', 5)
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1/priority',
|
||||
expect.objectContaining({ body: JSON.stringify({ priority: 5 }) })
|
||||
)
|
||||
})
|
||||
|
||||
it('sends a DELETE to the download route', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { deleted: true }))
|
||||
|
||||
await deleteDownload('d1')
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/d1',
|
||||
expect.objectContaining({ method: 'DELETE' })
|
||||
)
|
||||
})
|
||||
|
||||
it('posts to the clear route and returns the deleted count', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { deleted: 3 }))
|
||||
|
||||
const result = await clearDownloads()
|
||||
|
||||
expect(result).toBe(3)
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/clear',
|
||||
expect.objectContaining({ method: 'POST' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('checkAvailability', () => {
|
||||
it('wraps the models map in the request body', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(200, { models: {} }))
|
||||
|
||||
await checkAvailability({ 'loras/x.safetensors': 'https://h.co/x' })
|
||||
|
||||
expect(fetchApi).toHaveBeenCalledWith(
|
||||
'/download/availability',
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
models: { 'loras/x.safetensors': 'https://h.co/x' }
|
||||
})
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('credentials', () => {
|
||||
it('unwraps the credentials list', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(200, { credentials: [{ id: 'c1' }] })
|
||||
)
|
||||
|
||||
expect(await listCredentials()).toEqual([{ id: 'c1' }])
|
||||
})
|
||||
|
||||
it('returns the created credential view', async () => {
|
||||
fetchApi.mockResolvedValue(jsonResponse(201, { id: 'c1', host: 'h' }))
|
||||
|
||||
const result = await upsertCredential({ host: 'h', secret: 's' })
|
||||
|
||||
expect(result).toEqual({ id: 'c1', host: 'h' })
|
||||
})
|
||||
|
||||
it('throws on a failed delete', async () => {
|
||||
fetchApi.mockResolvedValue(
|
||||
jsonResponse(404, { error: { code: 'NOT_FOUND', message: 'gone' } })
|
||||
)
|
||||
|
||||
await expect(deleteCredential('c1')).rejects.toMatchObject({
|
||||
code: 'NOT_FOUND'
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
121
src/platform/modelManager/api/modelDownloadApi.ts
Normal file
121
src/platform/modelManager/api/modelDownloadApi.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import { api } from '@/scripts/api'
|
||||
|
||||
import type {
|
||||
AvailabilityResponse,
|
||||
DownloadStatus,
|
||||
EnqueueRequest,
|
||||
EnqueueResponse,
|
||||
HostCredentialUpsert,
|
||||
HostCredentialView
|
||||
} from '../types'
|
||||
import { DownloadApiError } from '../types'
|
||||
|
||||
const BASE = '/download'
|
||||
|
||||
interface ErrorEnvelope {
|
||||
error?: {
|
||||
code?: string
|
||||
message?: string
|
||||
details?: Record<string, unknown>
|
||||
}
|
||||
}
|
||||
|
||||
async function throwFromResponse(response: Response): Promise<never> {
|
||||
let body: ErrorEnvelope = {}
|
||||
try {
|
||||
body = await response.json()
|
||||
} catch {
|
||||
// Non-JSON error body
|
||||
}
|
||||
const error = body.error
|
||||
throw new DownloadApiError(
|
||||
error?.message ?? response.statusText,
|
||||
error?.code ?? 'UNKNOWN',
|
||||
response.status,
|
||||
error?.details
|
||||
)
|
||||
}
|
||||
|
||||
async function parseJson<T>(response: Response): Promise<T> {
|
||||
if (!response.ok) return throwFromResponse(response)
|
||||
return response.json() as Promise<T>
|
||||
}
|
||||
|
||||
function postJson(route: string, body: unknown): Promise<Response> {
|
||||
return api.fetchApi(route, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body)
|
||||
})
|
||||
}
|
||||
|
||||
export async function enqueueDownload(
|
||||
body: EnqueueRequest
|
||||
): Promise<EnqueueResponse> {
|
||||
const response = await postJson(`${BASE}/enqueue`, body)
|
||||
if (response.status === 202) {
|
||||
return response.json() as Promise<EnqueueResponse>
|
||||
}
|
||||
return throwFromResponse(response)
|
||||
}
|
||||
|
||||
export async function listDownloads(): Promise<DownloadStatus[]> {
|
||||
const response = await api.fetchApi(BASE)
|
||||
const data = await parseJson<{ downloads: DownloadStatus[] }>(response)
|
||||
return data.downloads
|
||||
}
|
||||
|
||||
async function postAction(id: string, action: string): Promise<void> {
|
||||
const response = await postJson(`${BASE}/${id}/${action}`, undefined)
|
||||
await parseJson<{ ok: boolean }>(response)
|
||||
}
|
||||
|
||||
export const pauseDownload = (id: string) => postAction(id, 'pause')
|
||||
export const resumeDownload = (id: string) => postAction(id, 'resume')
|
||||
export const cancelDownload = (id: string) => postAction(id, 'cancel')
|
||||
|
||||
export async function deleteDownload(id: string): Promise<void> {
|
||||
const response = await api.fetchApi(`${BASE}/${id}`, { method: 'DELETE' })
|
||||
await parseJson<{ deleted: boolean }>(response)
|
||||
}
|
||||
|
||||
export async function clearDownloads(): Promise<number> {
|
||||
const response = await postJson(`${BASE}/clear`, undefined)
|
||||
const data = await parseJson<{ deleted: number }>(response)
|
||||
return data.deleted
|
||||
}
|
||||
|
||||
export async function setDownloadPriority(
|
||||
id: string,
|
||||
priority: number
|
||||
): Promise<void> {
|
||||
const response = await postJson(`${BASE}/${id}/priority`, { priority })
|
||||
await parseJson<{ ok: boolean }>(response)
|
||||
}
|
||||
|
||||
export async function checkAvailability(
|
||||
models: Record<string, string>
|
||||
): Promise<AvailabilityResponse> {
|
||||
const response = await postJson(`${BASE}/availability`, { models })
|
||||
return parseJson<AvailabilityResponse>(response)
|
||||
}
|
||||
|
||||
export async function listCredentials(): Promise<HostCredentialView[]> {
|
||||
const response = await api.fetchApi(`${BASE}/credentials`)
|
||||
const data = await parseJson<{ credentials: HostCredentialView[] }>(response)
|
||||
return data.credentials
|
||||
}
|
||||
|
||||
export async function upsertCredential(
|
||||
body: HostCredentialUpsert
|
||||
): Promise<HostCredentialView> {
|
||||
const response = await postJson(`${BASE}/credentials`, body)
|
||||
return parseJson<HostCredentialView>(response)
|
||||
}
|
||||
|
||||
export async function deleteCredential(id: string): Promise<void> {
|
||||
const response = await api.fetchApi(`${BASE}/credentials/${id}`, {
|
||||
method: 'DELETE'
|
||||
})
|
||||
await parseJson<{ deleted: boolean }>(response)
|
||||
}
|
||||
253
src/platform/modelManager/components/AddModelByUrlDialog.test.ts
Normal file
253
src/platform/modelManager/components/AddModelByUrlDialog.test.ts
Normal file
@@ -0,0 +1,253 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { defineComponent, ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import { DownloadApiError } from '../types'
|
||||
import AddModelByUrlDialog from './AddModelByUrlDialog.vue'
|
||||
|
||||
const mockEnqueue = vi.fn()
|
||||
const mockToastAdd = vi.fn()
|
||||
const mockFetchModelTypes = vi.fn()
|
||||
const mockModelTypes = ref([
|
||||
{ name: 'Checkpoints', value: 'checkpoints' },
|
||||
{ name: 'LoRA', value: 'loras' }
|
||||
])
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => ({ enqueue: mockEnqueue })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/composables/useModelTypes', () => ({
|
||||
useModelTypes: () => ({
|
||||
modelTypes: mockModelTypes,
|
||||
isLoading: ref(false),
|
||||
fetchModelTypes: mockFetchModelTypes
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: mockToastAdd })
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
const stubs = {
|
||||
Dialog: { template: '<div><slot /></div>' },
|
||||
DialogPortal: { template: '<div><slot /></div>' },
|
||||
DialogOverlay: { template: '<div />' },
|
||||
DialogContent: defineComponent({
|
||||
emits: ['open-auto-focus'],
|
||||
mounted() {
|
||||
this.$emit('open-auto-focus')
|
||||
},
|
||||
template: '<div><slot /></div>'
|
||||
}),
|
||||
DialogHeader: { template: '<div><slot /></div>' },
|
||||
DialogTitle: { template: '<div><slot /></div>' },
|
||||
DialogDescription: { template: '<div><slot /></div>' },
|
||||
DialogFooter: { template: '<div><slot /></div>' },
|
||||
SingleSelect: {
|
||||
props: ['modelValue', 'options', 'label', 'loading'],
|
||||
emits: ['update:modelValue'],
|
||||
template: `
|
||||
<select
|
||||
data-testid="folder-select"
|
||||
:value="modelValue ?? ''"
|
||||
@change="$emit('update:modelValue', $event.target.value)"
|
||||
>
|
||||
<option value="" disabled>{{ label }}</option>
|
||||
<option v-for="opt in options" :key="opt.value" :value="opt.value">
|
||||
{{ opt.name }}
|
||||
</option>
|
||||
</select>
|
||||
`
|
||||
}
|
||||
}
|
||||
|
||||
function mountDialog(open = true) {
|
||||
return render(AddModelByUrlDialog, {
|
||||
props: { open },
|
||||
global: { plugins: [i18n], stubs }
|
||||
})
|
||||
}
|
||||
|
||||
async function fillValidForm() {
|
||||
await userEvent.type(
|
||||
screen.getByLabelText('URL'),
|
||||
'https://huggingface.co/org/model/resolve/main/model.safetensors'
|
||||
)
|
||||
await userEvent.selectOptions(
|
||||
screen.getByTestId('folder-select'),
|
||||
'checkpoints'
|
||||
)
|
||||
}
|
||||
|
||||
describe('AddModelByUrlDialog', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('auto-fills the filename from the url until edited manually', async () => {
|
||||
mountDialog()
|
||||
const urlInput = screen.getByLabelText('URL')
|
||||
const filenameInput = screen.getByLabelText('Filename')
|
||||
|
||||
await userEvent.type(
|
||||
urlInput,
|
||||
'https://huggingface.co/org/model/resolve/main/model.safetensors'
|
||||
)
|
||||
expect(filenameInput).toHaveValue('model.safetensors')
|
||||
|
||||
await userEvent.clear(filenameInput)
|
||||
await userEvent.type(filenameInput, 'custom.safetensors')
|
||||
await userEvent.type(urlInput, '?download=true')
|
||||
|
||||
expect(filenameInput).toHaveValue('custom.safetensors')
|
||||
})
|
||||
|
||||
it('shows a hint for a url on a non-allowlisted host', async () => {
|
||||
mountDialog()
|
||||
|
||||
await userEvent.type(
|
||||
screen.getByLabelText('URL'),
|
||||
'https://example.com/model.safetensors'
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText('This host may not be on the download allowlist.')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('disables submit until url, folder, and a valid filename are set', async () => {
|
||||
mountDialog()
|
||||
const submit = screen.getByRole('button', { name: 'Download' })
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.type(
|
||||
screen.getByLabelText('URL'),
|
||||
'https://huggingface.co/org/model/resolve/main/model.safetensors'
|
||||
)
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.selectOptions(
|
||||
screen.getByTestId('folder-select'),
|
||||
'checkpoints'
|
||||
)
|
||||
expect(submit).toBeEnabled()
|
||||
})
|
||||
|
||||
it('requires a known model extension', async () => {
|
||||
mountDialog()
|
||||
await userEvent.type(
|
||||
screen.getByLabelText('URL'),
|
||||
'https://huggingface.co/org/model/resolve/main/readme.txt'
|
||||
)
|
||||
await userEvent.selectOptions(
|
||||
screen.getByTestId('folder-select'),
|
||||
'checkpoints'
|
||||
)
|
||||
const submit = screen.getByRole('button', { name: 'Download' })
|
||||
expect(submit).toBeDisabled()
|
||||
})
|
||||
|
||||
it('closes without submitting when cancel is clicked', async () => {
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Cancel' }))
|
||||
|
||||
expect(mockEnqueue).not.toHaveBeenCalled()
|
||||
expect(emitted('update:open')?.at(-1)).toEqual([false])
|
||||
})
|
||||
|
||||
it('enqueues the download and closes on success', async () => {
|
||||
mockEnqueue.mockResolvedValue({ download_id: 'd1', accepted: true })
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
expect(mockEnqueue).toHaveBeenCalledWith({
|
||||
url: 'https://huggingface.co/org/model/resolve/main/model.safetensors',
|
||||
model_id: 'checkpoints/model.safetensors',
|
||||
allow_any_extension: false
|
||||
})
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
})
|
||||
expect(emitted('update:open')?.at(-1)).toEqual([false])
|
||||
})
|
||||
|
||||
it('shows an info toast and closes when the model is already downloading', async () => {
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('exists', 'ALREADY_DOWNLOADING', 409)
|
||||
)
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'info' })
|
||||
)
|
||||
})
|
||||
expect(emitted('update:open')?.at(-1)).toEqual([false])
|
||||
})
|
||||
|
||||
it('shows an info toast and closes when the model is already installed', async () => {
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('already there', 'ALREADY_AVAILABLE', 409)
|
||||
)
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'info' })
|
||||
)
|
||||
})
|
||||
expect(emitted('update:open')?.at(-1)).toEqual([false])
|
||||
})
|
||||
|
||||
it('shows an inline error for other API failures and stays open', async () => {
|
||||
mockEnqueue.mockRejectedValue(
|
||||
new DownloadApiError('url not allowed', 'URL_NOT_ALLOWED', 400)
|
||||
)
|
||||
const { emitted } = mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(screen.getByText('url not allowed')).toBeInTheDocument()
|
||||
})
|
||||
expect(emitted('update:open')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('shows a generic error message for a non-api error', async () => {
|
||||
mockEnqueue.mockRejectedValue(new Error('network down'))
|
||||
mountDialog()
|
||||
await fillValidForm()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Download' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(screen.getByText('network down')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
237
src/platform/modelManager/components/AddModelByUrlDialog.vue
Normal file
237
src/platform/modelManager/components/AddModelByUrlDialog.vue
Normal file
@@ -0,0 +1,237 @@
|
||||
<template>
|
||||
<Dialog v-model:open="isOpen">
|
||||
<DialogPortal>
|
||||
<DialogOverlay class="bg-black/70" />
|
||||
<DialogContent
|
||||
size="md"
|
||||
class="flex flex-col gap-4 p-6"
|
||||
@open-auto-focus="onOpen"
|
||||
>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{{ $t('modelManager.addModel') }}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{{ $t('modelManager.addModelDescription') }}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="download-url">
|
||||
{{ $t('modelManager.url') }}
|
||||
</label>
|
||||
<Input
|
||||
id="download-url"
|
||||
v-model="url"
|
||||
:placeholder="$t('modelManager.urlPlaceholder')"
|
||||
@update:model-value="onUrlChanged"
|
||||
/>
|
||||
<p v-if="hostHint" class="text-xs text-amber-400">{{ hostHint }}</p>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground">
|
||||
{{ $t('modelManager.folder') }}
|
||||
</label>
|
||||
<SingleSelect
|
||||
v-model="directory"
|
||||
:options="folderOptions"
|
||||
:label="$t('modelManager.selectFolder')"
|
||||
:loading="isLoadingFolders"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="download-filename">
|
||||
{{ $t('modelManager.filename') }}
|
||||
</label>
|
||||
<Input
|
||||
id="download-filename"
|
||||
v-model="filename"
|
||||
:placeholder="$t('modelManager.filenamePlaceholder')"
|
||||
@update:model-value="onFilenameEdited"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- TODO: re-enable once we think we'd want to allow any extension
|
||||
<label class="flex w-fit items-center gap-2 text-xs text-muted-foreground">
|
||||
<input v-model="allowAnyExtension" type="checkbox" class="size-4" />
|
||||
{{ $t('modelManager.allowAnyExtension') }}
|
||||
</label>
|
||||
-->
|
||||
|
||||
<p
|
||||
v-if="modelId"
|
||||
class="truncate rounded-md bg-secondary-background px-2 py-1 text-xs text-muted-foreground"
|
||||
>
|
||||
{{ modelId }}
|
||||
</p>
|
||||
|
||||
<p v-if="errorMessage" class="text-xs text-red-400">
|
||||
{{ errorMessage }}
|
||||
</p>
|
||||
|
||||
<DialogFooter>
|
||||
<Button variant="secondary" @click="isOpen = false">
|
||||
{{ $t('g.cancel') }}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
:disabled="!canSubmit"
|
||||
:loading="isSubmitting"
|
||||
@click="submit"
|
||||
>
|
||||
{{ $t('modelManager.download') }}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</DialogPortal>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import Dialog from '@/components/ui/dialog/Dialog.vue'
|
||||
import DialogContent from '@/components/ui/dialog/DialogContent.vue'
|
||||
import DialogDescription from '@/components/ui/dialog/DialogDescription.vue'
|
||||
import DialogFooter from '@/components/ui/dialog/DialogFooter.vue'
|
||||
import DialogHeader from '@/components/ui/dialog/DialogHeader.vue'
|
||||
import DialogOverlay from '@/components/ui/dialog/DialogOverlay.vue'
|
||||
import DialogPortal from '@/components/ui/dialog/DialogPortal.vue'
|
||||
import DialogTitle from '@/components/ui/dialog/DialogTitle.vue'
|
||||
import Input from '@/components/ui/input/Input.vue'
|
||||
import SingleSelect from '@/components/ui/single-select/SingleSelect.vue'
|
||||
import { useModelTypes } from '@/platform/assets/composables/useModelTypes'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
import { DownloadApiError } from '../types'
|
||||
import {
|
||||
buildModelId,
|
||||
filenameFromUrl,
|
||||
hasModelExtension,
|
||||
isLikelyAllowedHost,
|
||||
isValidPathSegment
|
||||
} from '../utils/modelId'
|
||||
|
||||
const isOpen = defineModel<boolean>('open', { required: true })
|
||||
|
||||
const { t } = useI18n()
|
||||
const store = useModelDownloadStore()
|
||||
const {
|
||||
modelTypes,
|
||||
isLoading: isLoadingFolders,
|
||||
fetchModelTypes
|
||||
} = useModelTypes()
|
||||
|
||||
const url = ref('')
|
||||
const directory = ref<string | undefined>(undefined)
|
||||
const filename = ref('')
|
||||
const isFilenameUserEdited = ref(false)
|
||||
const allowAnyExtension = ref(false)
|
||||
const isSubmitting = ref(false)
|
||||
const errorMessage = ref('')
|
||||
|
||||
const folderOptions = computed(() => modelTypes.value)
|
||||
|
||||
const modelId = computed(() =>
|
||||
directory.value && filename.value
|
||||
? buildModelId(directory.value, filename.value)
|
||||
: ''
|
||||
)
|
||||
|
||||
const hostHint = computed(() =>
|
||||
url.value && !isLikelyAllowedHost(url.value)
|
||||
? t('modelManager.hostNotAllowedHint')
|
||||
: ''
|
||||
)
|
||||
|
||||
const canSubmit = computed(
|
||||
() =>
|
||||
!!url.value &&
|
||||
!!directory.value &&
|
||||
isValidPathSegment(filename.value) &&
|
||||
(allowAnyExtension.value || hasModelExtension(filename.value))
|
||||
)
|
||||
|
||||
function onOpen() {
|
||||
void fetchModelTypes()
|
||||
errorMessage.value = ''
|
||||
}
|
||||
|
||||
function onUrlChanged() {
|
||||
if (!isFilenameUserEdited.value) {
|
||||
filename.value = filenameFromUrl(url.value)
|
||||
}
|
||||
}
|
||||
|
||||
function onFilenameEdited() {
|
||||
isFilenameUserEdited.value = true
|
||||
}
|
||||
|
||||
function reset() {
|
||||
url.value = ''
|
||||
directory.value = undefined
|
||||
filename.value = ''
|
||||
isFilenameUserEdited.value = false
|
||||
allowAnyExtension.value = false
|
||||
errorMessage.value = ''
|
||||
}
|
||||
|
||||
async function submit() {
|
||||
if (!canSubmit.value) return
|
||||
isSubmitting.value = true
|
||||
errorMessage.value = ''
|
||||
try {
|
||||
await store.enqueue({
|
||||
url: url.value,
|
||||
model_id: modelId.value,
|
||||
allow_any_extension: allowAnyExtension.value
|
||||
})
|
||||
useToastStore().add({
|
||||
severity: 'success',
|
||||
summary: t('modelManager.downloadQueued'),
|
||||
detail: modelId.value,
|
||||
life: 4000
|
||||
})
|
||||
reset()
|
||||
isOpen.value = false
|
||||
} catch (error) {
|
||||
handleEnqueueError(error)
|
||||
} finally {
|
||||
isSubmitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function handleEnqueueError(error: unknown) {
|
||||
if (error instanceof DownloadApiError) {
|
||||
if (error.is('ALREADY_AVAILABLE')) {
|
||||
useToastStore().add({
|
||||
severity: 'info',
|
||||
summary: t('modelManager.alreadyInstalled'),
|
||||
detail: modelId.value,
|
||||
life: 4000
|
||||
})
|
||||
reset()
|
||||
isOpen.value = false
|
||||
return
|
||||
}
|
||||
if (error.is('ALREADY_DOWNLOADING')) {
|
||||
useToastStore().add({
|
||||
severity: 'info',
|
||||
summary: t('modelManager.alreadyDownloading'),
|
||||
detail: modelId.value,
|
||||
life: 4000
|
||||
})
|
||||
reset()
|
||||
isOpen.value = false
|
||||
return
|
||||
}
|
||||
errorMessage.value = error.message
|
||||
return
|
||||
}
|
||||
errorMessage.value =
|
||||
error instanceof Error ? error.message : t('modelManager.actionFailed')
|
||||
}
|
||||
</script>
|
||||
@@ -0,0 +1,313 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { defineComponent, nextTick, reactive, ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { HostCredentialView } from '../types'
|
||||
import HostCredentialsDialog from './HostCredentialsDialog.vue'
|
||||
|
||||
const mockCloseDialog = vi.fn()
|
||||
const mockToastAdd = vi.fn()
|
||||
|
||||
const mockCredentialsStore = reactive({
|
||||
credentials: ref<HostCredentialView[]>([]),
|
||||
isLoading: ref(false),
|
||||
fetchCredentials: vi.fn(),
|
||||
upsert: vi.fn(),
|
||||
remove: vi.fn()
|
||||
})
|
||||
|
||||
vi.mock('../stores/hostCredentialsStore', () => ({
|
||||
useHostCredentialsStore: () => mockCredentialsStore
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/dialogStore', () => ({
|
||||
useDialogStore: () => ({ closeDialog: mockCloseDialog })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: mockToastAdd })
|
||||
}))
|
||||
|
||||
vi.mock('@/components/dialog/confirm/confirmDialog')
|
||||
|
||||
const mockShowConfirmDialog = vi.mocked(showConfirmDialog)
|
||||
|
||||
interface CapturedConfirmOptions {
|
||||
footerProps?: {
|
||||
onConfirm?: () => void | Promise<void>
|
||||
onCancel?: () => void
|
||||
}
|
||||
}
|
||||
|
||||
function capturedOptions(): CapturedConfirmOptions {
|
||||
return mockShowConfirmDialog.mock
|
||||
.calls[0][0] as unknown as CapturedConfirmOptions
|
||||
}
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
const stubs = {
|
||||
Dialog: { template: '<div><slot /></div>' },
|
||||
DialogPortal: { template: '<div><slot /></div>' },
|
||||
DialogOverlay: { template: '<div />' },
|
||||
DialogContent: defineComponent({
|
||||
emits: ['open-auto-focus'],
|
||||
mounted() {
|
||||
this.$emit('open-auto-focus')
|
||||
},
|
||||
template: '<div><slot /></div>'
|
||||
}),
|
||||
DialogHeader: { template: '<div><slot /></div>' },
|
||||
DialogTitle: { template: '<div><slot /></div>' },
|
||||
DialogDescription: { template: '<div><slot /></div>' },
|
||||
SingleSelect: {
|
||||
props: ['modelValue', 'options', 'label'],
|
||||
emits: ['update:modelValue'],
|
||||
template: `
|
||||
<select
|
||||
data-testid="scheme-select"
|
||||
:value="modelValue ?? ''"
|
||||
@change="$emit('update:modelValue', $event.target.value)"
|
||||
>
|
||||
<option v-for="opt in options" :key="opt.value" :value="opt.value">
|
||||
{{ opt.name }}
|
||||
</option>
|
||||
</select>
|
||||
`
|
||||
}
|
||||
}
|
||||
|
||||
function createCredential(
|
||||
overrides: Partial<HostCredentialView> = {}
|
||||
): HostCredentialView {
|
||||
return {
|
||||
id: 'c1',
|
||||
host: 'huggingface.co',
|
||||
auth_scheme: 'bearer',
|
||||
header_name: null,
|
||||
query_param: null,
|
||||
label: null,
|
||||
match_subdomains: false,
|
||||
enabled: true,
|
||||
secret_last4: '1234',
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function mountDialog(props: { open?: boolean; prefillHost?: string } = {}) {
|
||||
return render(HostCredentialsDialog, {
|
||||
props: { open: true, ...props },
|
||||
global: { plugins: [i18n], stubs }
|
||||
})
|
||||
}
|
||||
|
||||
describe('HostCredentialsDialog', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockCredentialsStore.credentials = []
|
||||
mockCredentialsStore.fetchCredentials.mockResolvedValue(undefined)
|
||||
mockShowConfirmDialog.mockReturnValue(
|
||||
{} as ReturnType<typeof showConfirmDialog>
|
||||
)
|
||||
})
|
||||
|
||||
it('renders existing credentials with host, scheme, and last 4 of the secret', () => {
|
||||
mockCredentialsStore.credentials = [
|
||||
createCredential({
|
||||
host: 'huggingface.co',
|
||||
auth_scheme: 'bearer',
|
||||
secret_last4: '9f21'
|
||||
})
|
||||
]
|
||||
mountDialog()
|
||||
|
||||
expect(
|
||||
screen.getByText('huggingface.co · Bearer token · ••••9f21')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows a disabled marker for a disabled credential', () => {
|
||||
mockCredentialsStore.credentials = [createCredential({ enabled: false })]
|
||||
mountDialog()
|
||||
|
||||
expect(screen.getByText(/Disabled/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('prefills the host from the prefillHost prop on open', async () => {
|
||||
mountDialog({ prefillHost: 'civitai.com' })
|
||||
await nextTick()
|
||||
|
||||
expect(screen.getByLabelText('Host')).toHaveValue('civitai.com')
|
||||
})
|
||||
|
||||
it('populates the form when editing an existing credential', async () => {
|
||||
mockCredentialsStore.credentials = [
|
||||
createCredential({
|
||||
id: 'c1',
|
||||
host: 'civitai.com',
|
||||
auth_scheme: 'header',
|
||||
header_name: 'X-Api-Key',
|
||||
label: 'My Civitai key'
|
||||
})
|
||||
]
|
||||
mountDialog()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Edit'))
|
||||
|
||||
expect(screen.getByLabelText('Host')).toHaveValue('civitai.com')
|
||||
expect(screen.getByLabelText('Label')).toHaveValue('My Civitai key')
|
||||
expect(screen.getByText('Update credential')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('disables submit until host and secret are filled', async () => {
|
||||
mountDialog()
|
||||
const submit = screen.getByRole('button', { name: 'Save' })
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
|
||||
expect(submit).toBeEnabled()
|
||||
})
|
||||
|
||||
it('requires a query parameter name when the scheme is query', async () => {
|
||||
mountDialog()
|
||||
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
|
||||
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
|
||||
await userEvent.selectOptions(screen.getByTestId('scheme-select'), 'query')
|
||||
|
||||
const submit = screen.getByRole('button', { name: 'Save' })
|
||||
expect(submit).toBeDisabled()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('Query parameter'), 'token')
|
||||
expect(submit).toBeEnabled()
|
||||
})
|
||||
|
||||
it('shows the header name field and a subdomain warning for the header scheme', async () => {
|
||||
mountDialog()
|
||||
|
||||
await userEvent.selectOptions(screen.getByTestId('scheme-select'), 'header')
|
||||
expect(screen.getByLabelText('Header name')).toBeInTheDocument()
|
||||
|
||||
await userEvent.click(screen.getByLabelText('Match subdomains'))
|
||||
expect(
|
||||
screen.getByText(
|
||||
'Not recommended: hubs redirect to sibling CDN hosts that must not receive your key.'
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('cancels an in-progress edit and resets the form', async () => {
|
||||
mockCredentialsStore.credentials = [
|
||||
createCredential({ id: 'c1', host: 'civitai.com' })
|
||||
]
|
||||
mountDialog()
|
||||
await userEvent.click(screen.getByTitle('Edit'))
|
||||
expect(screen.getByText('Update credential')).toBeInTheDocument()
|
||||
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Cancel' }))
|
||||
|
||||
expect(screen.getByText('Add credential')).toBeInTheDocument()
|
||||
expect(screen.getByLabelText('Host')).toHaveValue('')
|
||||
})
|
||||
|
||||
it('submits the form payload and resets on success', async () => {
|
||||
mockCredentialsStore.upsert.mockResolvedValue(createCredential())
|
||||
mountDialog()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
|
||||
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Save' }))
|
||||
|
||||
expect(mockCredentialsStore.upsert).toHaveBeenCalledWith({
|
||||
host: 'huggingface.co',
|
||||
secret: 's3cret',
|
||||
auth_scheme: 'bearer',
|
||||
header_name: null,
|
||||
query_param: null,
|
||||
label: null,
|
||||
match_subdomains: false
|
||||
})
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'success' })
|
||||
)
|
||||
})
|
||||
expect(screen.getByLabelText('Host')).toHaveValue('')
|
||||
})
|
||||
|
||||
it('shows an inline error message when saving fails', async () => {
|
||||
mockCredentialsStore.upsert.mockRejectedValue(new Error('save failed'))
|
||||
mountDialog()
|
||||
|
||||
await userEvent.type(screen.getByLabelText('Host'), 'huggingface.co')
|
||||
await userEvent.type(screen.getByLabelText('API key'), 's3cret')
|
||||
await userEvent.click(screen.getByRole('button', { name: 'Save' }))
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(screen.getByText('save failed')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('shows an inline error when the initial credentials fetch fails', async () => {
|
||||
mockCredentialsStore.fetchCredentials.mockRejectedValue(
|
||||
new Error('offline')
|
||||
)
|
||||
mountDialog()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(screen.getByText('offline')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('deletes a credential through a confirm dialog', async () => {
|
||||
mockCredentialsStore.remove.mockResolvedValue(undefined)
|
||||
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
|
||||
mountDialog()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Delete'))
|
||||
await capturedOptions().footerProps?.onConfirm?.()
|
||||
|
||||
expect(mockCredentialsStore.remove).toHaveBeenCalledWith('c1')
|
||||
expect(mockCloseDialog).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows an error toast when deletion fails', async () => {
|
||||
mockCredentialsStore.remove.mockRejectedValue(new Error('boom'))
|
||||
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
|
||||
mountDialog()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Delete'))
|
||||
await capturedOptions().footerProps?.onConfirm?.()
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error' })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
it('does not delete when the confirm dialog is dismissed', async () => {
|
||||
mockCredentialsStore.credentials = [createCredential({ id: 'c1' })]
|
||||
mountDialog()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Delete'))
|
||||
capturedOptions().footerProps?.onCancel?.()
|
||||
|
||||
expect(mockCredentialsStore.remove).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
333
src/platform/modelManager/components/HostCredentialsDialog.vue
Normal file
333
src/platform/modelManager/components/HostCredentialsDialog.vue
Normal file
@@ -0,0 +1,333 @@
|
||||
<template>
|
||||
<Dialog v-model:open="isOpen">
|
||||
<DialogPortal>
|
||||
<DialogOverlay class="bg-black/70" />
|
||||
<DialogContent
|
||||
size="md"
|
||||
class="flex flex-col gap-4 p-6"
|
||||
@open-auto-focus="onOpen"
|
||||
>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{{ $t('modelManager.credentials.title') }}</DialogTitle>
|
||||
<DialogDescription>
|
||||
{{ $t('modelManager.credentials.description') }}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<div v-if="credentials.length" class="flex flex-col gap-2">
|
||||
<div
|
||||
v-for="credential in credentials"
|
||||
:key="credential.id"
|
||||
class="flex items-center gap-2 rounded-lg border border-border-default bg-secondary-background px-3 py-2"
|
||||
>
|
||||
<div class="flex min-w-0 flex-1 flex-col">
|
||||
<span class="truncate text-sm font-medium text-base-foreground">
|
||||
{{ credential.label || credential.host }}
|
||||
</span>
|
||||
<span class="truncate text-xs text-muted-foreground">
|
||||
{{ credential.host }} ·
|
||||
{{
|
||||
$t(
|
||||
`modelManager.credentials.scheme.${credential.auth_scheme}`
|
||||
)
|
||||
}}
|
||||
<template v-if="credential.secret_last4">
|
||||
· ••••{{ credential.secret_last4 }}
|
||||
</template>
|
||||
<template v-if="!credential.enabled">
|
||||
· {{ $t('modelManager.credentials.disabled') }}
|
||||
</template>
|
||||
</span>
|
||||
</div>
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.credentials.edit')"
|
||||
@click="editCredential(credential)"
|
||||
>
|
||||
<i class="icon-[lucide--pencil] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('g.delete')"
|
||||
@click="confirmDelete(credential)"
|
||||
>
|
||||
<i class="icon-[lucide--trash-2] size-4 text-red-400" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-3 border-t border-border-default pt-3">
|
||||
<span class="text-sm font-medium text-base-foreground">
|
||||
{{
|
||||
form.id
|
||||
? $t('modelManager.credentials.update')
|
||||
: $t('modelManager.credentials.add')
|
||||
}}
|
||||
</span>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-host">
|
||||
{{ $t('modelManager.credentials.host') }}
|
||||
</label>
|
||||
<Input
|
||||
id="cred-host"
|
||||
v-model="form.host"
|
||||
:placeholder="HOST_EXAMPLE"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-secret">
|
||||
{{ $t('modelManager.credentials.secret') }}
|
||||
</label>
|
||||
<Input
|
||||
id="cred-secret"
|
||||
v-model="form.secret"
|
||||
type="password"
|
||||
:placeholder="$t('modelManager.credentials.secretPlaceholder')"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground">
|
||||
{{ $t('modelManager.credentials.authScheme') }}
|
||||
</label>
|
||||
<SingleSelect v-model="form.auth_scheme" :options="schemeOptions" />
|
||||
</div>
|
||||
|
||||
<div v-if="form.auth_scheme === 'header'" class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-header">
|
||||
{{ $t('modelManager.credentials.headerName') }}
|
||||
</label>
|
||||
<Input
|
||||
id="cred-header"
|
||||
v-model="form.header_name"
|
||||
:placeholder="HEADER_NAME_EXAMPLE"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div v-if="form.auth_scheme === 'query'" class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-query">
|
||||
{{ $t('modelManager.credentials.queryParam') }}
|
||||
</label>
|
||||
<Input
|
||||
id="cred-query"
|
||||
v-model="form.query_param"
|
||||
:placeholder="QUERY_PARAM_EXAMPLE"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-col gap-1">
|
||||
<label class="text-xs text-muted-foreground" for="cred-label">
|
||||
{{ $t('modelManager.credentials.label') }}
|
||||
</label>
|
||||
<Input id="cred-label" v-model="form.label" />
|
||||
</div>
|
||||
|
||||
<label class="flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<input
|
||||
v-model="form.match_subdomains"
|
||||
type="checkbox"
|
||||
class="size-4"
|
||||
/>
|
||||
{{ $t('modelManager.credentials.matchSubdomains') }}
|
||||
</label>
|
||||
<p v-if="form.match_subdomains" class="text-xs text-amber-400">
|
||||
{{ $t('modelManager.credentials.matchSubdomainsWarning') }}
|
||||
</p>
|
||||
|
||||
<p v-if="errorMessage" class="text-xs text-red-400">
|
||||
{{ errorMessage }}
|
||||
</p>
|
||||
|
||||
<div class="flex justify-end gap-2">
|
||||
<Button v-if="form.id" variant="secondary" @click="resetForm">
|
||||
{{ $t('g.cancel') }}
|
||||
</Button>
|
||||
<Button
|
||||
variant="primary"
|
||||
:disabled="!canSubmit"
|
||||
:loading="isSubmitting"
|
||||
@click="submit"
|
||||
>
|
||||
{{ $t('modelManager.credentials.save') }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</DialogPortal>
|
||||
</Dialog>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, reactive, ref } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import Dialog from '@/components/ui/dialog/Dialog.vue'
|
||||
import DialogContent from '@/components/ui/dialog/DialogContent.vue'
|
||||
import DialogDescription from '@/components/ui/dialog/DialogDescription.vue'
|
||||
import DialogHeader from '@/components/ui/dialog/DialogHeader.vue'
|
||||
import DialogOverlay from '@/components/ui/dialog/DialogOverlay.vue'
|
||||
import DialogPortal from '@/components/ui/dialog/DialogPortal.vue'
|
||||
import DialogTitle from '@/components/ui/dialog/DialogTitle.vue'
|
||||
import Input from '@/components/ui/input/Input.vue'
|
||||
import SingleSelect from '@/components/ui/single-select/SingleSelect.vue'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { storeToRefs } from 'pinia'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
|
||||
import { useHostCredentialsStore } from '../stores/hostCredentialsStore'
|
||||
import { AUTH_SCHEMES } from '../types'
|
||||
import type { AuthScheme, HostCredentialView } from '../types'
|
||||
|
||||
const HOST_EXAMPLE = 'huggingface.co'
|
||||
const HEADER_NAME_EXAMPLE = 'Authorization'
|
||||
const QUERY_PARAM_EXAMPLE = 'token'
|
||||
|
||||
const { prefillHost = '' } = defineProps<{ prefillHost?: string }>()
|
||||
|
||||
const isOpen = defineModel<boolean>('open', { required: true })
|
||||
|
||||
const { t } = useI18n()
|
||||
const store = useHostCredentialsStore()
|
||||
const { credentials } = storeToRefs(store)
|
||||
const dialogStore = useDialogStore()
|
||||
|
||||
const isSubmitting = ref(false)
|
||||
const errorMessage = ref('')
|
||||
|
||||
interface CredentialForm {
|
||||
id: string | null
|
||||
host: string
|
||||
secret: string
|
||||
auth_scheme: AuthScheme
|
||||
header_name: string
|
||||
query_param: string
|
||||
label: string
|
||||
match_subdomains: boolean
|
||||
}
|
||||
|
||||
function emptyForm(): CredentialForm {
|
||||
return {
|
||||
id: null,
|
||||
host: '',
|
||||
secret: '',
|
||||
auth_scheme: 'bearer',
|
||||
header_name: '',
|
||||
query_param: '',
|
||||
label: '',
|
||||
match_subdomains: false
|
||||
}
|
||||
}
|
||||
|
||||
const form = reactive<CredentialForm>(emptyForm())
|
||||
|
||||
const schemeOptions = computed(() =>
|
||||
AUTH_SCHEMES.map((scheme) => ({
|
||||
value: scheme,
|
||||
name: t(`modelManager.credentials.scheme.${scheme}`)
|
||||
}))
|
||||
)
|
||||
|
||||
const canSubmit = computed(
|
||||
() =>
|
||||
!!form.host.trim() &&
|
||||
!!form.secret &&
|
||||
(form.auth_scheme !== 'query' || !!form.query_param.trim()) &&
|
||||
(form.auth_scheme !== 'header' || !!form.header_name.trim())
|
||||
)
|
||||
|
||||
async function onOpen() {
|
||||
resetForm()
|
||||
if (prefillHost) {
|
||||
form.host = prefillHost
|
||||
}
|
||||
try {
|
||||
await store.fetchCredentials()
|
||||
} catch (error) {
|
||||
errorMessage.value =
|
||||
error instanceof Error ? error.message : t('modelManager.actionFailed')
|
||||
}
|
||||
}
|
||||
|
||||
function resetForm() {
|
||||
Object.assign(form, emptyForm())
|
||||
errorMessage.value = ''
|
||||
}
|
||||
|
||||
function editCredential(credential: HostCredentialView) {
|
||||
Object.assign(form, {
|
||||
id: credential.id,
|
||||
host: credential.host,
|
||||
secret: '',
|
||||
auth_scheme: credential.auth_scheme,
|
||||
header_name: credential.header_name ?? '',
|
||||
query_param: credential.query_param ?? '',
|
||||
label: credential.label ?? '',
|
||||
match_subdomains: credential.match_subdomains
|
||||
})
|
||||
errorMessage.value = ''
|
||||
}
|
||||
|
||||
async function submit() {
|
||||
if (!canSubmit.value) return
|
||||
isSubmitting.value = true
|
||||
errorMessage.value = ''
|
||||
try {
|
||||
await store.upsert({
|
||||
host: form.host,
|
||||
secret: form.secret,
|
||||
auth_scheme: form.auth_scheme,
|
||||
header_name: form.auth_scheme === 'header' ? form.header_name : null,
|
||||
query_param: form.auth_scheme === 'query' ? form.query_param : null,
|
||||
label: form.label || null,
|
||||
match_subdomains: form.match_subdomains
|
||||
})
|
||||
useToastStore().add({
|
||||
severity: 'success',
|
||||
summary: t('modelManager.credentials.saved'),
|
||||
detail: form.host,
|
||||
life: 4000
|
||||
})
|
||||
resetForm()
|
||||
} catch (error) {
|
||||
errorMessage.value =
|
||||
error instanceof Error ? error.message : t('modelManager.actionFailed')
|
||||
} finally {
|
||||
isSubmitting.value = false
|
||||
}
|
||||
}
|
||||
|
||||
function confirmDelete(credential: HostCredentialView) {
|
||||
const dialog = showConfirmDialog({
|
||||
headerProps: { title: t('modelManager.credentials.deleteTitle') },
|
||||
props: {
|
||||
promptText: t('modelManager.credentials.deleteMessage', {
|
||||
host: credential.host
|
||||
})
|
||||
},
|
||||
footerProps: {
|
||||
confirmText: t('g.delete'),
|
||||
confirmVariant: 'destructive' as const,
|
||||
onCancel: () => dialogStore.closeDialog(dialog),
|
||||
onConfirm: async () => {
|
||||
dialogStore.closeDialog(dialog)
|
||||
try {
|
||||
await store.remove(credential.id)
|
||||
} catch (error) {
|
||||
useToastStore().add({
|
||||
severity: 'error',
|
||||
summary: t('modelManager.actionFailed'),
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
life: 5000
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
</script>
|
||||
336
src/platform/modelManager/components/ModelDownloadRow.test.ts
Normal file
336
src/platform/modelManager/components/ModelDownloadRow.test.ts
Normal file
@@ -0,0 +1,336 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { DownloadStatus } from '../types'
|
||||
import ModelDownloadRow from './ModelDownloadRow.vue'
|
||||
|
||||
const mockPause = vi.fn()
|
||||
const mockResume = vi.fn()
|
||||
const mockCancel = vi.fn()
|
||||
const mockRaisePriority = vi.fn()
|
||||
const mockRemove = vi.fn()
|
||||
|
||||
vi.mock('../composables/useModelDownloadActions', () => ({
|
||||
useModelDownloadActions: () => ({
|
||||
pause: mockPause,
|
||||
resume: mockResume,
|
||||
cancel: mockCancel,
|
||||
raisePriority: mockRaisePriority,
|
||||
remove: mockRemove,
|
||||
toastError: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
function createDownload(
|
||||
overrides: Partial<DownloadStatus> = {}
|
||||
): DownloadStatus {
|
||||
return {
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
url: 'https://huggingface.co/org/x.safetensors',
|
||||
status: 'active',
|
||||
priority: 0,
|
||||
total_bytes: 2048,
|
||||
bytes_done: 1024,
|
||||
progress: 0.5,
|
||||
speed_bps: 512,
|
||||
eta_seconds: 125,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function mountRow(
|
||||
download: DownloadStatus,
|
||||
onOpenCredentials?: (host: string) => void
|
||||
) {
|
||||
return render(ModelDownloadRow, {
|
||||
props: {
|
||||
download,
|
||||
...(onOpenCredentials ? { onOpenCredentials } : {})
|
||||
},
|
||||
global: { plugins: [i18n] }
|
||||
})
|
||||
}
|
||||
|
||||
describe('ModelDownloadRow', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('splits the model id into directory and filename', () => {
|
||||
mountRow(createDownload({ model_id: 'loras/x.safetensors' }))
|
||||
|
||||
expect(screen.getByText('x.safetensors')).toBeInTheDocument()
|
||||
expect(screen.getByText('loras')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders an empty directory when the model id has no folder', () => {
|
||||
mountRow(createDownload({ model_id: 'x.safetensors' }))
|
||||
|
||||
expect(screen.getByText('x.safetensors')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('formats the meta line with percent, size, speed, and eta while active', () => {
|
||||
mountRow(createDownload({ status: 'active' }))
|
||||
|
||||
expect(
|
||||
screen.getByText('50% · 1 KB / 2 KB · 512 B/s · 2:05')
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders an empty meta line when no progress metrics are known yet', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'queued',
|
||||
progress: null,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
speed_bps: null,
|
||||
eta_seconds: null
|
||||
})
|
||||
)
|
||||
|
||||
expect(screen.getByTestId('meta-line')).toBeEmptyDOMElement()
|
||||
})
|
||||
|
||||
it('omits the eta once a download is no longer active', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'paused',
|
||||
progress: 0.5,
|
||||
eta_seconds: 125
|
||||
})
|
||||
)
|
||||
|
||||
expect(screen.getByText('50% · 1 KB / 2 KB · 512 B/s')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
describe('progress bar visibility', () => {
|
||||
it('shows a progress bar for an active download with known progress', () => {
|
||||
mountRow(createDownload({ status: 'active', progress: 0.5 }))
|
||||
|
||||
expect(screen.getByTestId('progress-bar')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('hides the progress bar for a cancelled download', () => {
|
||||
mountRow(createDownload({ status: 'cancelled', progress: 0.5 }))
|
||||
|
||||
expect(screen.queryByTestId('progress-bar')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('Cancelled')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('action buttons by status', () => {
|
||||
it('shows pause and cancel for a queued download, plus raise priority', () => {
|
||||
mountRow(createDownload({ status: 'queued' }))
|
||||
|
||||
expect(screen.getByTitle('Raise priority')).toBeInTheDocument()
|
||||
expect(screen.getByTitle('Pause')).toBeInTheDocument()
|
||||
expect(screen.getByTitle('Cancel')).toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Resume')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Remove from list')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows only resume for a failed download without an auth error', () => {
|
||||
mountRow(createDownload({ status: 'failed', error: 'disk full' }))
|
||||
|
||||
expect(screen.getByTitle('Resume')).toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Pause')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Cancel')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Remove from list')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Add credentials')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows the remove action for terminal downloads', () => {
|
||||
mountRow(createDownload({ status: 'completed' }))
|
||||
|
||||
expect(screen.getByTitle('Remove from list')).toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Cancel')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Resume')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('auth errors', () => {
|
||||
it('shows the credentials button and a host-specific hint', async () => {
|
||||
const onOpenCredentials = vi.fn()
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://huggingface.co/org/x.safetensors',
|
||||
error: '401 Unauthorized'
|
||||
}),
|
||||
onOpenCredentials
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText(
|
||||
'huggingface.co needs an API key. Add one in the Credentials Manager, then resume.'
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Add credentials'))
|
||||
expect(onOpenCredentials).toHaveBeenCalledWith('huggingface.co')
|
||||
})
|
||||
|
||||
it('falls back to a hostless hint when the url cannot be parsed', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'not-a-url',
|
||||
error: '403 forbidden'
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText(
|
||||
'This host/model needs an API key. Add one in the Credentials Manager, then resume.'
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows the raw error message for a non-auth failure', () => {
|
||||
mountRow(createDownload({ status: 'failed', error: 'disk full' }))
|
||||
|
||||
expect(screen.getByText('disk full')).toBeInTheDocument()
|
||||
expect(screen.queryByTitle('Add credentials')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('hides a leftover error while the download is not yet terminal', () => {
|
||||
mountRow(createDownload({ status: 'active', error: 'disk full' }))
|
||||
|
||||
expect(screen.queryByText('disk full')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('gated models', () => {
|
||||
const gatedError =
|
||||
'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/vae/diffusion_pytorch_model.safetensors is a gated model — Access to model black-forest-labs/FLUX.2-dev is restricted. You must have access to it and be authenticated to access it.'
|
||||
|
||||
it('shows the gated hint, credentials button, and a link to accept the license on the model page', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/vae/diffusion_pytorch_model.safetensors',
|
||||
error: gatedError
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByText(
|
||||
"This model is gated. Accept its license on the model's page, then add an API key and resume."
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
expect(screen.getByTitle('Add credentials')).toBeInTheDocument()
|
||||
|
||||
const link = screen.getByRole('link', { name: 'Accept license' })
|
||||
expect(link).toHaveAttribute(
|
||||
'href',
|
||||
'https://huggingface.co/black-forest-labs/FLUX.2-dev'
|
||||
)
|
||||
})
|
||||
|
||||
it('derives the model page from the error when the download url is a cdn link', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://cas-bridge.xethub.hf.co/xet-bridge-us/abc/def',
|
||||
error: gatedError
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.getByRole('link', { name: 'Accept license' })
|
||||
).toHaveAttribute(
|
||||
'href',
|
||||
'https://huggingface.co/black-forest-labs/FLUX.2-dev'
|
||||
)
|
||||
})
|
||||
|
||||
it('hides the raw backend error text for gated failures', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/model.safetensors',
|
||||
error: gatedError
|
||||
})
|
||||
)
|
||||
|
||||
expect(screen.queryByText(gatedError)).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('omits the accept-license link when no huggingface url is present', () => {
|
||||
mountRow(
|
||||
createDownload({
|
||||
status: 'failed',
|
||||
url: 'https://example.com/model.safetensors',
|
||||
error: 'Access to this model is restricted, request access.'
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
screen.queryByRole('link', { name: 'Accept license' })
|
||||
).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('user actions', () => {
|
||||
it('pauses on click', async () => {
|
||||
const download = createDownload({ status: 'active' })
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Pause'))
|
||||
expect(mockPause).toHaveBeenCalledWith(download)
|
||||
})
|
||||
|
||||
it('resumes on click', async () => {
|
||||
const download = createDownload({ status: 'paused' })
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Resume'))
|
||||
expect(mockResume).toHaveBeenCalledWith(download)
|
||||
})
|
||||
|
||||
it('cancels on click', async () => {
|
||||
const download = createDownload({ status: 'active' })
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Cancel'))
|
||||
expect(mockCancel).toHaveBeenCalledWith(download)
|
||||
})
|
||||
|
||||
it('raises priority by 1 on click', async () => {
|
||||
const download = createDownload({ status: 'queued', priority: 2 })
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Raise priority'))
|
||||
expect(mockRaisePriority).toHaveBeenCalledWith(download, 1)
|
||||
})
|
||||
|
||||
it('removes the download on click', async () => {
|
||||
const download = createDownload({
|
||||
download_id: 'd1',
|
||||
status: 'completed'
|
||||
})
|
||||
mountRow(download)
|
||||
|
||||
await userEvent.click(screen.getByTitle('Remove from list'))
|
||||
expect(mockRemove).toHaveBeenCalledWith(download)
|
||||
})
|
||||
})
|
||||
})
|
||||
245
src/platform/modelManager/components/ModelDownloadRow.vue
Normal file
245
src/platform/modelManager/components/ModelDownloadRow.vue
Normal file
@@ -0,0 +1,245 @@
|
||||
<template>
|
||||
<div
|
||||
:class="
|
||||
cn(
|
||||
'relative flex flex-col gap-1 overflow-hidden rounded-lg border border-border-default bg-secondary-background px-3 py-2',
|
||||
isCancelled && 'opacity-60'
|
||||
)
|
||||
"
|
||||
>
|
||||
<div
|
||||
v-if="showProgressBar"
|
||||
:class="progressBarContainerClass"
|
||||
data-testid="progress-bar"
|
||||
>
|
||||
<div :class="progressBarPrimaryClass" :style="barStyle" />
|
||||
</div>
|
||||
|
||||
<div class="relative flex items-center gap-2">
|
||||
<div class="flex min-w-0 flex-1 flex-col">
|
||||
<span class="truncate text-sm font-medium text-base-foreground">
|
||||
{{ filename }}
|
||||
</span>
|
||||
<span class="truncate text-xs text-muted-foreground">
|
||||
{{ directory }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div class="flex shrink-0 items-center gap-0.5">
|
||||
<template v-if="canRaisePriority">
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.raisePriority')"
|
||||
@click="actions.raisePriority(download, 1)"
|
||||
>
|
||||
<i class="icon-[lucide--chevron-up] size-4" />
|
||||
</Button>
|
||||
</template>
|
||||
<Button
|
||||
v-if="canPause"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('g.pause')"
|
||||
@click="actions.pause(download)"
|
||||
>
|
||||
<i class="icon-[lucide--pause] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
v-if="isAuthError"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.addCredentials')"
|
||||
@click="emit('openCredentials', host)"
|
||||
>
|
||||
<i class="icon-[lucide--key-round] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
v-if="canResume"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.resume')"
|
||||
@click="actions.resume(download)"
|
||||
>
|
||||
<i class="icon-[lucide--play] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
v-if="canCancel"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('g.cancel')"
|
||||
@click="actions.cancel(download)"
|
||||
>
|
||||
<i class="icon-[lucide--x] size-4 text-red-400" />
|
||||
</Button>
|
||||
<Button
|
||||
v-if="isTerminal"
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.removeFromList')"
|
||||
@click="actions.remove(download)"
|
||||
>
|
||||
<i class="icon-[lucide--x] size-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
class="relative flex items-center justify-between gap-2 text-xs text-muted-foreground"
|
||||
>
|
||||
<span class="flex items-center gap-1">
|
||||
<i v-if="isCancelled" class="icon-[lucide--ban] size-3.5" />
|
||||
{{ statusLabel }}
|
||||
</span>
|
||||
<span class="truncate" data-testid="meta-line">{{ metaLine }}</span>
|
||||
</div>
|
||||
|
||||
<p
|
||||
v-if="isGatedModel"
|
||||
class="relative text-xs wrap-break-word text-amber-400"
|
||||
>
|
||||
{{ $t('modelManager.gatedModelHint') }}
|
||||
<a
|
||||
v-if="modelPageUrl"
|
||||
:href="modelPageUrl"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
class="underline"
|
||||
>
|
||||
{{ $t('modelManager.openModelPage') }}
|
||||
</a>
|
||||
</p>
|
||||
<p
|
||||
v-else-if="isAuthError"
|
||||
class="relative text-xs wrap-break-word text-amber-400"
|
||||
>
|
||||
{{
|
||||
host
|
||||
? $t('modelManager.authErrorHint', { host })
|
||||
: $t('modelManager.authErrorHintNoHost')
|
||||
}}
|
||||
</p>
|
||||
<p
|
||||
v-else-if="isFailed && download.error"
|
||||
class="relative text-xs wrap-break-word text-red-400"
|
||||
>
|
||||
{{ download.error }}
|
||||
</p>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
import { computed } from 'vue'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { useProgressBarBackground } from '@/composables/useProgressBarBackground'
|
||||
import { formatSize } from '@/utils/formatUtil'
|
||||
|
||||
import { useModelDownloadActions } from '../composables/useModelDownloadActions'
|
||||
import { downloadProgressFraction } from '../stores/modelDownloadStore'
|
||||
import type { DownloadStatus } from '../types'
|
||||
import { directoryOf, filenameOf } from '../utils/modelId'
|
||||
|
||||
const { download } = defineProps<{ download: DownloadStatus }>()
|
||||
|
||||
const emit = defineEmits<{ openCredentials: [host: string] }>()
|
||||
|
||||
const { t } = useI18n()
|
||||
const actions = useModelDownloadActions()
|
||||
const {
|
||||
progressBarContainerClass,
|
||||
progressBarPrimaryClass,
|
||||
progressPercentStyle
|
||||
} = useProgressBarBackground()
|
||||
|
||||
const directory = computed(() => directoryOf(download.model_id))
|
||||
const filename = computed(() => filenameOf(download.model_id))
|
||||
|
||||
const host = computed(() => {
|
||||
try {
|
||||
return new URL(download.url).hostname
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
})
|
||||
|
||||
const percent = computed(() => {
|
||||
const fraction = downloadProgressFraction(download)
|
||||
return fraction == null ? undefined : Math.round(fraction * 100)
|
||||
})
|
||||
|
||||
const barStyle = computed(() => progressPercentStyle(percent.value))
|
||||
|
||||
const isTerminal = computed(() =>
|
||||
['completed', 'cancelled'].includes(download.status)
|
||||
)
|
||||
const isCancelled = computed(() => download.status === 'cancelled')
|
||||
const showProgressBar = computed(
|
||||
() =>
|
||||
download.status !== 'failed' &&
|
||||
!isCancelled.value &&
|
||||
percent.value !== undefined
|
||||
)
|
||||
const canPause = computed(() => ['queued', 'active'].includes(download.status))
|
||||
const canResume = computed(() => ['paused', 'failed'].includes(download.status))
|
||||
const canCancel = computed(
|
||||
() => !isTerminal.value && download.status !== 'failed'
|
||||
)
|
||||
const canRaisePriority = computed(() => download.status === 'queued')
|
||||
|
||||
const GATED_MODEL_PATTERN = /gated|restricted|request access|must have access/i
|
||||
const AUTH_ERROR_PATTERN =
|
||||
/api key|credential|token|unauthor|forbidden|\b401\b|\b403\b/i
|
||||
|
||||
const isFailed = computed(() => download.status === 'failed')
|
||||
const isGatedModel = computed(
|
||||
() => isFailed.value && !!download.error?.match(GATED_MODEL_PATTERN)
|
||||
)
|
||||
const isAuthError = computed(
|
||||
() =>
|
||||
isFailed.value &&
|
||||
(isGatedModel.value || !!download.error?.match(AUTH_ERROR_PATTERN))
|
||||
)
|
||||
|
||||
const HF_MODEL_URL_PATTERN =
|
||||
/https?:\/\/huggingface\.co\/([^/\s?#]+)\/([^/\s?#]+)/i
|
||||
const modelPageUrl = computed(() => {
|
||||
const match = `${download.url} ${download.error ?? ''}`.match(
|
||||
HF_MODEL_URL_PATTERN
|
||||
)
|
||||
if (!match) return ''
|
||||
const [, owner, repo] = match
|
||||
return `https://huggingface.co/${owner}/${repo}`
|
||||
})
|
||||
|
||||
const statusLabel = computed(() => t(`modelManager.status.${download.status}`))
|
||||
|
||||
const metaLine = computed(() => {
|
||||
const parts: string[] = []
|
||||
if (percent.value !== undefined) parts.push(`${percent.value}%`)
|
||||
if (download.total_bytes != null) {
|
||||
parts.push(
|
||||
`${formatSize(download.bytes_done)} / ${formatSize(download.total_bytes)}`
|
||||
)
|
||||
}
|
||||
if (download.speed_bps) parts.push(`${formatSize(download.speed_bps)}/s`)
|
||||
if (download.eta_seconds != null && download.status === 'active') {
|
||||
parts.push(formatEta(download.eta_seconds))
|
||||
}
|
||||
return parts.join(' · ')
|
||||
})
|
||||
|
||||
function formatEta(seconds: number): string {
|
||||
const total = Math.max(0, Math.round(seconds))
|
||||
const hours = Math.floor(total / 3600)
|
||||
const minutes = Math.floor((total % 3600) / 60)
|
||||
const secs = total % 60
|
||||
const paddedSecs = secs.toString().padStart(2, '0')
|
||||
if (hours > 0) {
|
||||
return `${hours}:${minutes.toString().padStart(2, '0')}:${paddedSecs}`
|
||||
}
|
||||
return `${minutes}:${paddedSecs}`
|
||||
}
|
||||
</script>
|
||||
@@ -0,0 +1,178 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { reactive, ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import enMessages from '@/locales/en/main.json' with { type: 'json' }
|
||||
|
||||
import type { DownloadStatus } from '../types'
|
||||
import ModelManagerSidebarTab from './ModelManagerSidebarTab.vue'
|
||||
|
||||
const mockHydrate = vi.fn()
|
||||
const mockClearHistory = vi.fn()
|
||||
|
||||
const mockStore = reactive({
|
||||
downloadList: ref<DownloadStatus[]>([]),
|
||||
activeDownloads: ref<DownloadStatus[]>([]),
|
||||
historyDownloads: ref<DownloadStatus[]>([]),
|
||||
hydrate: mockHydrate
|
||||
})
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => mockStore
|
||||
}))
|
||||
|
||||
vi.mock('../composables/useModelDownloadActions', () => ({
|
||||
useModelDownloadActions: () => ({ clearHistory: mockClearHistory })
|
||||
}))
|
||||
|
||||
vi.mock('./ModelDownloadRow.vue', () => ({
|
||||
default: {
|
||||
name: 'ModelDownloadRow',
|
||||
props: ['download'],
|
||||
emits: ['openCredentials'],
|
||||
template:
|
||||
'<div><span>{{ download.download_id }}</span>' +
|
||||
'<button @click="$emit(\'openCredentials\', download.model_id)">open-credentials</button></div>'
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('./AddModelByUrlDialog.vue', () => ({
|
||||
default: {
|
||||
name: 'AddModelByUrlDialog',
|
||||
props: ['open'],
|
||||
template: '<div data-testid="add-model-dialog">{{ open }}</div>'
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('./HostCredentialsDialog.vue', () => ({
|
||||
default: {
|
||||
name: 'HostCredentialsDialog',
|
||||
props: ['open', 'prefillHost'],
|
||||
template:
|
||||
'<div data-testid="credentials-dialog">{{ open }}:{{ prefillHost }}</div>'
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/components/sidebar/tabs/SidebarTabTemplate.vue', () => ({
|
||||
default: {
|
||||
name: 'SidebarTabTemplate',
|
||||
template: '<div><slot name="tool-buttons" /><slot name="body" /></div>'
|
||||
}
|
||||
}))
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: { en: enMessages },
|
||||
missingWarn: false,
|
||||
fallbackWarn: false
|
||||
})
|
||||
|
||||
function createDownload(
|
||||
overrides: Partial<DownloadStatus> = {}
|
||||
): DownloadStatus {
|
||||
return {
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
url: 'https://huggingface.co/org/x.safetensors',
|
||||
status: 'active',
|
||||
priority: 0,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
progress: null,
|
||||
speed_bps: null,
|
||||
eta_seconds: null,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function mountTab() {
|
||||
return render(ModelManagerSidebarTab, { global: { plugins: [i18n] } })
|
||||
}
|
||||
|
||||
describe('ModelManagerSidebarTab', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockStore.downloadList = []
|
||||
mockStore.activeDownloads = []
|
||||
mockStore.historyDownloads = []
|
||||
})
|
||||
|
||||
it('hydrates the store on mount', () => {
|
||||
mountTab()
|
||||
expect(mockHydrate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows the empty state when there are no downloads', () => {
|
||||
mountTab()
|
||||
expect(screen.getByText('No downloads yet')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders active downloads under the Active section', () => {
|
||||
const download = createDownload({ download_id: 'd1' })
|
||||
mockStore.activeDownloads = [download]
|
||||
mockStore.downloadList = [download]
|
||||
mountTab()
|
||||
|
||||
expect(screen.getByText('Active')).toBeInTheDocument()
|
||||
expect(screen.getByText('d1')).toBeInTheDocument()
|
||||
expect(screen.queryByText('No downloads yet')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders history downloads under the History section with a clear action', async () => {
|
||||
const download = createDownload({ download_id: 'd2' })
|
||||
mockStore.historyDownloads = [download]
|
||||
mockStore.downloadList = [download]
|
||||
mountTab()
|
||||
|
||||
expect(screen.getByText('History')).toBeInTheDocument()
|
||||
expect(screen.getByText('d2')).toBeInTheDocument()
|
||||
|
||||
await userEvent.click(screen.getByText('Clear history'))
|
||||
expect(mockClearHistory).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('opens the add-model dialog from the toolbar button', async () => {
|
||||
mountTab()
|
||||
|
||||
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('false')
|
||||
await userEvent.click(screen.getByTitle('Add model'))
|
||||
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('true')
|
||||
})
|
||||
|
||||
it('opens the add-model dialog from the empty state button', async () => {
|
||||
mountTab()
|
||||
|
||||
await userEvent.click(screen.getByText('Add model'))
|
||||
expect(screen.getByTestId('add-model-dialog')).toHaveTextContent('true')
|
||||
})
|
||||
|
||||
it('opens the credentials dialog with an empty prefill from the toolbar button', async () => {
|
||||
mountTab()
|
||||
|
||||
await userEvent.click(screen.getByTitle('Credentials Manager'))
|
||||
expect(screen.getByTestId('credentials-dialog')).toHaveTextContent('true:')
|
||||
})
|
||||
|
||||
it('opens the credentials dialog prefilled with the row host', async () => {
|
||||
const download = createDownload({
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors'
|
||||
})
|
||||
mockStore.activeDownloads = [download]
|
||||
mockStore.downloadList = [download]
|
||||
mountTab()
|
||||
|
||||
await userEvent.click(screen.getByText('open-credentials'))
|
||||
|
||||
expect(screen.getByTestId('credentials-dialog')).toHaveTextContent(
|
||||
'true:loras/x.safetensors'
|
||||
)
|
||||
})
|
||||
})
|
||||
103
src/platform/modelManager/components/ModelManagerSidebarTab.vue
Normal file
103
src/platform/modelManager/components/ModelManagerSidebarTab.vue
Normal file
@@ -0,0 +1,103 @@
|
||||
<template>
|
||||
<SidebarTabTemplate :title="$t('modelManager.title')">
|
||||
<template #tool-buttons>
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.credentials.title')"
|
||||
@click="openCredentials('')"
|
||||
>
|
||||
<i class="icon-[lucide--key-round] size-4" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="textonly"
|
||||
size="icon"
|
||||
:title="$t('modelManager.addModel')"
|
||||
@click="addOpen = true"
|
||||
>
|
||||
<i class="icon-[lucide--plus] size-4" />
|
||||
</Button>
|
||||
</template>
|
||||
|
||||
<template #body>
|
||||
<div class="flex flex-col gap-4 p-3">
|
||||
<section v-if="activeDownloads.length" class="flex flex-col gap-2">
|
||||
<h3 class="text-xs font-semibold text-muted-foreground uppercase">
|
||||
{{ $t('modelManager.active') }}
|
||||
</h3>
|
||||
<ModelDownloadRow
|
||||
v-for="download in activeDownloads"
|
||||
:key="download.download_id"
|
||||
:download
|
||||
@open-credentials="openCredentials"
|
||||
/>
|
||||
</section>
|
||||
|
||||
<section v-if="historyDownloads.length" class="flex flex-col gap-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<h3 class="text-xs font-semibold text-muted-foreground uppercase">
|
||||
{{ $t('modelManager.history') }}
|
||||
</h3>
|
||||
<Button variant="link" size="sm" @click="actions.clearHistory()">
|
||||
{{ $t('modelManager.clearHistory') }}
|
||||
</Button>
|
||||
</div>
|
||||
<ModelDownloadRow
|
||||
v-for="download in historyDownloads"
|
||||
:key="download.download_id"
|
||||
:download
|
||||
@open-credentials="openCredentials"
|
||||
/>
|
||||
</section>
|
||||
|
||||
<div
|
||||
v-if="!store.downloadList.length"
|
||||
class="flex flex-col items-center gap-3 py-10 text-center text-muted-foreground"
|
||||
>
|
||||
<i class="icon-[lucide--download] size-8" />
|
||||
<p class="text-sm">{{ $t('modelManager.empty') }}</p>
|
||||
<Button variant="primary" size="sm" @click="addOpen = true">
|
||||
{{ $t('modelManager.addModel') }}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</SidebarTabTemplate>
|
||||
|
||||
<AddModelByUrlDialog v-model:open="addOpen" />
|
||||
<HostCredentialsDialog
|
||||
v-model:open="credentialsOpen"
|
||||
:prefill-host="prefillHost"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { storeToRefs } from 'pinia'
|
||||
import { onMounted, ref } from 'vue'
|
||||
|
||||
import SidebarTabTemplate from '@/components/sidebar/tabs/SidebarTabTemplate.vue'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
|
||||
import AddModelByUrlDialog from './AddModelByUrlDialog.vue'
|
||||
import HostCredentialsDialog from './HostCredentialsDialog.vue'
|
||||
import ModelDownloadRow from './ModelDownloadRow.vue'
|
||||
import { useModelDownloadActions } from '../composables/useModelDownloadActions'
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
|
||||
const store = useModelDownloadStore()
|
||||
const actions = useModelDownloadActions()
|
||||
const { activeDownloads, historyDownloads } = storeToRefs(store)
|
||||
|
||||
const addOpen = ref(false)
|
||||
const credentialsOpen = ref(false)
|
||||
const prefillHost = ref('')
|
||||
|
||||
function openCredentials(host: string) {
|
||||
prefillHost.value = host
|
||||
credentialsOpen.value = true
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
void store.hydrate()
|
||||
})
|
||||
</script>
|
||||
@@ -0,0 +1,51 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import * as downloadApi from '../api/modelDownloadApi'
|
||||
import { useModelAvailability } from './useModelAvailability'
|
||||
|
||||
vi.mock('../api/modelDownloadApi', () => ({
|
||||
checkAvailability: vi.fn()
|
||||
}))
|
||||
|
||||
describe('useModelAvailability', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
it('short-circuits without calling the API for an empty map', async () => {
|
||||
const { check, results } = useModelAvailability()
|
||||
|
||||
await check({})
|
||||
|
||||
expect(downloadApi.checkAvailability).not.toHaveBeenCalled()
|
||||
expect(results.value).toEqual({})
|
||||
})
|
||||
|
||||
it('stores the availability map from the response', async () => {
|
||||
vi.mocked(downloadApi.checkAvailability).mockResolvedValue({
|
||||
models: {
|
||||
'loras/x.safetensors': { state: 'available', url_allowed: true }
|
||||
}
|
||||
})
|
||||
const { check, results, isChecking } = useModelAvailability()
|
||||
|
||||
await check({ 'loras/x.safetensors': 'https://h.co/x' })
|
||||
|
||||
expect(results.value['loras/x.safetensors']).toEqual({
|
||||
state: 'available',
|
||||
url_allowed: true
|
||||
})
|
||||
expect(isChecking.value).toBe(false)
|
||||
})
|
||||
|
||||
it('captures errors and resets the checking flag', async () => {
|
||||
vi.mocked(downloadApi.checkAvailability).mockRejectedValue(
|
||||
new Error('boom')
|
||||
)
|
||||
const { check, error, isChecking } = useModelAvailability()
|
||||
|
||||
await expect(check({ 'loras/x.safetensors': 'u' })).rejects.toThrow('boom')
|
||||
expect(error.value).toBeInstanceOf(Error)
|
||||
expect(isChecking.value).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,41 @@
|
||||
import { ref, shallowRef } from 'vue'
|
||||
|
||||
import { checkAvailability } from '../api/modelDownloadApi'
|
||||
import type { AvailabilityEntry } from '../types'
|
||||
|
||||
/**
|
||||
* Bulk-checks whether the models declared by a workflow are present,
|
||||
* downloading, or missing. Pass a `model_id -> source URL` map and badge
|
||||
* each model from the returned entries.
|
||||
*/
|
||||
export function useModelAvailability() {
|
||||
const results = shallowRef<Record<string, AvailabilityEntry>>({})
|
||||
const isChecking = ref(false)
|
||||
const error = ref<unknown>(null)
|
||||
|
||||
async function check(models: Record<string, string>) {
|
||||
if (Object.keys(models).length === 0) {
|
||||
results.value = {}
|
||||
return results.value
|
||||
}
|
||||
isChecking.value = true
|
||||
error.value = null
|
||||
try {
|
||||
const response = await checkAvailability(models)
|
||||
results.value = response.models
|
||||
return response.models
|
||||
} catch (e) {
|
||||
error.value = e
|
||||
throw e
|
||||
} finally {
|
||||
isChecking.value = false
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
results,
|
||||
isChecking,
|
||||
error,
|
||||
check
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
|
||||
|
||||
import type { DownloadStatus } from '../types'
|
||||
import { DownloadApiError } from '../types'
|
||||
import { useModelDownloadActions } from './useModelDownloadActions'
|
||||
|
||||
const mockStore = {
|
||||
pause: vi.fn(),
|
||||
resume: vi.fn(),
|
||||
cancel: vi.fn(),
|
||||
setPriority: vi.fn(),
|
||||
hydrate: vi.fn()
|
||||
}
|
||||
const mockToastAdd = vi.fn()
|
||||
const mockCloseDialog = vi.fn()
|
||||
|
||||
vi.mock('vue-i18n', () => ({
|
||||
useI18n: () => ({ t: (key: string) => key })
|
||||
}))
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => mockStore
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: () => ({ add: mockToastAdd })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/dialogStore', () => ({
|
||||
useDialogStore: () => ({ closeDialog: mockCloseDialog })
|
||||
}))
|
||||
|
||||
vi.mock('@/components/dialog/confirm/confirmDialog')
|
||||
|
||||
const mockShowConfirmDialog = vi.mocked(showConfirmDialog)
|
||||
|
||||
interface CapturedConfirmOptions {
|
||||
footerProps?: {
|
||||
confirmVariant?: string
|
||||
onConfirm?: () => void | Promise<void>
|
||||
onCancel?: () => void
|
||||
}
|
||||
}
|
||||
|
||||
function capturedOptions(): CapturedConfirmOptions {
|
||||
return mockShowConfirmDialog.mock
|
||||
.calls[0][0] as unknown as CapturedConfirmOptions
|
||||
}
|
||||
|
||||
function createDownload(
|
||||
overrides: Partial<DownloadStatus> = {}
|
||||
): DownloadStatus {
|
||||
return {
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
url: 'https://huggingface.co/x.safetensors',
|
||||
status: 'active',
|
||||
priority: 0,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
progress: null,
|
||||
speed_bps: null,
|
||||
eta_seconds: null,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
describe('useModelDownloadActions', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockShowConfirmDialog.mockReturnValue(
|
||||
{} as ReturnType<typeof showConfirmDialog>
|
||||
)
|
||||
})
|
||||
|
||||
it('pauses a download by id', async () => {
|
||||
mockStore.pause.mockResolvedValue(undefined)
|
||||
const { pause } = useModelDownloadActions()
|
||||
|
||||
await pause(createDownload({ download_id: 'd1' }))
|
||||
|
||||
expect(mockStore.pause).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('resumes a download by id', async () => {
|
||||
mockStore.resume.mockResolvedValue(undefined)
|
||||
const { resume } = useModelDownloadActions()
|
||||
|
||||
await resume(createDownload({ download_id: 'd1' }))
|
||||
|
||||
expect(mockStore.resume).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('raises priority relative to the current value', async () => {
|
||||
mockStore.setPriority.mockResolvedValue(undefined)
|
||||
const { raisePriority } = useModelDownloadActions()
|
||||
|
||||
await raisePriority(createDownload({ download_id: 'd1', priority: 2 }), 1)
|
||||
|
||||
expect(mockStore.setPriority).toHaveBeenCalledWith('d1', 3)
|
||||
})
|
||||
|
||||
it('shows an error toast and re-hydrates the store when an action fails', async () => {
|
||||
mockStore.pause.mockRejectedValue(new Error('offline'))
|
||||
mockStore.hydrate.mockResolvedValue(undefined)
|
||||
const { pause } = useModelDownloadActions()
|
||||
|
||||
await pause(createDownload())
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ severity: 'error', detail: 'offline' })
|
||||
)
|
||||
expect(mockStore.hydrate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('uses the DownloadApiError message in the toast', async () => {
|
||||
mockStore.pause.mockRejectedValue(
|
||||
new DownloadApiError('nope', 'URL_NOT_ALLOWED', 400)
|
||||
)
|
||||
mockStore.hydrate.mockResolvedValue(undefined)
|
||||
const { pause } = useModelDownloadActions()
|
||||
|
||||
await pause(createDownload())
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ detail: 'nope' })
|
||||
)
|
||||
})
|
||||
|
||||
it('swallows a failed re-hydrate after an action error', async () => {
|
||||
mockStore.pause.mockRejectedValue(new Error('offline'))
|
||||
mockStore.hydrate.mockRejectedValue(new Error('still offline'))
|
||||
const { pause } = useModelDownloadActions()
|
||||
|
||||
await expect(pause(createDownload())).resolves.toBeUndefined()
|
||||
})
|
||||
|
||||
describe('cancel', () => {
|
||||
it('opens a destructive confirm dialog and only cancels on confirm', async () => {
|
||||
mockStore.cancel.mockResolvedValue(undefined)
|
||||
const { cancel } = useModelDownloadActions()
|
||||
|
||||
cancel(createDownload({ download_id: 'd1' }))
|
||||
expect(capturedOptions().footerProps?.confirmVariant).toBe('destructive')
|
||||
|
||||
await capturedOptions().footerProps?.onConfirm?.()
|
||||
|
||||
expect(mockStore.cancel).toHaveBeenCalledWith('d1')
|
||||
expect(mockCloseDialog).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('does not cancel when the confirm dialog is dismissed', () => {
|
||||
const { cancel } = useModelDownloadActions()
|
||||
|
||||
cancel(createDownload({ download_id: 'd1' }))
|
||||
capturedOptions().footerProps?.onCancel?.()
|
||||
|
||||
expect(mockStore.cancel).not.toHaveBeenCalled()
|
||||
expect(mockCloseDialog).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,97 @@
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
import { showConfirmDialog } from '@/components/dialog/confirm/confirmDialog'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
import type { DownloadStatus } from '../types'
|
||||
import { DownloadApiError } from '../types'
|
||||
|
||||
/**
|
||||
* Wraps download store mutations with user feedback: optimistic-friendly error
|
||||
* toasts and a confirmation prompt for the destructive cancel action (which
|
||||
* deletes the partial file).
|
||||
*/
|
||||
export function useModelDownloadActions() {
|
||||
const store = useModelDownloadStore()
|
||||
const dialogStore = useDialogStore()
|
||||
const { t } = useI18n()
|
||||
|
||||
function toastError(error: unknown) {
|
||||
const detail =
|
||||
error instanceof DownloadApiError || error instanceof Error
|
||||
? error.message
|
||||
: String(error)
|
||||
useToastStore().add({
|
||||
severity: 'error',
|
||||
summary: t('modelManager.actionFailed'),
|
||||
detail,
|
||||
life: 5000
|
||||
})
|
||||
}
|
||||
|
||||
async function run(action: () => Promise<void>) {
|
||||
try {
|
||||
await action()
|
||||
} catch (error) {
|
||||
toastError(error)
|
||||
// The store applies optimistic status/priority patches before the API
|
||||
// call; re-fetch the authoritative state so a failed mutation doesn't
|
||||
// leave the row stuck in the wrong local state.
|
||||
try {
|
||||
await store.hydrate()
|
||||
} catch {
|
||||
// Server unreachable; the next poll will reconcile.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const pause = (download: DownloadStatus) =>
|
||||
run(() => store.pause(download.download_id))
|
||||
|
||||
const resume = (download: DownloadStatus) =>
|
||||
run(() => store.resume(download.download_id))
|
||||
|
||||
const raisePriority = (download: DownloadStatus, delta: number) =>
|
||||
run(() =>
|
||||
store.setPriority(download.download_id, download.priority + delta)
|
||||
)
|
||||
|
||||
const remove = (download: DownloadStatus) =>
|
||||
run(() => store.remove(download.download_id))
|
||||
|
||||
const clearHistory = () => run(() => store.clearHistory())
|
||||
|
||||
function cancel(download: DownloadStatus) {
|
||||
const dialog = showConfirmDialog({
|
||||
headerProps: { title: t('modelManager.cancelConfirmTitle') },
|
||||
props: {
|
||||
promptText: t(
|
||||
'modelManager.cancelConfirmMessage',
|
||||
{ name: download.model_id },
|
||||
{ escapeParameter: false }
|
||||
)
|
||||
},
|
||||
footerProps: {
|
||||
confirmText: t('modelManager.cancelConfirm'),
|
||||
confirmVariant: 'destructive' as const,
|
||||
onCancel: () => dialogStore.closeDialog(dialog),
|
||||
onConfirm: async () => {
|
||||
dialogStore.closeDialog(dialog)
|
||||
await run(() => store.cancel(download.download_id))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
pause,
|
||||
resume,
|
||||
cancel,
|
||||
raisePriority,
|
||||
remove,
|
||||
clearHistory,
|
||||
toastError
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
import { ref } from 'vue'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { CompletedDownload } from '../stores/modelDownloadStore'
|
||||
import { useModelDownloadEffects } from './useModelDownloadEffects'
|
||||
|
||||
const mockLastCompletedDownload = ref<CompletedDownload | null>(null)
|
||||
const mockRefreshModelFolder = vi.fn()
|
||||
const mockRefreshMissingModels = vi.fn()
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => ({
|
||||
get lastCompletedDownload() {
|
||||
return mockLastCompletedDownload.value
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelStore', () => ({
|
||||
useModelStore: () => ({ refreshModelFolder: mockRefreshModelFolder })
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/missingModel/missingModelStore', () => ({
|
||||
useMissingModelStore: () => ({
|
||||
refreshMissingModels: mockRefreshMissingModels
|
||||
})
|
||||
}))
|
||||
|
||||
describe('useModelDownloadEffects', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockLastCompletedDownload.value = null
|
||||
mockRefreshModelFolder.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
it('does nothing until a download completes', () => {
|
||||
useModelDownloadEffects()
|
||||
|
||||
expect(mockRefreshModelFolder).not.toHaveBeenCalled()
|
||||
expect(mockRefreshMissingModels).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('refreshes the model folder and re-scans missing models on completion', async () => {
|
||||
useModelDownloadEffects()
|
||||
|
||||
mockLastCompletedDownload.value = {
|
||||
downloadId: 'd1',
|
||||
modelId: 'loras/x.safetensors',
|
||||
directory: 'loras',
|
||||
timestamp: 1
|
||||
}
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshModelFolder).toHaveBeenCalledWith('loras')
|
||||
})
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('skips the folder refresh when the directory is unknown', async () => {
|
||||
useModelDownloadEffects()
|
||||
|
||||
mockLastCompletedDownload.value = {
|
||||
downloadId: 'd1',
|
||||
modelId: 'unknown.safetensors',
|
||||
directory: '',
|
||||
timestamp: 1
|
||||
}
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
})
|
||||
expect(mockRefreshModelFolder).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('still re-scans missing models when the folder refresh fails', async () => {
|
||||
const consoleWarn = vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
mockRefreshModelFolder.mockRejectedValue(new Error('boom'))
|
||||
useModelDownloadEffects()
|
||||
|
||||
mockLastCompletedDownload.value = {
|
||||
downloadId: 'd1',
|
||||
modelId: 'loras/x.safetensors',
|
||||
directory: 'loras',
|
||||
timestamp: 1
|
||||
}
|
||||
|
||||
await vi.waitFor(() => {
|
||||
expect(mockRefreshMissingModels).toHaveBeenCalled()
|
||||
})
|
||||
expect(consoleWarn).toHaveBeenCalled()
|
||||
consoleWarn.mockRestore()
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,45 @@
|
||||
import { watch } from 'vue'
|
||||
|
||||
import { useMissingModelStore } from '@/platform/missingModel/missingModelStore'
|
||||
import { useModelStore } from '@/stores/modelStore'
|
||||
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
|
||||
/**
|
||||
* Side effects that run when a server-side download finishes: refresh the
|
||||
* affected model folder so the file is recognized, and re-scan missing models
|
||||
* so node errors for the now-present model clear automatically.
|
||||
*
|
||||
* Mounted once at the app root so it runs regardless of whether the Model
|
||||
* Manager panel is open.
|
||||
*/
|
||||
export function useModelDownloadEffects() {
|
||||
const store = useModelDownloadStore()
|
||||
|
||||
watch(
|
||||
() => store.lastCompletedDownload,
|
||||
async (completed) => {
|
||||
if (!completed) return
|
||||
|
||||
if (completed.directory) {
|
||||
try {
|
||||
await useModelStore().refreshModelFolder(completed.directory)
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'[ModelManager] Failed to refresh model folder after download',
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await useMissingModelStore().refreshMissingModels()
|
||||
} catch (error) {
|
||||
console.warn(
|
||||
'[ModelManager] Failed to re-scan missing models after download',
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const mockActiveDownloadCount = { value: 0 }
|
||||
|
||||
vi.mock('../stores/modelDownloadStore', () => ({
|
||||
useModelDownloadStore: () => ({
|
||||
get activeDownloadCount() {
|
||||
return mockActiveDownloadCount.value
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('../components/ModelManagerSidebarTab.vue', () => ({
|
||||
default: { name: 'ModelManagerSidebarTab' }
|
||||
}))
|
||||
|
||||
import { useModelManagerSidebarTab } from './useModelManagerSidebarTab'
|
||||
|
||||
describe('useModelManagerSidebarTab', () => {
|
||||
it('returns the expected sidebar tab extension shape', () => {
|
||||
const tab = useModelManagerSidebarTab()
|
||||
|
||||
expect(tab.id).toBe('model-manager')
|
||||
expect(tab.type).toBe('vue')
|
||||
expect(tab.title).toBe('modelManager.title')
|
||||
expect(tab.tooltip).toBe('modelManager.title')
|
||||
expect(tab.label).toBe('modelManager.title')
|
||||
})
|
||||
|
||||
it('shows no badge when there are no active downloads', () => {
|
||||
mockActiveDownloadCount.value = 0
|
||||
const tab = useModelManagerSidebarTab()
|
||||
|
||||
expect(typeof tab.iconBadge).toBe('function')
|
||||
expect((tab.iconBadge as () => string | null)()).toBeNull()
|
||||
})
|
||||
|
||||
it('shows the active download count as a badge', () => {
|
||||
mockActiveDownloadCount.value = 3
|
||||
const tab = useModelManagerSidebarTab()
|
||||
|
||||
expect((tab.iconBadge as () => string | null)()).toBe('3')
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,22 @@
|
||||
import { markRaw } from 'vue'
|
||||
|
||||
import type { SidebarTabExtension } from '@/types/extensionTypes'
|
||||
|
||||
import ModelManagerSidebarTab from '../components/ModelManagerSidebarTab.vue'
|
||||
import { useModelDownloadStore } from '../stores/modelDownloadStore'
|
||||
|
||||
export function useModelManagerSidebarTab(): SidebarTabExtension {
|
||||
return {
|
||||
id: 'model-manager',
|
||||
icon: 'icon-[lucide--download]',
|
||||
title: 'modelManager.title',
|
||||
tooltip: 'modelManager.title',
|
||||
label: 'modelManager.title',
|
||||
component: markRaw(ModelManagerSidebarTab),
|
||||
type: 'vue',
|
||||
iconBadge: () => {
|
||||
const count = useModelDownloadStore().activeDownloadCount
|
||||
return count > 0 ? count.toString() : null
|
||||
}
|
||||
}
|
||||
}
|
||||
191
src/platform/modelManager/stores/hostCredentialsStore.test.ts
Normal file
191
src/platform/modelManager/stores/hostCredentialsStore.test.ts
Normal file
@@ -0,0 +1,191 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import * as downloadApi from '../api/modelDownloadApi'
|
||||
import type { HostCredentialUpsert, HostCredentialView } from '../types'
|
||||
import { useHostCredentialsStore } from './hostCredentialsStore'
|
||||
|
||||
vi.mock('../api/modelDownloadApi', () => ({
|
||||
listCredentials: vi.fn(),
|
||||
upsertCredential: vi.fn(),
|
||||
deleteCredential: vi.fn()
|
||||
}))
|
||||
|
||||
function createCredential(
|
||||
overrides: Partial<HostCredentialView> = {}
|
||||
): HostCredentialView {
|
||||
return {
|
||||
id: 'c1',
|
||||
host: 'huggingface.co',
|
||||
auth_scheme: 'bearer',
|
||||
header_name: null,
|
||||
query_param: null,
|
||||
label: null,
|
||||
match_subdomains: false,
|
||||
enabled: true,
|
||||
secret_last4: '1234',
|
||||
created_at: 0,
|
||||
updated_at: 0,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
describe('useHostCredentialsStore', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('fetchCredentials', () => {
|
||||
it('toggles isLoading and stores the result', async () => {
|
||||
const credential = createCredential()
|
||||
vi.mocked(downloadApi.listCredentials).mockResolvedValue([credential])
|
||||
const store = useHostCredentialsStore()
|
||||
|
||||
const promise = store.fetchCredentials()
|
||||
expect(store.isLoading).toBe(true)
|
||||
await promise
|
||||
|
||||
expect(store.isLoading).toBe(false)
|
||||
expect(store.credentials).toEqual([credential])
|
||||
})
|
||||
|
||||
it('clears isLoading even when the request fails', async () => {
|
||||
vi.mocked(downloadApi.listCredentials).mockRejectedValue(
|
||||
new Error('boom')
|
||||
)
|
||||
const store = useHostCredentialsStore()
|
||||
|
||||
await expect(store.fetchCredentials()).rejects.toThrow('boom')
|
||||
expect(store.isLoading).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('upsert', () => {
|
||||
it('normalizes the host and appends a new credential', async () => {
|
||||
const created = createCredential({ id: 'new', host: 'huggingface.co' })
|
||||
vi.mocked(downloadApi.upsertCredential).mockResolvedValue(created)
|
||||
const store = useHostCredentialsStore()
|
||||
const body: HostCredentialUpsert = {
|
||||
host: ' HuggingFace.co ',
|
||||
secret: 's3cret'
|
||||
}
|
||||
|
||||
const result = await store.upsert(body)
|
||||
|
||||
expect(downloadApi.upsertCredential).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ host: 'huggingface.co', secret: 's3cret' })
|
||||
)
|
||||
expect(result).toEqual(created)
|
||||
expect(store.credentials).toEqual([created])
|
||||
})
|
||||
|
||||
it('replaces an existing credential by id', async () => {
|
||||
const original = createCredential({ id: 'c1', label: 'Old' })
|
||||
const updated = createCredential({ id: 'c1', label: 'New' })
|
||||
vi.mocked(downloadApi.upsertCredential).mockResolvedValue(updated)
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(original)
|
||||
|
||||
await store.upsert({ host: 'huggingface.co', secret: 's' })
|
||||
|
||||
expect(store.credentials).toEqual([updated])
|
||||
})
|
||||
})
|
||||
|
||||
describe('remove', () => {
|
||||
it('deletes and filters out the credential by id', async () => {
|
||||
vi.mocked(downloadApi.deleteCredential).mockResolvedValue(undefined)
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({ id: 'c1' }),
|
||||
createCredential({ id: 'c2' })
|
||||
)
|
||||
|
||||
await store.remove('c1')
|
||||
|
||||
expect(downloadApi.deleteCredential).toHaveBeenCalledWith('c1')
|
||||
expect(store.credentials.map((c) => c.id)).toEqual(['c2'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('enabledCredentialForHost', () => {
|
||||
it('matches an exact enabled host case-insensitively', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({ host: 'HuggingFace.co', enabled: true })
|
||||
)
|
||||
|
||||
expect(store.enabledCredentialForHost('huggingface.co')?.host).toBe(
|
||||
'HuggingFace.co'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns undefined for a disabled exact match', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({ host: 'huggingface.co', enabled: false })
|
||||
)
|
||||
|
||||
expect(store.enabledCredentialForHost('huggingface.co')).toBeUndefined()
|
||||
})
|
||||
|
||||
it('falls back to a subdomain match only when match_subdomains is set', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({
|
||||
host: 'huggingface.co',
|
||||
match_subdomains: true,
|
||||
enabled: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(store.enabledCredentialForHost('cdn.huggingface.co')?.host).toBe(
|
||||
'huggingface.co'
|
||||
)
|
||||
})
|
||||
|
||||
it('skips a disabled subdomain credential and returns a later enabled match', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({
|
||||
id: 'disabled',
|
||||
host: 'huggingface.co',
|
||||
match_subdomains: true,
|
||||
enabled: false
|
||||
}),
|
||||
createCredential({
|
||||
id: 'enabled',
|
||||
host: 'huggingface.co',
|
||||
match_subdomains: true,
|
||||
enabled: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(store.enabledCredentialForHost('cdn.huggingface.co')?.id).toBe(
|
||||
'enabled'
|
||||
)
|
||||
})
|
||||
|
||||
it('does not subdomain-match when match_subdomains is false', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
store.credentials.push(
|
||||
createCredential({
|
||||
host: 'huggingface.co',
|
||||
match_subdomains: false,
|
||||
enabled: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(
|
||||
store.enabledCredentialForHost('cdn.huggingface.co')
|
||||
).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined when no host matches', () => {
|
||||
const store = useHostCredentialsStore()
|
||||
expect(store.enabledCredentialForHost('example.com')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
75
src/platform/modelManager/stores/hostCredentialsStore.ts
Normal file
75
src/platform/modelManager/stores/hostCredentialsStore.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
import { defineStore } from 'pinia'
|
||||
import { computed, ref } from 'vue'
|
||||
|
||||
import {
|
||||
deleteCredential,
|
||||
listCredentials,
|
||||
upsertCredential
|
||||
} from '../api/modelDownloadApi'
|
||||
import type { HostCredentialUpsert, HostCredentialView } from '../types'
|
||||
|
||||
function normalizeHost(host: string): string {
|
||||
return host.trim().toLowerCase()
|
||||
}
|
||||
|
||||
export const useHostCredentialsStore = defineStore('hostCredentials', () => {
|
||||
const credentials = ref<HostCredentialView[]>([])
|
||||
const isLoading = ref(false)
|
||||
|
||||
const credentialsByHost = computed(
|
||||
() => new Map(credentials.value.map((c) => [normalizeHost(c.host), c]))
|
||||
)
|
||||
|
||||
async function fetchCredentials() {
|
||||
isLoading.value = true
|
||||
try {
|
||||
credentials.value = await listCredentials()
|
||||
} finally {
|
||||
isLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function upsert(
|
||||
body: HostCredentialUpsert
|
||||
): Promise<HostCredentialView> {
|
||||
const view = await upsertCredential({
|
||||
...body,
|
||||
host: normalizeHost(body.host)
|
||||
})
|
||||
const index = credentials.value.findIndex((c) => c.id === view.id)
|
||||
credentials.value =
|
||||
index === -1
|
||||
? [...credentials.value, view]
|
||||
: credentials.value.map((c) => (c.id === view.id ? view : c))
|
||||
return view
|
||||
}
|
||||
|
||||
async function remove(id: string) {
|
||||
await deleteCredential(id)
|
||||
credentials.value = credentials.value.filter((c) => c.id !== id)
|
||||
}
|
||||
|
||||
function enabledCredentialForHost(
|
||||
host: string
|
||||
): HostCredentialView | undefined {
|
||||
const normalized = normalizeHost(host)
|
||||
const exact = credentialsByHost.value.get(normalized)
|
||||
if (exact) return exact.enabled ? exact : undefined
|
||||
|
||||
return credentials.value.find(
|
||||
(c) =>
|
||||
c.enabled &&
|
||||
c.match_subdomains &&
|
||||
normalized.endsWith(`.${normalizeHost(c.host)}`)
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
credentials,
|
||||
isLoading,
|
||||
fetchCredentials,
|
||||
upsert,
|
||||
remove,
|
||||
enabledCredentialForHost
|
||||
}
|
||||
})
|
||||
388
src/platform/modelManager/stores/modelDownloadStore.test.ts
Normal file
388
src/platform/modelManager/stores/modelDownloadStore.test.ts
Normal file
@@ -0,0 +1,388 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const flushPromises = () => vi.advanceTimersByTimeAsync(0)
|
||||
|
||||
import * as downloadApi from '../api/modelDownloadApi'
|
||||
import type { DownloadState, DownloadStatus } from '../types'
|
||||
import {
|
||||
downloadProgressFraction,
|
||||
useModelDownloadStore
|
||||
} from './modelDownloadStore'
|
||||
|
||||
type ProgressHandler = (e: CustomEvent<DownloadStatus>) => void
|
||||
|
||||
const eventHandler = vi.hoisted(() => {
|
||||
const state: { current: ProgressHandler | null } = { current: null }
|
||||
return state
|
||||
})
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
addEventListener: vi.fn((_event: string, handler: ProgressHandler) => {
|
||||
eventHandler.current = handler
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../api/modelDownloadApi', () => ({
|
||||
enqueueDownload: vi.fn(),
|
||||
listDownloads: vi.fn(),
|
||||
pauseDownload: vi.fn(),
|
||||
resumeDownload: vi.fn(),
|
||||
cancelDownload: vi.fn(),
|
||||
setDownloadPriority: vi.fn(),
|
||||
deleteDownload: vi.fn(),
|
||||
clearDownloads: vi.fn()
|
||||
}))
|
||||
|
||||
function createStatus(overrides: Partial<DownloadStatus> = {}): DownloadStatus {
|
||||
return {
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
url: 'https://huggingface.co/x.safetensors',
|
||||
status: 'active',
|
||||
priority: 0,
|
||||
total_bytes: 1000,
|
||||
bytes_done: 250,
|
||||
progress: 0.25,
|
||||
speed_bps: 100,
|
||||
eta_seconds: 10,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: 1,
|
||||
updated_at: 2,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function dispatch(status: DownloadStatus) {
|
||||
if (!eventHandler.current) throw new Error('handler not registered')
|
||||
eventHandler.current(new CustomEvent('download_progress', { detail: status }))
|
||||
}
|
||||
|
||||
describe('useModelDownloadStore', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
vi.useFakeTimers({ shouldAdvanceTime: false })
|
||||
vi.resetAllMocks()
|
||||
eventHandler.current = null
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('upserts rows from download_progress events', () => {
|
||||
const store = useModelDownloadStore()
|
||||
|
||||
dispatch(createStatus({ download_id: 'd1', bytes_done: 100 }))
|
||||
dispatch(createStatus({ download_id: 'd1', bytes_done: 900 }))
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'queued' }))
|
||||
|
||||
expect(store.downloadList).toHaveLength(2)
|
||||
expect(
|
||||
store.downloadList.find((d) => d.download_id === 'd1')?.bytes_done
|
||||
).toBe(900)
|
||||
})
|
||||
|
||||
it('splits active and terminal downloads', () => {
|
||||
const store = useModelDownloadStore()
|
||||
const states: DownloadState[] = [
|
||||
'queued',
|
||||
'active',
|
||||
'paused',
|
||||
'verifying',
|
||||
'completed',
|
||||
'failed',
|
||||
'cancelled'
|
||||
]
|
||||
states.forEach((status, idx) =>
|
||||
dispatch(createStatus({ download_id: `d${idx}`, status }))
|
||||
)
|
||||
|
||||
expect(store.activeDownloads.map((d) => d.status)).toEqual([
|
||||
'queued',
|
||||
'active',
|
||||
'paused',
|
||||
'verifying'
|
||||
])
|
||||
expect(store.historyDownloads.map((d) => d.status)).toEqual([
|
||||
'completed',
|
||||
'failed',
|
||||
'cancelled'
|
||||
])
|
||||
expect(store.activeDownloadCount).toBe(4)
|
||||
})
|
||||
|
||||
it('inserts an optimistic queued row on enqueue', async () => {
|
||||
vi.mocked(downloadApi.enqueueDownload).mockResolvedValue({
|
||||
download_id: 'new-id',
|
||||
accepted: true
|
||||
})
|
||||
const store = useModelDownloadStore()
|
||||
|
||||
const result = await store.enqueue({
|
||||
url: 'https://huggingface.co/x.safetensors',
|
||||
model_id: 'loras/x.safetensors'
|
||||
})
|
||||
|
||||
expect(result.download_id).toBe('new-id')
|
||||
const row = store.downloadList.find((d) => d.download_id === 'new-id')
|
||||
expect(row?.status).toBe('queued')
|
||||
expect(row?.model_id).toBe('loras/x.safetensors')
|
||||
})
|
||||
|
||||
it('optimistically updates status when pausing', async () => {
|
||||
vi.mocked(downloadApi.pauseDownload).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
await store.pause('d1')
|
||||
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
|
||||
'paused'
|
||||
)
|
||||
expect(downloadApi.pauseDownload).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('optimistically updates status when resuming', async () => {
|
||||
vi.mocked(downloadApi.resumeDownload).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'paused' }))
|
||||
|
||||
await store.resume('d1')
|
||||
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
|
||||
'queued'
|
||||
)
|
||||
expect(downloadApi.resumeDownload).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('marks a download cancelled after the API call resolves', async () => {
|
||||
vi.mocked(downloadApi.cancelDownload).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
await store.cancel('d1')
|
||||
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.status).toBe(
|
||||
'cancelled'
|
||||
)
|
||||
expect(downloadApi.cancelDownload).toHaveBeenCalledWith('d1')
|
||||
})
|
||||
|
||||
it('optimistically updates priority and calls the API', async () => {
|
||||
vi.mocked(downloadApi.setDownloadPriority).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'queued', priority: 0 }))
|
||||
|
||||
await store.setPriority('d1', 5)
|
||||
|
||||
expect(
|
||||
store.downloadList.find((d) => d.download_id === 'd1')?.priority
|
||||
).toBe(5)
|
||||
expect(downloadApi.setDownloadPriority).toHaveBeenCalledWith('d1', 5)
|
||||
})
|
||||
|
||||
it('is a no-op when patching priority for an unknown id', async () => {
|
||||
vi.mocked(downloadApi.setDownloadPriority).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
|
||||
await store.setPriority('missing', 5)
|
||||
|
||||
expect(store.downloadList).toHaveLength(0)
|
||||
expect(downloadApi.setDownloadPriority).toHaveBeenCalledWith('missing', 5)
|
||||
})
|
||||
|
||||
it('finds a download by model id', () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(
|
||||
createStatus({ download_id: 'd1', model_id: 'loras/x.safetensors' })
|
||||
)
|
||||
|
||||
expect(store.findByModelId('loras/x.safetensors')?.download_id).toBe('d1')
|
||||
expect(store.findByModelId('loras/missing.safetensors')).toBeUndefined()
|
||||
})
|
||||
|
||||
describe('hydrate', () => {
|
||||
it('replaces the download map with the fetched list', async () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'stale', status: 'active' }))
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
|
||||
createStatus({ download_id: 'fresh', status: 'active' })
|
||||
])
|
||||
|
||||
await store.hydrate()
|
||||
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['fresh'])
|
||||
})
|
||||
|
||||
it('records a completion when a refreshed row transitions to completed', async () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
|
||||
createStatus({
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
status: 'completed'
|
||||
})
|
||||
])
|
||||
|
||||
await store.hydrate()
|
||||
|
||||
expect(store.lastCompletedDownload).toMatchObject({
|
||||
downloadId: 'd1',
|
||||
modelId: 'loras/x.safetensors',
|
||||
directory: 'loras'
|
||||
})
|
||||
})
|
||||
|
||||
it('does not re-record a completion for a row that was already completed', async () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'completed' }))
|
||||
const firstTimestamp = store.lastCompletedDownload?.timestamp
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
|
||||
createStatus({ download_id: 'd1', status: 'completed' })
|
||||
])
|
||||
|
||||
await store.hydrate()
|
||||
|
||||
expect(store.lastCompletedDownload?.timestamp).toBe(firstTimestamp)
|
||||
})
|
||||
})
|
||||
|
||||
it('records the last completed download once on the completing transition', () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
expect(store.lastCompletedDownload).toBeNull()
|
||||
|
||||
dispatch(
|
||||
createStatus({
|
||||
download_id: 'd1',
|
||||
model_id: 'loras/x.safetensors',
|
||||
status: 'completed'
|
||||
})
|
||||
)
|
||||
|
||||
expect(store.lastCompletedDownload).toMatchObject({
|
||||
downloadId: 'd1',
|
||||
modelId: 'loras/x.safetensors',
|
||||
directory: 'loras'
|
||||
})
|
||||
|
||||
const firstTimestamp = store.lastCompletedDownload?.timestamp
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'completed' }))
|
||||
expect(store.lastCompletedDownload?.timestamp).toBe(firstTimestamp)
|
||||
})
|
||||
|
||||
it('refetches when a download fails without an error message', async () => {
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([
|
||||
createStatus({ download_id: 'd1', status: 'failed', error: 'gated' })
|
||||
])
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'failed', error: null }))
|
||||
await flushPromises()
|
||||
|
||||
expect(downloadApi.listDownloads).toHaveBeenCalledOnce()
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.error).toBe(
|
||||
'gated'
|
||||
)
|
||||
})
|
||||
|
||||
it('does not refetch when the failure event already carries an error', async () => {
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
dispatch(
|
||||
createStatus({ download_id: 'd1', status: 'failed', error: 'disk full' })
|
||||
)
|
||||
await flushPromises()
|
||||
|
||||
expect(downloadApi.listDownloads).not.toHaveBeenCalled()
|
||||
expect(store.downloadList.find((d) => d.download_id === 'd1')?.error).toBe(
|
||||
'disk full'
|
||||
)
|
||||
})
|
||||
|
||||
it('deletes a row through the backend so it stays gone', async () => {
|
||||
vi.mocked(downloadApi.deleteDownload).mockResolvedValue()
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
|
||||
|
||||
await store.remove('d2')
|
||||
|
||||
expect(downloadApi.deleteDownload).toHaveBeenCalledWith('d2')
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d1'])
|
||||
})
|
||||
|
||||
it('clears every history row in one backend call, leaving active downloads', async () => {
|
||||
vi.mocked(downloadApi.clearDownloads).mockResolvedValue(2)
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
|
||||
dispatch(createStatus({ download_id: 'd3', status: 'failed' }))
|
||||
|
||||
await store.clearHistory()
|
||||
|
||||
expect(downloadApi.clearDownloads).toHaveBeenCalledOnce()
|
||||
expect(downloadApi.deleteDownload).not.toHaveBeenCalled()
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d1'])
|
||||
})
|
||||
|
||||
it('keeps history rows locally when the bulk clear fails', async () => {
|
||||
vi.mocked(downloadApi.clearDownloads).mockRejectedValue(new Error('boom'))
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
|
||||
|
||||
await expect(store.clearHistory()).rejects.toThrow('boom')
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d2'])
|
||||
})
|
||||
|
||||
it('keeps the row locally when the backend delete fails', async () => {
|
||||
vi.mocked(downloadApi.deleteDownload).mockRejectedValue(new Error('boom'))
|
||||
const store = useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd2', status: 'completed' }))
|
||||
|
||||
await expect(store.remove('d2')).rejects.toThrow('boom')
|
||||
expect(store.downloadList.map((d) => d.download_id)).toEqual(['d2'])
|
||||
})
|
||||
|
||||
it('polls the list when active downloads go stale', async () => {
|
||||
vi.mocked(downloadApi.listDownloads).mockResolvedValue([])
|
||||
useModelDownloadStore()
|
||||
dispatch(createStatus({ download_id: 'd1', status: 'active' }))
|
||||
|
||||
await vi.advanceTimersByTimeAsync(15_000)
|
||||
|
||||
expect(downloadApi.listDownloads).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
describe('downloadProgressFraction', () => {
|
||||
it('uses live progress when present', () => {
|
||||
expect(downloadProgressFraction(createStatus({ progress: 0.4 }))).toBe(
|
||||
0.4
|
||||
)
|
||||
})
|
||||
|
||||
it('derives from bytes when progress is null', () => {
|
||||
expect(
|
||||
downloadProgressFraction(
|
||||
createStatus({ progress: null, bytes_done: 500, total_bytes: 1000 })
|
||||
)
|
||||
).toBe(0.5)
|
||||
})
|
||||
|
||||
it('returns null when total is unknown', () => {
|
||||
expect(
|
||||
downloadProgressFraction(
|
||||
createStatus({ progress: null, total_bytes: null })
|
||||
)
|
||||
).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
229
src/platform/modelManager/stores/modelDownloadStore.ts
Normal file
229
src/platform/modelManager/stores/modelDownloadStore.ts
Normal file
@@ -0,0 +1,229 @@
|
||||
import { useIntervalFn } from '@vueuse/core'
|
||||
import { defineStore } from 'pinia'
|
||||
import { computed, ref } from 'vue'
|
||||
|
||||
import { api } from '@/scripts/api'
|
||||
|
||||
import {
|
||||
cancelDownload,
|
||||
clearDownloads,
|
||||
deleteDownload,
|
||||
enqueueDownload,
|
||||
listDownloads,
|
||||
pauseDownload,
|
||||
resumeDownload,
|
||||
setDownloadPriority
|
||||
} from '../api/modelDownloadApi'
|
||||
import type {
|
||||
DownloadState,
|
||||
DownloadStatus,
|
||||
EnqueueRequest,
|
||||
EnqueueResponse
|
||||
} from '../types'
|
||||
import { directoryOf } from '../utils/modelId'
|
||||
|
||||
const ACTIVE_STATES: ReadonlySet<DownloadState> = new Set([
|
||||
'queued',
|
||||
'active',
|
||||
'paused',
|
||||
'verifying'
|
||||
])
|
||||
|
||||
const POLL_INTERVAL_MS = 5_000
|
||||
const STALE_THRESHOLD_MS = 8_000
|
||||
|
||||
/**
|
||||
* Static completion fraction (0..1) for a download. Live values (`progress`)
|
||||
* are only populated while a worker is running; for paused/terminal rows we
|
||||
* derive it from `bytes_done / total_bytes`.
|
||||
*/
|
||||
export function downloadProgressFraction(
|
||||
download: DownloadStatus
|
||||
): number | null {
|
||||
if (download.progress != null) return download.progress
|
||||
if (download.total_bytes && download.total_bytes > 0) {
|
||||
return download.bytes_done / download.total_bytes
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function optimisticRow(
|
||||
downloadId: string,
|
||||
request: EnqueueRequest
|
||||
): DownloadStatus {
|
||||
const now = Math.floor(Date.now() / 1000)
|
||||
return {
|
||||
download_id: downloadId,
|
||||
model_id: request.model_id,
|
||||
url: request.url,
|
||||
status: 'queued',
|
||||
priority: request.priority ?? 0,
|
||||
total_bytes: null,
|
||||
bytes_done: 0,
|
||||
progress: null,
|
||||
speed_bps: null,
|
||||
eta_seconds: null,
|
||||
segments: null,
|
||||
error: null,
|
||||
created_at: now,
|
||||
updated_at: now
|
||||
}
|
||||
}
|
||||
|
||||
export interface CompletedDownload {
|
||||
downloadId: string
|
||||
modelId: string
|
||||
directory: string
|
||||
timestamp: number
|
||||
}
|
||||
|
||||
export const useModelDownloadStore = defineStore('modelDownload', () => {
|
||||
const downloads = ref<Map<string, DownloadStatus>>(new Map())
|
||||
const lastWsUpdate = ref(0)
|
||||
const lastCompletedDownload = ref<CompletedDownload | null>(null)
|
||||
|
||||
const downloadList = computed(() => Array.from(downloads.value.values()))
|
||||
const activeDownloads = computed(() =>
|
||||
downloadList.value.filter((d) => ACTIVE_STATES.has(d.status))
|
||||
)
|
||||
const historyDownloads = computed(() =>
|
||||
downloadList.value.filter((d) => !ACTIVE_STATES.has(d.status))
|
||||
)
|
||||
const hasActiveDownloads = computed(() => activeDownloads.value.length > 0)
|
||||
const activeDownloadCount = computed(() => activeDownloads.value.length)
|
||||
|
||||
function upsert(status: DownloadStatus) {
|
||||
downloads.value.set(status.download_id, status)
|
||||
}
|
||||
|
||||
function findByModelId(modelId: string): DownloadStatus | undefined {
|
||||
return downloadList.value.find((d) => d.model_id === modelId)
|
||||
}
|
||||
|
||||
function handleProgress(e: CustomEvent<DownloadStatus>) {
|
||||
lastWsUpdate.value = Date.now()
|
||||
const previous = downloads.value.get(e.detail.download_id)
|
||||
upsert(e.detail)
|
||||
if (e.detail.status === 'completed' && previous?.status !== 'completed') {
|
||||
lastCompletedDownload.value = {
|
||||
downloadId: e.detail.download_id,
|
||||
modelId: e.detail.model_id,
|
||||
directory: directoryOf(e.detail.model_id),
|
||||
timestamp: Date.now()
|
||||
}
|
||||
}
|
||||
// The live failure event carries the terminal status but not the error
|
||||
// text, and polling stops once nothing is active. Refetch so the failure
|
||||
// reason surfaces without the user reopening the panel.
|
||||
if (
|
||||
e.detail.status === 'failed' &&
|
||||
previous?.status !== 'failed' &&
|
||||
!e.detail.error
|
||||
) {
|
||||
void hydrate().catch(() => {})
|
||||
}
|
||||
}
|
||||
|
||||
async function hydrate() {
|
||||
const previous = downloads.value
|
||||
const list = await listDownloads()
|
||||
downloads.value = new Map(list.map((d) => [d.download_id, d]))
|
||||
for (const download of list) {
|
||||
const prior = previous.get(download.download_id)
|
||||
if (
|
||||
download.status === 'completed' &&
|
||||
prior &&
|
||||
prior.status !== 'completed'
|
||||
) {
|
||||
lastCompletedDownload.value = {
|
||||
downloadId: download.download_id,
|
||||
modelId: download.model_id,
|
||||
directory: directoryOf(download.model_id),
|
||||
timestamp: Date.now()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function enqueue(request: EnqueueRequest): Promise<EnqueueResponse> {
|
||||
const response = await enqueueDownload(request)
|
||||
upsert(optimisticRow(response.download_id, request))
|
||||
return response
|
||||
}
|
||||
|
||||
function patchStatus(id: string, status: DownloadState) {
|
||||
const existing = downloads.value.get(id)
|
||||
if (existing) {
|
||||
downloads.value.set(id, { ...existing, status })
|
||||
}
|
||||
}
|
||||
|
||||
async function pause(id: string) {
|
||||
patchStatus(id, 'paused')
|
||||
await pauseDownload(id)
|
||||
}
|
||||
|
||||
async function resume(id: string) {
|
||||
patchStatus(id, 'queued')
|
||||
await resumeDownload(id)
|
||||
}
|
||||
|
||||
async function cancel(id: string) {
|
||||
await cancelDownload(id)
|
||||
patchStatus(id, 'cancelled')
|
||||
}
|
||||
|
||||
async function setPriority(id: string, priority: number) {
|
||||
const existing = downloads.value.get(id)
|
||||
if (existing) {
|
||||
downloads.value.set(id, { ...existing, priority })
|
||||
}
|
||||
await setDownloadPriority(id, priority)
|
||||
}
|
||||
|
||||
async function remove(id: string) {
|
||||
await deleteDownload(id)
|
||||
downloads.value.delete(id)
|
||||
}
|
||||
|
||||
async function clearHistory() {
|
||||
const clearedIds = historyDownloads.value.map((d) => d.download_id)
|
||||
await clearDownloads()
|
||||
for (const id of clearedIds) {
|
||||
downloads.value.delete(id)
|
||||
}
|
||||
}
|
||||
|
||||
async function pollIfStale() {
|
||||
if (!hasActiveDownloads.value) return
|
||||
if (Date.now() - lastWsUpdate.value < STALE_THRESHOLD_MS) return
|
||||
try {
|
||||
await hydrate()
|
||||
} catch {
|
||||
// Server unreachable; retry on next interval
|
||||
}
|
||||
}
|
||||
|
||||
useIntervalFn(() => void pollIfStale(), POLL_INTERVAL_MS)
|
||||
|
||||
api.addEventListener('download_progress', handleProgress)
|
||||
|
||||
return {
|
||||
downloadList,
|
||||
activeDownloads,
|
||||
historyDownloads,
|
||||
hasActiveDownloads,
|
||||
activeDownloadCount,
|
||||
lastCompletedDownload,
|
||||
upsert,
|
||||
findByModelId,
|
||||
hydrate,
|
||||
enqueue,
|
||||
pause,
|
||||
resume,
|
||||
cancel,
|
||||
setPriority,
|
||||
remove,
|
||||
clearHistory
|
||||
}
|
||||
})
|
||||
130
src/platform/modelManager/types.ts
Normal file
130
src/platform/modelManager/types.ts
Normal file
@@ -0,0 +1,130 @@
|
||||
import { z } from 'zod'
|
||||
|
||||
import type { DownloadState, DownloadStatus } from '@/schemas/apiSchema'
|
||||
|
||||
export type { DownloadState, DownloadStatus }
|
||||
|
||||
/**
|
||||
* Known model file extensions accepted by the backend without
|
||||
* `allow_any_extension`. Mirrors the server's extension allowlist and is
|
||||
* used for instant client-side validation only — the server stays the source
|
||||
* of truth.
|
||||
*/
|
||||
export const MODEL_EXTENSIONS = [
|
||||
'.safetensors',
|
||||
'.sft',
|
||||
'.ckpt',
|
||||
'.pth',
|
||||
'.pt',
|
||||
'.gguf',
|
||||
'.bin'
|
||||
] as const
|
||||
|
||||
/**
|
||||
* Hosts the backend allows out of the box. Admins can extend this
|
||||
* server-side, so this list is only for optimistic client-side hints; rely on
|
||||
* the server's `URL_NOT_ALLOWED` / `url_allowed` as the source of truth.
|
||||
*/
|
||||
export const DEFAULT_ALLOWED_HOSTS = [
|
||||
'huggingface.co',
|
||||
'civitai.com',
|
||||
'localhost',
|
||||
'127.0.0.1'
|
||||
] as const
|
||||
|
||||
export interface EnqueueRequest {
|
||||
url: string
|
||||
model_id: string
|
||||
priority?: number
|
||||
expected_sha256?: string | null
|
||||
allow_any_extension?: boolean
|
||||
credential_id?: string | null
|
||||
}
|
||||
|
||||
export interface EnqueueResponse {
|
||||
download_id: string
|
||||
accepted: boolean
|
||||
}
|
||||
|
||||
export const AUTH_SCHEMES = ['bearer', 'header', 'query'] as const
|
||||
export type AuthScheme = (typeof AUTH_SCHEMES)[number]
|
||||
|
||||
const zHostCredentialView = z.object({
|
||||
id: z.string(),
|
||||
host: z.string(),
|
||||
auth_scheme: z.enum(AUTH_SCHEMES),
|
||||
header_name: z.string().nullable(),
|
||||
query_param: z.string().nullable(),
|
||||
label: z.string().nullable(),
|
||||
match_subdomains: z.boolean(),
|
||||
enabled: z.boolean(),
|
||||
secret_last4: z.string().nullable(),
|
||||
created_at: z.number(),
|
||||
updated_at: z.number()
|
||||
})
|
||||
export type HostCredentialView = z.infer<typeof zHostCredentialView>
|
||||
|
||||
export interface HostCredentialUpsert {
|
||||
host: string
|
||||
secret: string
|
||||
auth_scheme?: AuthScheme
|
||||
header_name?: string | null
|
||||
query_param?: string | null
|
||||
label?: string | null
|
||||
match_subdomains?: boolean
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
interface AvailabilityBase {
|
||||
url_allowed: boolean
|
||||
}
|
||||
|
||||
export type AvailabilityEntry =
|
||||
| (AvailabilityBase & { state: 'available' })
|
||||
| (AvailabilityBase & { state: 'missing' })
|
||||
| (AvailabilityBase & {
|
||||
state: 'downloading'
|
||||
download_id: string
|
||||
progress: number | null
|
||||
bytes_done: number
|
||||
total_bytes: number | null
|
||||
speed_bps: number | null
|
||||
})
|
||||
|
||||
export interface AvailabilityResponse {
|
||||
models: Record<string, AvailabilityEntry>
|
||||
}
|
||||
|
||||
const DOWNLOAD_ERROR_CODES = [
|
||||
'INVALID_JSON',
|
||||
'INVALID_BODY',
|
||||
'URL_NOT_ALLOWED',
|
||||
'INVALID_MODEL_ID',
|
||||
'INVALID_CREDENTIAL',
|
||||
'ALREADY_AVAILABLE',
|
||||
'ALREADY_DOWNLOADING',
|
||||
'DOWNLOAD_ACTIVE',
|
||||
'NOT_FOUND'
|
||||
] as const
|
||||
export type DownloadErrorCode = (typeof DOWNLOAD_ERROR_CODES)[number]
|
||||
|
||||
/**
|
||||
* Error envelope returned by every download-manager endpoint on failure.
|
||||
* `code` is the stable machine-readable discriminator; `message` is
|
||||
* user-facing; `details` is an open object (do not assume a shape).
|
||||
*/
|
||||
export class DownloadApiError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public readonly code: string,
|
||||
public readonly status: number,
|
||||
public readonly details?: Record<string, unknown>
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'DownloadApiError'
|
||||
}
|
||||
|
||||
is(code: DownloadErrorCode): boolean {
|
||||
return this.code === code
|
||||
}
|
||||
}
|
||||
99
src/platform/modelManager/utils/modelId.test.ts
Normal file
99
src/platform/modelManager/utils/modelId.test.ts
Normal file
@@ -0,0 +1,99 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
buildModelId,
|
||||
filenameFromUrl,
|
||||
hasModelExtension,
|
||||
hostFromUrl,
|
||||
isLikelyAllowedHost,
|
||||
isValidPathSegment
|
||||
} from './modelId'
|
||||
|
||||
describe('modelId utils', () => {
|
||||
describe('isValidPathSegment', () => {
|
||||
it('accepts valid filenames', () => {
|
||||
expect(isValidPathSegment('my_lora.safetensors')).toBe(true)
|
||||
expect(isValidPathSegment('sdxl-base.ckpt')).toBe(true)
|
||||
})
|
||||
|
||||
it('rejects traversal, leading dots, and slashes', () => {
|
||||
expect(isValidPathSegment('..')).toBe(false)
|
||||
expect(isValidPathSegment('.hidden')).toBe(false)
|
||||
expect(isValidPathSegment('nested/path')).toBe(false)
|
||||
expect(isValidPathSegment('')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasModelExtension', () => {
|
||||
it('matches known extensions case-insensitively', () => {
|
||||
expect(hasModelExtension('model.SAFETENSORS')).toBe(true)
|
||||
expect(hasModelExtension('weights.gguf')).toBe(true)
|
||||
})
|
||||
|
||||
it('rejects unknown extensions', () => {
|
||||
expect(hasModelExtension('readme.txt')).toBe(false)
|
||||
expect(hasModelExtension('archive.zip')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildModelId', () => {
|
||||
it('joins directory and filename with a single slash', () => {
|
||||
expect(buildModelId('loras', 'x.safetensors')).toBe('loras/x.safetensors')
|
||||
})
|
||||
|
||||
it('collapses redundant slashes at the join boundary', () => {
|
||||
expect(buildModelId('loras/', 'x.safetensors')).toBe(
|
||||
'loras/x.safetensors'
|
||||
)
|
||||
expect(buildModelId('loras', '/x.safetensors')).toBe(
|
||||
'loras/x.safetensors'
|
||||
)
|
||||
expect(buildModelId('loras//', '//x.safetensors')).toBe(
|
||||
'loras/x.safetensors'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('hostFromUrl', () => {
|
||||
it('extracts the lowercased host', () => {
|
||||
expect(hostFromUrl('https://HuggingFace.co/a/b')).toBe('huggingface.co')
|
||||
})
|
||||
|
||||
it('returns null for invalid URLs', () => {
|
||||
expect(hostFromUrl('not a url')).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('isLikelyAllowedHost', () => {
|
||||
it('allows built-in hosts and their subdomains', () => {
|
||||
expect(isLikelyAllowedHost('https://huggingface.co/x')).toBe(true)
|
||||
expect(isLikelyAllowedHost('https://cdn.huggingface.co/x')).toBe(true)
|
||||
expect(isLikelyAllowedHost('https://civitai.com/api/download/1')).toBe(
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('flags unknown hosts', () => {
|
||||
expect(isLikelyAllowedHost('https://example.com/x')).toBe(false)
|
||||
expect(isLikelyAllowedHost('garbage')).toBe(false)
|
||||
})
|
||||
|
||||
it('does not treat a fake subdomain of an allowed IP literal as allowed', () => {
|
||||
expect(isLikelyAllowedHost('https://127.0.0.1/x')).toBe(true)
|
||||
expect(isLikelyAllowedHost('https://evil.127.0.0.1/x')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('filenameFromUrl', () => {
|
||||
it('returns the decoded trailing path segment', () => {
|
||||
expect(filenameFromUrl('https://h.co/a/my%20model.safetensors')).toBe(
|
||||
'my model.safetensors'
|
||||
)
|
||||
})
|
||||
|
||||
it('returns empty string when no segment is present', () => {
|
||||
expect(filenameFromUrl('https://h.co/')).toBe('')
|
||||
expect(filenameFromUrl('bad')).toBe('')
|
||||
})
|
||||
})
|
||||
})
|
||||
79
src/platform/modelManager/utils/modelId.ts
Normal file
79
src/platform/modelManager/utils/modelId.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import { DEFAULT_ALLOWED_HOSTS, MODEL_EXTENSIONS } from '../types'
|
||||
|
||||
const SEGMENT_PATTERN = /^[A-Za-z0-9][A-Za-z0-9._-]*$/
|
||||
|
||||
export function isValidPathSegment(segment: string): boolean {
|
||||
return SEGMENT_PATTERN.test(segment)
|
||||
}
|
||||
|
||||
export function hasModelExtension(filename: string): boolean {
|
||||
const lower = filename.toLowerCase()
|
||||
return MODEL_EXTENSIONS.some((ext) => lower.endsWith(ext))
|
||||
}
|
||||
|
||||
export function buildModelId(directory: string, filename: string): string {
|
||||
return `${directory.replace(/\/+$/, '')}/${filename.replace(/^\/+/, '')}`
|
||||
}
|
||||
|
||||
/**
|
||||
* Directory portion of a `directory/filename` model id, or an empty string when
|
||||
* the id has no directory prefix.
|
||||
*/
|
||||
export function directoryOf(modelId: string): string {
|
||||
const slash = modelId.indexOf('/')
|
||||
return slash === -1 ? '' : modelId.slice(0, slash)
|
||||
}
|
||||
|
||||
/**
|
||||
* Filename portion of a `directory/filename` model id, falling back to the full
|
||||
* id when there is no directory prefix.
|
||||
*/
|
||||
export function filenameOf(modelId: string): string {
|
||||
const slash = modelId.indexOf('/')
|
||||
return slash === -1 ? modelId : modelId.slice(slash + 1)
|
||||
}
|
||||
|
||||
/**
|
||||
* Lowercased host of a URL, or `null` when the URL is unparseable.
|
||||
*/
|
||||
export function hostFromUrl(url: string): string | null {
|
||||
try {
|
||||
return new URL(url).hostname.toLowerCase()
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
const IPV4_PATTERN = /^\d{1,3}(?:\.\d{1,3}){3}$/
|
||||
|
||||
function isIpLiteral(host: string): boolean {
|
||||
return IPV4_PATTERN.test(host) || host.includes(':')
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimistic client-side allowlist hint. The server can extend the
|
||||
* allowlist, so a `false` here is advisory — defer to `URL_NOT_ALLOWED`.
|
||||
*/
|
||||
export function isLikelyAllowedHost(url: string): boolean {
|
||||
const host = hostFromUrl(url)
|
||||
if (!host) return false
|
||||
return DEFAULT_ALLOWED_HOSTS.some(
|
||||
(allowed) =>
|
||||
host === allowed ||
|
||||
(!isIpLiteral(allowed) && host.endsWith(`.${allowed}`))
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Best-effort filename guess from a URL path, for prefilling the model_id
|
||||
* filename field. May be empty when the URL has no usable trailing segment.
|
||||
*/
|
||||
export function filenameFromUrl(url: string): string {
|
||||
try {
|
||||
const { pathname } = new URL(url)
|
||||
const last = pathname.split('/').filter(Boolean).pop() ?? ''
|
||||
return decodeURIComponent(last)
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
@@ -154,6 +154,39 @@ const zAssetDownloadWsMessage = z.object({
|
||||
error: z.string().optional()
|
||||
})
|
||||
|
||||
const DOWNLOAD_STATES = [
|
||||
'queued',
|
||||
'active',
|
||||
'paused',
|
||||
'verifying',
|
||||
'completed',
|
||||
'failed',
|
||||
'cancelled'
|
||||
] as const
|
||||
|
||||
const zDownloadSegment = z.object({
|
||||
idx: z.number().int().nonnegative(),
|
||||
bytes_done: z.number().int().nonnegative(),
|
||||
length: z.number().int().nonnegative()
|
||||
})
|
||||
|
||||
const zDownloadStatus = z.object({
|
||||
download_id: z.string(),
|
||||
model_id: z.string(),
|
||||
url: z.string(),
|
||||
status: z.enum(DOWNLOAD_STATES),
|
||||
priority: z.number().int(),
|
||||
total_bytes: z.number().int().nonnegative().nullable(),
|
||||
bytes_done: z.number().int().nonnegative(),
|
||||
progress: z.number().nullable(),
|
||||
speed_bps: z.number().nonnegative().nullable(),
|
||||
eta_seconds: z.number().nonnegative().nullable(),
|
||||
segments: z.array(zDownloadSegment).nullable(),
|
||||
error: z.string().nullable(),
|
||||
created_at: z.number(),
|
||||
updated_at: z.number()
|
||||
})
|
||||
|
||||
const zAssetExportWsMessage = z.object({
|
||||
task_id: z.string(),
|
||||
export_name: z.string().optional(),
|
||||
@@ -188,6 +221,8 @@ export type ProgressStateWsMessage = z.infer<typeof zProgressStateWsMessage>
|
||||
export type FeatureFlagsWsMessage = z.infer<typeof zFeatureFlagsWsMessage>
|
||||
export type AssetDownloadWsMessage = z.infer<typeof zAssetDownloadWsMessage>
|
||||
export type AssetExportWsMessage = z.infer<typeof zAssetExportWsMessage>
|
||||
export type DownloadState = (typeof DOWNLOAD_STATES)[number]
|
||||
export type DownloadStatus = z.infer<typeof zDownloadStatus>
|
||||
// End of ws messages
|
||||
|
||||
export type NotificationWsMessage = z.infer<typeof zNotificationWsMessage>
|
||||
|
||||
@@ -34,6 +34,7 @@ import type {
|
||||
AssetDownloadWsMessage,
|
||||
AssetExportWsMessage,
|
||||
CustomNodesI18n,
|
||||
DownloadStatus,
|
||||
EmbeddingsResponse,
|
||||
ExecutedWsMessage,
|
||||
ExecutingWsMessage,
|
||||
@@ -186,6 +187,7 @@ interface BackendApiCalls {
|
||||
feature_flags: FeatureFlagsWsMessage
|
||||
asset_download: AssetDownloadWsMessage
|
||||
asset_export: AssetExportWsMessage
|
||||
download_progress: DownloadStatus
|
||||
}
|
||||
|
||||
/** Dictionary of all api calls */
|
||||
|
||||
@@ -113,6 +113,27 @@ describe('commandStore', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('unregisterCommand', () => {
|
||||
it('removes a registered command', () => {
|
||||
const store = useCommandStore()
|
||||
store.registerCommand({ id: 'to.remove', function: vi.fn() })
|
||||
|
||||
store.unregisterCommand('to.remove')
|
||||
|
||||
expect(store.isRegistered('to.remove')).toBe(false)
|
||||
})
|
||||
|
||||
it('is a no-op for an unregistered id', () => {
|
||||
const store = useCommandStore()
|
||||
store.registerCommand({ id: 'keep.me', function: vi.fn() })
|
||||
|
||||
store.unregisterCommand('nonexistent')
|
||||
|
||||
expect(store.isRegistered('keep.me')).toBe(true)
|
||||
expect(store.commands).toHaveLength(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isRegistered', () => {
|
||||
it('returns false for unregistered command', () => {
|
||||
const store = useCommandStore()
|
||||
|
||||
@@ -88,6 +88,12 @@ export const useCommandStore = defineStore('command', () => {
|
||||
}
|
||||
}
|
||||
|
||||
const unregisterCommand = (commandId: string) => {
|
||||
if (!commandsById.value[commandId]) return
|
||||
const { [commandId]: _removed, ...rest } = commandsById.value
|
||||
commandsById.value = rest
|
||||
}
|
||||
|
||||
const getCommand = (command: string) => {
|
||||
return commandsById.value[command]
|
||||
}
|
||||
@@ -139,6 +145,7 @@ export const useCommandStore = defineStore('command', () => {
|
||||
getCommand,
|
||||
registerCommand,
|
||||
registerCommands,
|
||||
unregisterCommand,
|
||||
isRegistered,
|
||||
loadExtensionCommands,
|
||||
formatKeySequence
|
||||
|
||||
@@ -5,12 +5,19 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { useSidebarTabStore } from '@/stores/workspace/sidebarTabStore'
|
||||
|
||||
const { mockGetSetting, mockRegisterCommand, mockRegisterCommands } =
|
||||
vi.hoisted(() => ({
|
||||
mockGetSetting: vi.fn(),
|
||||
mockRegisterCommand: vi.fn(),
|
||||
mockRegisterCommands: vi.fn()
|
||||
}))
|
||||
const {
|
||||
mockGetSetting,
|
||||
mockRegisterCommand,
|
||||
mockRegisterCommands,
|
||||
mockUnregisterCommand,
|
||||
mockServerSideModelDownloads
|
||||
} = vi.hoisted(() => ({
|
||||
mockGetSetting: vi.fn(),
|
||||
mockRegisterCommand: vi.fn(),
|
||||
mockRegisterCommands: vi.fn(),
|
||||
mockUnregisterCommand: vi.fn(),
|
||||
mockServerSideModelDownloads: { value: false }
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: () => ({
|
||||
@@ -21,10 +28,43 @@ vi.mock('@/platform/settings/settingStore', () => ({
|
||||
vi.mock('@/stores/commandStore', () => ({
|
||||
useCommandStore: () => ({
|
||||
registerCommand: mockRegisterCommand,
|
||||
unregisterCommand: mockUnregisterCommand,
|
||||
commands: []
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/useFeatureFlags', async () => {
|
||||
const { ref } = await import('vue')
|
||||
const serverSideModelDownloadsRef = ref(mockServerSideModelDownloads.value)
|
||||
Object.defineProperty(mockServerSideModelDownloads, 'value', {
|
||||
get: () => serverSideModelDownloadsRef.value,
|
||||
set: (value: boolean) => {
|
||||
serverSideModelDownloadsRef.value = value
|
||||
}
|
||||
})
|
||||
return {
|
||||
useFeatureFlags: () => ({
|
||||
flags: {
|
||||
get serverSideModelDownloads() {
|
||||
return mockServerSideModelDownloads.value
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock(
|
||||
'@/platform/modelManager/composables/useModelManagerSidebarTab',
|
||||
() => ({
|
||||
useModelManagerSidebarTab: () => ({
|
||||
id: 'model-manager',
|
||||
title: 'model-manager',
|
||||
type: 'vue',
|
||||
component: {}
|
||||
})
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@/stores/menuItemStore', () => ({
|
||||
useMenuItemStore: () => ({
|
||||
registerCommands: mockRegisterCommands
|
||||
@@ -99,6 +139,8 @@ describe('useSidebarTabStore', () => {
|
||||
mockGetSetting.mockReset()
|
||||
mockRegisterCommand.mockClear()
|
||||
mockRegisterCommands.mockClear()
|
||||
mockUnregisterCommand.mockClear()
|
||||
mockServerSideModelDownloads.value = false
|
||||
})
|
||||
|
||||
it('registers the job history tab when QPO V2 is enabled', () => {
|
||||
@@ -160,4 +202,57 @@ describe('useSidebarTabStore', () => {
|
||||
])
|
||||
expect(mockRegisterCommand).toHaveBeenCalledTimes(6)
|
||||
})
|
||||
|
||||
it('registers the model-manager tab when the feature flag starts enabled', () => {
|
||||
mockServerSideModelDownloads.value = true
|
||||
|
||||
const store = useSidebarTabStore()
|
||||
store.registerCoreSidebarTabs()
|
||||
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).toContain('model-manager')
|
||||
})
|
||||
|
||||
it('does not register the model-manager tab when the feature flag is disabled', () => {
|
||||
const store = useSidebarTabStore()
|
||||
store.registerCoreSidebarTabs()
|
||||
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).not.toContain(
|
||||
'model-manager'
|
||||
)
|
||||
})
|
||||
|
||||
it('registers the model-manager tab when the feature flag turns on later', async () => {
|
||||
const store = useSidebarTabStore()
|
||||
store.registerCoreSidebarTabs()
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).not.toContain(
|
||||
'model-manager'
|
||||
)
|
||||
|
||||
mockServerSideModelDownloads.value = true
|
||||
await nextTick()
|
||||
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).toContain('model-manager')
|
||||
expect(mockRegisterCommand).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
id: 'Workspace.ToggleSidebarTab.model-manager'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('unregisters the model-manager tab and its command when the flag turns off', async () => {
|
||||
mockServerSideModelDownloads.value = true
|
||||
const store = useSidebarTabStore()
|
||||
store.registerCoreSidebarTabs()
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).toContain('model-manager')
|
||||
|
||||
mockServerSideModelDownloads.value = false
|
||||
await nextTick()
|
||||
|
||||
expect(store.sidebarTabs.map((tab) => tab.id)).not.toContain(
|
||||
'model-manager'
|
||||
)
|
||||
expect(mockUnregisterCommand).toHaveBeenCalledWith(
|
||||
'Workspace.ToggleSidebarTab.model-manager'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,7 +5,9 @@ import { useAssetsSidebarTab } from '@/composables/sidebarTabs/useAssetsSidebarT
|
||||
import { useJobHistorySidebarTab } from '@/composables/sidebarTabs/useJobHistorySidebarTab'
|
||||
import { useModelLibrarySidebarTab } from '@/composables/sidebarTabs/useModelLibrarySidebarTab'
|
||||
import { useNodeLibrarySidebarTab } from '@/composables/sidebarTabs/useNodeLibrarySidebarTab'
|
||||
import { useFeatureFlags } from '@/composables/useFeatureFlags'
|
||||
import { t, te } from '@/i18n'
|
||||
import { useModelManagerSidebarTab } from '@/platform/modelManager/composables/useModelManagerSidebarTab'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { useAppsSidebarTab } from '@/platform/workflow/management/composables/useAppsSidebarTab'
|
||||
import { useWorkflowsSidebarTab } from '@/platform/workflow/management/composables/useWorkflowsSidebarTab'
|
||||
@@ -53,7 +55,8 @@ export const useSidebarTabStore = defineStore('sidebarTab', () => {
|
||||
'model-library': 'sideToolbar.modelLibrary',
|
||||
workflows: 'sideToolbar.workflows',
|
||||
assets: 'sideToolbar.assets',
|
||||
'job-history': 'queue.jobHistory'
|
||||
'job-history': 'queue.jobHistory',
|
||||
'model-manager': 'modelManager.title'
|
||||
}
|
||||
|
||||
const key = menubarLabelKeys[tab.id]
|
||||
@@ -132,6 +135,27 @@ export const useSidebarTabStore = defineStore('sidebarTab', () => {
|
||||
(enabled) => syncJobHistoryTab(enabled)
|
||||
)
|
||||
|
||||
const modelManagerTabId = 'model-manager'
|
||||
const { flags } = useFeatureFlags()
|
||||
const syncModelManagerTab = (enabled: boolean) => {
|
||||
const hasTab = sidebarTabs.value.some(
|
||||
(tab) => tab.id === modelManagerTabId
|
||||
)
|
||||
if (enabled && !hasTab) {
|
||||
registerSidebarTab(useModelManagerSidebarTab())
|
||||
} else if (!enabled && hasTab) {
|
||||
unregisterSidebarTab(modelManagerTabId)
|
||||
useCommandStore().unregisterCommand(
|
||||
`Workspace.ToggleSidebarTab.${modelManagerTabId}`
|
||||
)
|
||||
}
|
||||
}
|
||||
watch(
|
||||
() => flags.serverSideModelDownloads,
|
||||
(enabled) => syncModelManagerTab(enabled),
|
||||
{ immediate: true }
|
||||
)
|
||||
|
||||
registerSidebarTab(useAssetsSidebarTab())
|
||||
registerSidebarTab(useNodeLibrarySidebarTab())
|
||||
registerSidebarTab(useModelLibrarySidebarTab())
|
||||
|
||||
@@ -58,6 +58,7 @@ import { useQueuePolling } from '@/platform/remote/comfyui/useQueuePolling'
|
||||
import { useErrorHandling } from '@/composables/useErrorHandling'
|
||||
import { useReconnectQueueRefresh } from '@/composables/useReconnectQueueRefresh'
|
||||
import { useReconnectingNotification } from '@/composables/useReconnectingNotification'
|
||||
import { useModelDownloadEffects } from '@/platform/modelManager/composables/useModelDownloadEffects'
|
||||
import { useProgressFavicon } from '@/composables/useProgressFavicon'
|
||||
import { SERVER_CONFIG_ITEMS } from '@/constants/serverConfig'
|
||||
import type { ServerConfig, ServerConfigValue } from '@/constants/serverConfig'
|
||||
@@ -227,6 +228,7 @@ useMenuItemStore().registerCoreMenuCommands()
|
||||
useKeybindingService().registerCoreKeybindings()
|
||||
useSidebarTabStore().registerCoreSidebarTabs()
|
||||
void useBottomPanelStore().registerCoreBottomPanelTabs()
|
||||
useModelDownloadEffects()
|
||||
|
||||
useQueuePolling()
|
||||
const queuePendingTaskCountStore = useQueuePendingTaskCountStore()
|
||||
|
||||
Reference in New Issue
Block a user