Files
ComfyUI_frontend/src/stores/assetDownloadStore.test.ts
Alexander Brown a166ec91a6 refactor: simplify asset download state and fix deletion UI (#7974)
## Summary

Refactors asset download state management and fixes asset deletion UI
issues.

## Changes

### assetDownloadStore simplification
- Replace `pendingModelTypes` Map with `modelType` stored directly on
`AssetDownload`
- Replace `completedDownloads` array with single `lastCompletedDownload`
ref
- `trackDownload()` now creates a placeholder entry immediately
- Use VueUse `whenever` instead of `watch` for cleaner null handling

### Asset refresh on download completion
- Refresh all relevant caches when a download completes:
  - Node type caches (e.g., "CheckpointLoaderSimple")
  - Tag caches (e.g., "tag:checkpoints")
  - "All Models" cache ("tag:models")

### Asset deletion fix
- Remove local `deletedLocal` state that caused blank grid cells
- Emit `deleted` event from AssetCard → AssetGrid → AssetBrowserModal
- Trigger store refresh on deletion to properly remove the asset from
the grid

## Testing

- Added test for out-of-order websocket message handling
- All existing tests pass

┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-7974-refactor-simplify-asset-download-state-and-fix-deletion-UI-2e76d73d365081c69bcde9150a0d460c)
by [Unito](https://www.unito.io)

---------

Co-authored-by: Amp <amp@ampcode.com>
2026-01-12 17:57:25 -08:00

240 lines
7.4 KiB
TypeScript

import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import type { TaskResponse } from '@/platform/tasks/services/taskService'
import { taskService } from '@/platform/tasks/services/taskService'
import type { AssetDownloadWsMessage } from '@/schemas/apiSchema'
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
type DownloadEventHandler = (e: CustomEvent<AssetDownloadWsMessage>) => void
const eventHandler = vi.hoisted(() => {
const state: { current: DownloadEventHandler | null } = { current: null }
return state
})
vi.mock('@/scripts/api', () => ({
api: {
addEventListener: vi.fn((_event: string, handler: DownloadEventHandler) => {
eventHandler.current = handler
}),
removeEventListener: vi.fn()
}
}))
vi.mock('@/platform/tasks/services/taskService', () => ({
taskService: {
getTask: vi.fn()
}
}))
function createDownloadMessage(
overrides: Partial<AssetDownloadWsMessage> = {}
): AssetDownloadWsMessage {
return {
task_id: 'task-123',
asset_id: 'asset-456',
asset_name: 'model.safetensors',
bytes_total: 1000,
bytes_downloaded: 500,
progress: 50,
status: 'running',
...overrides
}
}
function dispatch(msg: AssetDownloadWsMessage) {
if (!eventHandler.current) {
throw new Error(
'Event handler not registered. Call useAssetDownloadStore() first.'
)
}
eventHandler.current(new CustomEvent('asset_download', { detail: msg }))
}
describe('useAssetDownloadStore', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.useFakeTimers({ shouldAdvanceTime: false })
vi.resetAllMocks()
eventHandler.current = null
})
afterEach(() => {
vi.useRealTimers()
})
describe('handleAssetDownload', () => {
it('tracks running downloads', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage())
expect(store.activeDownloads).toHaveLength(1)
expect(store.activeDownloads[0].taskId).toBe('task-123')
expect(store.activeDownloads[0].progress).toBe(50)
})
it('moves download to finished when completed', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'running' }))
expect(store.activeDownloads).toHaveLength(1)
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('completed')
})
it('moves download to finished when failed', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'running' }))
dispatch(
createDownloadMessage({ status: 'failed', error: 'Network error' })
)
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('failed')
expect(store.finishedDownloads[0].error).toBe('Network error')
})
it('ignores duplicate terminal state messages', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.finishedDownloads).toHaveLength(1)
})
})
describe('trackDownload', () => {
it('associates task with model type for completion tracking', () => {
const store = useAssetDownloadStore()
store.trackDownload('task-123', 'checkpoints', 'model.safetensors')
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.lastCompletedDownload).toMatchObject({
taskId: 'task-123',
modelType: 'checkpoints'
})
})
it('handles out-of-order messages where completed arrives before progress', () => {
const store = useAssetDownloadStore()
store.trackDownload('task-123', 'checkpoints', 'model.safetensors')
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
dispatch(createDownloadMessage({ status: 'running', progress: 50 }))
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('completed')
expect(store.lastCompletedDownload?.modelType).toBe('checkpoints')
})
})
describe('stale download polling', () => {
function createTaskResponse(
overrides: Partial<TaskResponse> = {}
): TaskResponse {
return {
id: 'task-123',
idempotency_key: 'key-123',
task_name: 'task:download_file',
payload: {},
status: 'completed',
create_time: new Date().toISOString(),
update_time: new Date().toISOString(),
result: {
success: true,
asset_id: 'asset-456',
filename: 'model.safetensors',
bytes_downloaded: 1000
},
...overrides
}
}
it('polls and completes stale downloads', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockResolvedValue(createTaskResponse())
dispatch(createDownloadMessage({ status: 'running' }))
expect(store.activeDownloads).toHaveLength(1)
await vi.advanceTimersByTimeAsync(45_000)
expect(taskService.getTask).toHaveBeenCalledWith('task-123')
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads[0].status).toBe('completed')
})
it('polls and marks failed downloads', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockResolvedValue(
createTaskResponse({
status: 'failed',
error_message: 'Download failed',
result: { success: false, error: 'Network error' }
})
)
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads[0].status).toBe('failed')
expect(store.finishedDownloads[0].error).toBe('Download failed')
})
it('does not complete if task still running', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockResolvedValue(
createTaskResponse({ status: 'running', result: undefined })
)
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)
expect(taskService.getTask).toHaveBeenCalled()
expect(store.activeDownloads).toHaveLength(1)
})
it('continues tracking on polling error', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockRejectedValue(new Error('Not found'))
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)
expect(store.activeDownloads).toHaveLength(1)
})
})
describe('clearFinishedDownloads', () => {
it('removes all finished downloads', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.finishedDownloads).toHaveLength(1)
store.clearFinishedDownloads()
expect(store.finishedDownloads).toHaveLength(0)
})
})
})