mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-19 22:34:15 +00:00
## Summary
Refactors asset download state management and fixes asset deletion UI
issues.
## Changes
### assetDownloadStore simplification
- Replace `pendingModelTypes` Map with `modelType` stored directly on
`AssetDownload`
- Replace `completedDownloads` array with single `lastCompletedDownload`
ref
- `trackDownload()` now creates a placeholder entry immediately
- Use VueUse `whenever` instead of `watch` for cleaner null handling
### Asset refresh on download completion
- Refresh all relevant caches when a download completes:
- Node type caches (e.g., "CheckpointLoaderSimple")
- Tag caches (e.g., "tag:checkpoints")
- "All Models" cache ("tag:models")
### Asset deletion fix
- Remove local `deletedLocal` state that caused blank grid cells
- Emit `deleted` event from AssetCard → AssetGrid → AssetBrowserModal
- Trigger store refresh on deletion to properly remove the asset from
the grid
## Testing
- Added test for out-of-order websocket message handling
- All existing tests pass
┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-7974-refactor-simplify-asset-download-state-and-fix-deletion-UI-2e76d73d365081c69bcde9150a0d460c)
by [Unito](https://www.unito.io)
---------
Co-authored-by: Amp <amp@ampcode.com>
240 lines
7.4 KiB
TypeScript
240 lines
7.4 KiB
TypeScript
import { createTestingPinia } from '@pinia/testing'
|
|
import { setActivePinia } from 'pinia'
|
|
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
|
|
|
import type { TaskResponse } from '@/platform/tasks/services/taskService'
|
|
import { taskService } from '@/platform/tasks/services/taskService'
|
|
import type { AssetDownloadWsMessage } from '@/schemas/apiSchema'
|
|
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
|
|
|
|
type DownloadEventHandler = (e: CustomEvent<AssetDownloadWsMessage>) => void
|
|
|
|
const eventHandler = vi.hoisted(() => {
|
|
const state: { current: DownloadEventHandler | null } = { current: null }
|
|
return state
|
|
})
|
|
|
|
vi.mock('@/scripts/api', () => ({
|
|
api: {
|
|
addEventListener: vi.fn((_event: string, handler: DownloadEventHandler) => {
|
|
eventHandler.current = handler
|
|
}),
|
|
removeEventListener: vi.fn()
|
|
}
|
|
}))
|
|
|
|
vi.mock('@/platform/tasks/services/taskService', () => ({
|
|
taskService: {
|
|
getTask: vi.fn()
|
|
}
|
|
}))
|
|
|
|
function createDownloadMessage(
|
|
overrides: Partial<AssetDownloadWsMessage> = {}
|
|
): AssetDownloadWsMessage {
|
|
return {
|
|
task_id: 'task-123',
|
|
asset_id: 'asset-456',
|
|
asset_name: 'model.safetensors',
|
|
bytes_total: 1000,
|
|
bytes_downloaded: 500,
|
|
progress: 50,
|
|
status: 'running',
|
|
...overrides
|
|
}
|
|
}
|
|
|
|
function dispatch(msg: AssetDownloadWsMessage) {
|
|
if (!eventHandler.current) {
|
|
throw new Error(
|
|
'Event handler not registered. Call useAssetDownloadStore() first.'
|
|
)
|
|
}
|
|
eventHandler.current(new CustomEvent('asset_download', { detail: msg }))
|
|
}
|
|
|
|
describe('useAssetDownloadStore', () => {
|
|
beforeEach(() => {
|
|
setActivePinia(createTestingPinia({ stubActions: false }))
|
|
vi.useFakeTimers({ shouldAdvanceTime: false })
|
|
vi.resetAllMocks()
|
|
eventHandler.current = null
|
|
})
|
|
|
|
afterEach(() => {
|
|
vi.useRealTimers()
|
|
})
|
|
|
|
describe('handleAssetDownload', () => {
|
|
it('tracks running downloads', () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
dispatch(createDownloadMessage())
|
|
|
|
expect(store.activeDownloads).toHaveLength(1)
|
|
expect(store.activeDownloads[0].taskId).toBe('task-123')
|
|
expect(store.activeDownloads[0].progress).toBe(50)
|
|
})
|
|
|
|
it('moves download to finished when completed', () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
dispatch(createDownloadMessage({ status: 'running' }))
|
|
expect(store.activeDownloads).toHaveLength(1)
|
|
|
|
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
|
|
|
|
expect(store.activeDownloads).toHaveLength(0)
|
|
expect(store.finishedDownloads).toHaveLength(1)
|
|
expect(store.finishedDownloads[0].status).toBe('completed')
|
|
})
|
|
|
|
it('moves download to finished when failed', () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
dispatch(createDownloadMessage({ status: 'running' }))
|
|
dispatch(
|
|
createDownloadMessage({ status: 'failed', error: 'Network error' })
|
|
)
|
|
|
|
expect(store.activeDownloads).toHaveLength(0)
|
|
expect(store.finishedDownloads).toHaveLength(1)
|
|
expect(store.finishedDownloads[0].status).toBe('failed')
|
|
expect(store.finishedDownloads[0].error).toBe('Network error')
|
|
})
|
|
|
|
it('ignores duplicate terminal state messages', () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
|
|
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
|
|
|
|
expect(store.finishedDownloads).toHaveLength(1)
|
|
})
|
|
})
|
|
|
|
describe('trackDownload', () => {
|
|
it('associates task with model type for completion tracking', () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
store.trackDownload('task-123', 'checkpoints', 'model.safetensors')
|
|
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
|
|
|
|
expect(store.lastCompletedDownload).toMatchObject({
|
|
taskId: 'task-123',
|
|
modelType: 'checkpoints'
|
|
})
|
|
})
|
|
|
|
it('handles out-of-order messages where completed arrives before progress', () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
store.trackDownload('task-123', 'checkpoints', 'model.safetensors')
|
|
|
|
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
|
|
|
|
dispatch(createDownloadMessage({ status: 'running', progress: 50 }))
|
|
|
|
expect(store.activeDownloads).toHaveLength(0)
|
|
expect(store.finishedDownloads).toHaveLength(1)
|
|
expect(store.finishedDownloads[0].status).toBe('completed')
|
|
expect(store.lastCompletedDownload?.modelType).toBe('checkpoints')
|
|
})
|
|
})
|
|
|
|
describe('stale download polling', () => {
|
|
function createTaskResponse(
|
|
overrides: Partial<TaskResponse> = {}
|
|
): TaskResponse {
|
|
return {
|
|
id: 'task-123',
|
|
idempotency_key: 'key-123',
|
|
task_name: 'task:download_file',
|
|
payload: {},
|
|
status: 'completed',
|
|
create_time: new Date().toISOString(),
|
|
update_time: new Date().toISOString(),
|
|
result: {
|
|
success: true,
|
|
asset_id: 'asset-456',
|
|
filename: 'model.safetensors',
|
|
bytes_downloaded: 1000
|
|
},
|
|
...overrides
|
|
}
|
|
}
|
|
|
|
it('polls and completes stale downloads', async () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
vi.mocked(taskService.getTask).mockResolvedValue(createTaskResponse())
|
|
|
|
dispatch(createDownloadMessage({ status: 'running' }))
|
|
expect(store.activeDownloads).toHaveLength(1)
|
|
|
|
await vi.advanceTimersByTimeAsync(45_000)
|
|
|
|
expect(taskService.getTask).toHaveBeenCalledWith('task-123')
|
|
expect(store.activeDownloads).toHaveLength(0)
|
|
expect(store.finishedDownloads[0].status).toBe('completed')
|
|
})
|
|
|
|
it('polls and marks failed downloads', async () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
vi.mocked(taskService.getTask).mockResolvedValue(
|
|
createTaskResponse({
|
|
status: 'failed',
|
|
error_message: 'Download failed',
|
|
result: { success: false, error: 'Network error' }
|
|
})
|
|
)
|
|
|
|
dispatch(createDownloadMessage({ status: 'running' }))
|
|
await vi.advanceTimersByTimeAsync(45_000)
|
|
|
|
expect(store.activeDownloads).toHaveLength(0)
|
|
expect(store.finishedDownloads[0].status).toBe('failed')
|
|
expect(store.finishedDownloads[0].error).toBe('Download failed')
|
|
})
|
|
|
|
it('does not complete if task still running', async () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
vi.mocked(taskService.getTask).mockResolvedValue(
|
|
createTaskResponse({ status: 'running', result: undefined })
|
|
)
|
|
|
|
dispatch(createDownloadMessage({ status: 'running' }))
|
|
await vi.advanceTimersByTimeAsync(45_000)
|
|
|
|
expect(taskService.getTask).toHaveBeenCalled()
|
|
expect(store.activeDownloads).toHaveLength(1)
|
|
})
|
|
|
|
it('continues tracking on polling error', async () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
vi.mocked(taskService.getTask).mockRejectedValue(new Error('Not found'))
|
|
dispatch(createDownloadMessage({ status: 'running' }))
|
|
|
|
await vi.advanceTimersByTimeAsync(45_000)
|
|
|
|
expect(store.activeDownloads).toHaveLength(1)
|
|
})
|
|
})
|
|
|
|
describe('clearFinishedDownloads', () => {
|
|
it('removes all finished downloads', () => {
|
|
const store = useAssetDownloadStore()
|
|
|
|
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
|
|
expect(store.finishedDownloads).toHaveLength(1)
|
|
|
|
store.clearFinishedDownloads()
|
|
|
|
expect(store.finishedDownloads).toHaveLength(0)
|
|
})
|
|
})
|
|
})
|