feat: add polling fallback for stale asset downloads (#7926)

## Summary

Adds a polling fallback mechanism to recover from dropped WebSocket
messages during model downloads.

## Problem

When downloading models via the asset download service, status updates
are received over WebSocket. Sometimes these messages are dropped
(network issues, reconnection, etc.), causing downloads to appear
"stuck" even when they've completed on the backend.

## Solution

Periodically poll for stale downloads using the existing REST API:

- Track `lastUpdate` timestamp on each download
- Downloads without updates for 10s are considered "stale"
- Poll stale downloads every 10s via `GET /tasks/{task_id}` to check if
the asset exists
- If the asset exists with size > 0, mark the download as completed

## Implementation

- Added `lastUpdate` field to `AssetDownload` interface
- Use VueUse's `useIntervalFn` with a `watch` to auto start/stop polling
based on active downloads
- Reuse existing `handleAssetDownload` for completion (synthetic event)
- Added 9 unit tests covering the polling behavior

## Testing

- All existing tests pass
- New tests cover:
  - Basic download tracking
  - Completion/failure handling  
  - Duplicate message prevention
  - Stale download polling
  - Polling error handling

┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-7926-feat-add-polling-fallback-for-stale-asset-downloads-2e36d73d3650810ea966f5480f08b60c)
by [Unito](https://www.unito.io)

---------

Co-authored-by: Amp <amp@ampcode.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
Alexander Brown
2026-01-09 16:23:12 -08:00
committed by GitHub
parent 5029a0b32c
commit 41ffb7c627
8 changed files with 509 additions and 43 deletions

View File

@@ -17,6 +17,7 @@ function createMockJob(overrides: Partial<AssetDownload> = {}): AssetDownload {
bytesDownloaded: 0,
progress: 0,
status: 'created',
lastUpdate: Date.now(),
...overrides
}
}

View File

@@ -29,6 +29,7 @@ function createMockJob(overrides: Partial<AssetDownload> = {}): AssetDownload {
bytesDownloaded: 0,
progress: 0,
status: 'created',
lastUpdate: Date.now(),
...overrides
}
}

View File

@@ -0,0 +1,70 @@
/**
* Task Service for polling background task status.
*
* CAVEAT: The `payload` and `result` schemas below are specific to
* `task:download_file` tasks. Other task types may have different
* payload/result structures. We are not generalizing this until
* additional use cases arise.
*/
import { z } from 'zod'
import { fromZodError } from 'zod-validation-error'
import { api } from '@/scripts/api'
const TASKS_ENDPOINT = '/tasks'
const zTaskStatus = z.enum(['created', 'running', 'completed', 'failed'])
const zDownloadFileResult = z.object({
success: z.boolean(),
file_path: z.string().optional(),
bytes_downloaded: z.number().optional(),
content_type: z.string().optional(),
hash: z.string().optional(),
filename: z.string().optional(),
asset_id: z.string().optional(),
metadata: z.record(z.unknown()).optional(),
error: z.string().optional()
})
const zTaskResponse = z.object({
id: z.string().uuid(),
idempotency_key: z.string(),
task_name: z.string(),
payload: z.record(z.unknown()),
status: zTaskStatus,
result: zDownloadFileResult.optional(),
error_message: z.string().optional(),
create_time: z.string().datetime(),
update_time: z.string().datetime(),
started_at: z.string().datetime().optional(),
completed_at: z.string().datetime().optional()
})
export type TaskResponse = z.infer<typeof zTaskResponse>
function createTaskService() {
async function getTask(taskId: string): Promise<TaskResponse> {
const res = await api.fetchApi(`${TASKS_ENDPOINT}/${taskId}`)
if (!res.ok) {
if (res.status === 404) {
throw new Error(`Task not found: ${taskId}`)
}
throw new Error(`Failed to get task ${taskId}: ${res.status}`)
}
const data = await res.json()
const result = zTaskResponse.safeParse(data)
if (!result.success) {
throw new Error(fromZodError(result.error).message)
}
return result.data
}
return { getTask }
}
export const taskService = createTaskService()

View File

@@ -137,12 +137,12 @@ const zFeatureFlagsWsMessage = z.record(z.string(), z.any())
const zAssetDownloadWsMessage = z.object({
task_id: z.string(),
asset_id: z.string(),
asset_name: z.string(),
bytes_total: z.number(),
bytes_downloaded: z.number(),
progress: z.number(),
status: z.enum(['created', 'running', 'completed', 'failed']),
asset_id: z.string().optional(),
error: z.string().optional()
})

View File

@@ -0,0 +1,225 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import type { TaskResponse } from '@/platform/tasks/services/taskService'
import { taskService } from '@/platform/tasks/services/taskService'
import type { AssetDownloadWsMessage } from '@/schemas/apiSchema'
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
type DownloadEventHandler = (e: CustomEvent<AssetDownloadWsMessage>) => void
const eventHandler = vi.hoisted(() => {
const state: { current: DownloadEventHandler | null } = { current: null }
return state
})
vi.mock('@/scripts/api', () => ({
api: {
addEventListener: vi.fn((_event: string, handler: DownloadEventHandler) => {
eventHandler.current = handler
}),
removeEventListener: vi.fn()
}
}))
vi.mock('@/platform/tasks/services/taskService', () => ({
taskService: {
getTask: vi.fn()
}
}))
function createDownloadMessage(
overrides: Partial<AssetDownloadWsMessage> = {}
): AssetDownloadWsMessage {
return {
task_id: 'task-123',
asset_id: 'asset-456',
asset_name: 'model.safetensors',
bytes_total: 1000,
bytes_downloaded: 500,
progress: 50,
status: 'running',
...overrides
}
}
function dispatch(msg: AssetDownloadWsMessage) {
if (!eventHandler.current) {
throw new Error(
'Event handler not registered. Call useAssetDownloadStore() first.'
)
}
eventHandler.current(new CustomEvent('asset_download', { detail: msg }))
}
describe('useAssetDownloadStore', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.useFakeTimers({ shouldAdvanceTime: false })
vi.resetAllMocks()
eventHandler.current = null
})
afterEach(() => {
vi.useRealTimers()
})
describe('handleAssetDownload', () => {
it('tracks running downloads', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage())
expect(store.activeDownloads).toHaveLength(1)
expect(store.activeDownloads[0].taskId).toBe('task-123')
expect(store.activeDownloads[0].progress).toBe(50)
})
it('moves download to finished when completed', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'running' }))
expect(store.activeDownloads).toHaveLength(1)
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('completed')
})
it('moves download to finished when failed', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'running' }))
dispatch(
createDownloadMessage({ status: 'failed', error: 'Network error' })
)
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('failed')
expect(store.finishedDownloads[0].error).toBe('Network error')
})
it('ignores duplicate terminal state messages', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.finishedDownloads).toHaveLength(1)
})
})
describe('trackDownload', () => {
it('associates task with model type for completion tracking', () => {
const store = useAssetDownloadStore()
store.trackDownload('task-123', 'checkpoints')
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.completedDownloads).toHaveLength(1)
expect(store.completedDownloads[0]).toMatchObject({
taskId: 'task-123',
modelType: 'checkpoints'
})
})
})
describe('stale download polling', () => {
function createTaskResponse(
overrides: Partial<TaskResponse> = {}
): TaskResponse {
return {
id: 'task-123',
idempotency_key: 'key-123',
task_name: 'task:download_file',
payload: {},
status: 'completed',
create_time: new Date().toISOString(),
update_time: new Date().toISOString(),
result: {
success: true,
asset_id: 'asset-456',
filename: 'model.safetensors',
bytes_downloaded: 1000
},
...overrides
}
}
it('polls and completes stale downloads', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockResolvedValue(createTaskResponse())
dispatch(createDownloadMessage({ status: 'running' }))
expect(store.activeDownloads).toHaveLength(1)
await vi.advanceTimersByTimeAsync(45_000)
expect(taskService.getTask).toHaveBeenCalledWith('task-123')
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads[0].status).toBe('completed')
})
it('polls and marks failed downloads', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockResolvedValue(
createTaskResponse({
status: 'failed',
error_message: 'Download failed',
result: { success: false, error: 'Network error' }
})
)
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads[0].status).toBe('failed')
expect(store.finishedDownloads[0].error).toBe('Download failed')
})
it('does not complete if task still running', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockResolvedValue(
createTaskResponse({ status: 'running', result: undefined })
)
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)
expect(taskService.getTask).toHaveBeenCalled()
expect(store.activeDownloads).toHaveLength(1)
})
it('continues tracking on polling error', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockRejectedValue(new Error('Not found'))
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)
expect(store.activeDownloads).toHaveLength(1)
})
})
describe('clearFinishedDownloads', () => {
it('removes all finished downloads', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.finishedDownloads).toHaveLength(1)
store.clearFinishedDownloads()
expect(store.finishedDownloads).toHaveLength(0)
})
})
})

View File

@@ -1,17 +1,20 @@
import { useIntervalFn } from '@vueuse/core'
import { defineStore } from 'pinia'
import { computed, ref } from 'vue'
import { computed, ref, watch } from 'vue'
import { taskService } from '@/platform/tasks/services/taskService'
import type { AssetDownloadWsMessage } from '@/schemas/apiSchema'
import { api } from '@/scripts/api'
export interface AssetDownload {
taskId: string
assetId: string
assetName: string
bytesTotal: number
bytesDownloaded: number
progress: number
status: 'created' | 'running' | 'completed' | 'failed'
lastUpdate: number
assetId?: string
error?: string
}
@@ -21,20 +24,13 @@ interface CompletedDownload {
timestamp: number
}
const PROCESSED_TASK_CLEANUP_MS = 60000
const MAX_COMPLETED_DOWNLOADS = 10
const STALE_THRESHOLD_MS = 10_000
const POLL_INTERVAL_MS = 10_000
export const useAssetDownloadStore = defineStore('assetDownload', () => {
/** Map of task IDs to their download progress data */
const downloads = ref<Map<string, AssetDownload>>(new Map())
/** Map of task IDs to model types, used to track which model type to refresh after download completes */
const pendingModelTypes = new Map<string, string>()
/** Set of task IDs that have reached a terminal state (completed/failed), prevents duplicate processing */
const processedTaskIds = new Set<string>()
/** Reactive signal for completed downloads */
const completedDownloads = ref<CompletedDownload[]>([])
const downloadList = computed(() => Array.from(downloads.value.values()))
@@ -51,24 +47,17 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => {
const hasActiveDownloads = computed(() => activeDownloads.value.length > 0)
const hasDownloads = computed(() => downloads.value.size > 0)
/**
* Associates a download task with its model type for later use when the download completes.
* Intended for external callers (e.g., useUploadModelWizard) to register async downloads.
*/
function trackDownload(taskId: string, modelType: string) {
pendingModelTypes.set(taskId, modelType)
}
/**
* Handles asset download WebSocket events. Updates download progress, manages toast notifications,
* and tracks completed downloads. Prevents duplicate processing of terminal states (completed/failed).
*/
function handleAssetDownload(e: CustomEvent<AssetDownloadWsMessage>) {
const data = e.detail
const existing = downloads.value.get(data.task_id)
if (data.status === 'completed' || data.status === 'failed') {
if (processedTaskIds.has(data.task_id)) return
processedTaskIds.add(data.task_id)
// Skip if already in terminal state
if (existing?.status === 'completed' || existing?.status === 'failed') {
return
}
const download: AssetDownload = {
@@ -79,7 +68,8 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => {
bytesDownloaded: data.bytes_downloaded,
progress: data.progress,
status: data.status,
error: data.error
error: data.error,
lastUpdate: Date.now()
}
downloads.value.set(data.task_id, download)
@@ -87,33 +77,72 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => {
if (data.status === 'completed') {
const modelType = pendingModelTypes.get(data.task_id)
if (modelType) {
const newDownload: CompletedDownload = {
taskId: data.task_id,
modelType,
timestamp: Date.now()
}
const updated = [...completedDownloads.value, newDownload]
if (updated.length > MAX_COMPLETED_DOWNLOADS) {
updated.shift()
}
const updated = [
...completedDownloads.value,
{ taskId: data.task_id, modelType, timestamp: Date.now() }
]
if (updated.length > MAX_COMPLETED_DOWNLOADS) updated.shift()
completedDownloads.value = updated
pendingModelTypes.delete(data.task_id)
}
setTimeout(
() => processedTaskIds.delete(data.task_id),
PROCESSED_TASK_CLEANUP_MS
)
} else if (data.status === 'failed') {
pendingModelTypes.delete(data.task_id)
setTimeout(
() => processedTaskIds.delete(data.task_id),
PROCESSED_TASK_CLEANUP_MS
)
}
}
async function pollStaleDownloads() {
const now = Date.now()
const staleDownloads = activeDownloads.value.filter(
(d) => now - d.lastUpdate >= STALE_THRESHOLD_MS
)
if (staleDownloads.length === 0) return
async function pollSingleDownload(download: AssetDownload) {
try {
const task = await taskService.getTask(download.taskId)
if (task.status === 'completed' || task.status === 'failed') {
const result = task.result
handleAssetDownload(
new CustomEvent('asset_download', {
detail: {
task_id: download.taskId,
asset_id: result?.asset_id ?? download.assetId,
asset_name: result?.filename ?? download.assetName,
bytes_total: download.bytesTotal,
bytes_downloaded:
result?.bytes_downloaded ?? download.bytesTotal,
progress: task.status === 'completed' ? 100 : download.progress,
status: task.status,
error: task.error_message ?? result?.error
}
})
)
}
} catch {
// Task not ready or not found
}
}
await Promise.all(staleDownloads.map(pollSingleDownload))
}
const { pause, resume } = useIntervalFn(
() => void pollStaleDownloads(),
POLL_INTERVAL_MS,
{ immediate: false }
)
watch(
hasActiveDownloads,
(hasActive) => {
if (hasActive) resume()
else pause()
},
{ immediate: true }
)
api.addEventListener('asset_download', handleAssetDownload)
function clearFinishedDownloads() {