[backport cloud/1.36] feat: add polling fallback for stale asset downloads (#7981)

Backport of #7926 to cloud/1.36

┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-7981-backport-cloud-1-36-feat-add-polling-fallback-for-stale-asset-downloads-2e76d73d365081a983a4e5a8683ae2c9)
by [Unito](https://www.unito.io)

Co-authored-by: Amp <amp@ampcode.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
Alexander Brown
2026-01-12 19:41:57 -08:00
committed by GitHub
parent 2d04cf4757
commit 7f83af391c
8 changed files with 509 additions and 43 deletions

View File

@@ -161,6 +161,8 @@ The project uses **Nx** for build orchestration and task management
## Testing Guidelines
See @docs/testing/*.md for detailed patterns.
- Frameworks:
- Vitest (unit/component, happy-dom)
- Playwright (E2E)

View File

@@ -17,6 +17,7 @@ function createMockJob(overrides: Partial<AssetDownload> = {}): AssetDownload {
bytesDownloaded: 0,
progress: 0,
status: 'created',
lastUpdate: Date.now(),
...overrides
}
}

View File

@@ -29,6 +29,7 @@ function createMockJob(overrides: Partial<AssetDownload> = {}): AssetDownload {
bytesDownloaded: 0,
progress: 0,
status: 'created',
lastUpdate: Date.now(),
...overrides
}
}

View File

@@ -0,0 +1,70 @@
/**
* Task Service for polling background task status.
*
* CAVEAT: The `payload` and `result` schemas below are specific to
* `task:download_file` tasks. Other task types may have different
* payload/result structures. We are not generalizing this until
* additional use cases arise.
*/
import { z } from 'zod'
import { fromZodError } from 'zod-validation-error'
import { api } from '@/scripts/api'
const TASKS_ENDPOINT = '/tasks'
const zTaskStatus = z.enum(['created', 'running', 'completed', 'failed'])
const zDownloadFileResult = z.object({
success: z.boolean(),
file_path: z.string().optional(),
bytes_downloaded: z.number().optional(),
content_type: z.string().optional(),
hash: z.string().optional(),
filename: z.string().optional(),
asset_id: z.string().optional(),
metadata: z.record(z.unknown()).optional(),
error: z.string().optional()
})
const zTaskResponse = z.object({
id: z.string().uuid(),
idempotency_key: z.string(),
task_name: z.string(),
payload: z.record(z.unknown()),
status: zTaskStatus,
result: zDownloadFileResult.optional(),
error_message: z.string().optional(),
create_time: z.string().datetime(),
update_time: z.string().datetime(),
started_at: z.string().datetime().optional(),
completed_at: z.string().datetime().optional()
})
export type TaskResponse = z.infer<typeof zTaskResponse>
function createTaskService() {
async function getTask(taskId: string): Promise<TaskResponse> {
const res = await api.fetchApi(`${TASKS_ENDPOINT}/${taskId}`)
if (!res.ok) {
if (res.status === 404) {
throw new Error(`Task not found: ${taskId}`)
}
throw new Error(`Failed to get task ${taskId}: ${res.status}`)
}
const data = await res.json()
const result = zTaskResponse.safeParse(data)
if (!result.success) {
throw new Error(fromZodError(result.error).message)
}
return result.data
}
return { getTask }
}
export const taskService = createTaskService()

View File

@@ -137,12 +137,12 @@ const zFeatureFlagsWsMessage = z.record(z.string(), z.any())
const zAssetDownloadWsMessage = z.object({
task_id: z.string(),
asset_id: z.string(),
asset_name: z.string(),
bytes_total: z.number(),
bytes_downloaded: z.number(),
progress: z.number(),
status: z.enum(['created', 'running', 'completed', 'failed']),
asset_id: z.string().optional(),
error: z.string().optional()
})

View File

@@ -0,0 +1,225 @@
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import type { TaskResponse } from '@/platform/tasks/services/taskService'
import { taskService } from '@/platform/tasks/services/taskService'
import type { AssetDownloadWsMessage } from '@/schemas/apiSchema'
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
type DownloadEventHandler = (e: CustomEvent<AssetDownloadWsMessage>) => void
const eventHandler = vi.hoisted(() => {
const state: { current: DownloadEventHandler | null } = { current: null }
return state
})
vi.mock('@/scripts/api', () => ({
api: {
addEventListener: vi.fn((_event: string, handler: DownloadEventHandler) => {
eventHandler.current = handler
}),
removeEventListener: vi.fn()
}
}))
vi.mock('@/platform/tasks/services/taskService', () => ({
taskService: {
getTask: vi.fn()
}
}))
function createDownloadMessage(
overrides: Partial<AssetDownloadWsMessage> = {}
): AssetDownloadWsMessage {
return {
task_id: 'task-123',
asset_id: 'asset-456',
asset_name: 'model.safetensors',
bytes_total: 1000,
bytes_downloaded: 500,
progress: 50,
status: 'running',
...overrides
}
}
function dispatch(msg: AssetDownloadWsMessage) {
if (!eventHandler.current) {
throw new Error(
'Event handler not registered. Call useAssetDownloadStore() first.'
)
}
eventHandler.current(new CustomEvent('asset_download', { detail: msg }))
}
describe('useAssetDownloadStore', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.useFakeTimers({ shouldAdvanceTime: false })
vi.resetAllMocks()
eventHandler.current = null
})
afterEach(() => {
vi.useRealTimers()
})
describe('handleAssetDownload', () => {
it('tracks running downloads', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage())
expect(store.activeDownloads).toHaveLength(1)
expect(store.activeDownloads[0].taskId).toBe('task-123')
expect(store.activeDownloads[0].progress).toBe(50)
})
it('moves download to finished when completed', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'running' }))
expect(store.activeDownloads).toHaveLength(1)
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('completed')
})
it('moves download to finished when failed', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'running' }))
dispatch(
createDownloadMessage({ status: 'failed', error: 'Network error' })
)
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads).toHaveLength(1)
expect(store.finishedDownloads[0].status).toBe('failed')
expect(store.finishedDownloads[0].error).toBe('Network error')
})
it('ignores duplicate terminal state messages', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.finishedDownloads).toHaveLength(1)
})
})
describe('trackDownload', () => {
it('associates task with model type for completion tracking', () => {
const store = useAssetDownloadStore()
store.trackDownload('task-123', 'checkpoints')
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.completedDownloads).toHaveLength(1)
expect(store.completedDownloads[0]).toMatchObject({
taskId: 'task-123',
modelType: 'checkpoints'
})
})
})
describe('stale download polling', () => {
function createTaskResponse(
overrides: Partial<TaskResponse> = {}
): TaskResponse {
return {
id: 'task-123',
idempotency_key: 'key-123',
task_name: 'task:download_file',
payload: {},
status: 'completed',
create_time: new Date().toISOString(),
update_time: new Date().toISOString(),
result: {
success: true,
asset_id: 'asset-456',
filename: 'model.safetensors',
bytes_downloaded: 1000
},
...overrides
}
}
it('polls and completes stale downloads', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockResolvedValue(createTaskResponse())
dispatch(createDownloadMessage({ status: 'running' }))
expect(store.activeDownloads).toHaveLength(1)
await vi.advanceTimersByTimeAsync(45_000)
expect(taskService.getTask).toHaveBeenCalledWith('task-123')
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads[0].status).toBe('completed')
})
it('polls and marks failed downloads', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockResolvedValue(
createTaskResponse({
status: 'failed',
error_message: 'Download failed',
result: { success: false, error: 'Network error' }
})
)
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)
expect(store.activeDownloads).toHaveLength(0)
expect(store.finishedDownloads[0].status).toBe('failed')
expect(store.finishedDownloads[0].error).toBe('Download failed')
})
it('does not complete if task still running', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockResolvedValue(
createTaskResponse({ status: 'running', result: undefined })
)
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)
expect(taskService.getTask).toHaveBeenCalled()
expect(store.activeDownloads).toHaveLength(1)
})
it('continues tracking on polling error', async () => {
const store = useAssetDownloadStore()
vi.mocked(taskService.getTask).mockRejectedValue(new Error('Not found'))
dispatch(createDownloadMessage({ status: 'running' }))
await vi.advanceTimersByTimeAsync(45_000)
expect(store.activeDownloads).toHaveLength(1)
})
})
describe('clearFinishedDownloads', () => {
it('removes all finished downloads', () => {
const store = useAssetDownloadStore()
dispatch(createDownloadMessage({ status: 'completed', progress: 100 }))
expect(store.finishedDownloads).toHaveLength(1)
store.clearFinishedDownloads()
expect(store.finishedDownloads).toHaveLength(0)
})
})
})

View File

@@ -1,17 +1,20 @@
import { useIntervalFn } from '@vueuse/core'
import { defineStore } from 'pinia'
import { computed, ref } from 'vue'
import { computed, ref, watch } from 'vue'
import { taskService } from '@/platform/tasks/services/taskService'
import type { AssetDownloadWsMessage } from '@/schemas/apiSchema'
import { api } from '@/scripts/api'
export interface AssetDownload {
taskId: string
assetId: string
assetName: string
bytesTotal: number
bytesDownloaded: number
progress: number
status: 'created' | 'running' | 'completed' | 'failed'
lastUpdate: number
assetId?: string
error?: string
}
@@ -21,20 +24,13 @@ interface CompletedDownload {
timestamp: number
}
const PROCESSED_TASK_CLEANUP_MS = 60000
const MAX_COMPLETED_DOWNLOADS = 10
const STALE_THRESHOLD_MS = 10_000
const POLL_INTERVAL_MS = 10_000
export const useAssetDownloadStore = defineStore('assetDownload', () => {
/** Map of task IDs to their download progress data */
const downloads = ref<Map<string, AssetDownload>>(new Map())
/** Map of task IDs to model types, used to track which model type to refresh after download completes */
const pendingModelTypes = new Map<string, string>()
/** Set of task IDs that have reached a terminal state (completed/failed), prevents duplicate processing */
const processedTaskIds = new Set<string>()
/** Reactive signal for completed downloads */
const completedDownloads = ref<CompletedDownload[]>([])
const downloadList = computed(() => Array.from(downloads.value.values()))
@@ -51,24 +47,17 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => {
const hasActiveDownloads = computed(() => activeDownloads.value.length > 0)
const hasDownloads = computed(() => downloads.value.size > 0)
/**
* Associates a download task with its model type for later use when the download completes.
* Intended for external callers (e.g., useUploadModelWizard) to register async downloads.
*/
function trackDownload(taskId: string, modelType: string) {
pendingModelTypes.set(taskId, modelType)
}
/**
* Handles asset download WebSocket events. Updates download progress, manages toast notifications,
* and tracks completed downloads. Prevents duplicate processing of terminal states (completed/failed).
*/
function handleAssetDownload(e: CustomEvent<AssetDownloadWsMessage>) {
const data = e.detail
const existing = downloads.value.get(data.task_id)
if (data.status === 'completed' || data.status === 'failed') {
if (processedTaskIds.has(data.task_id)) return
processedTaskIds.add(data.task_id)
// Skip if already in terminal state
if (existing?.status === 'completed' || existing?.status === 'failed') {
return
}
const download: AssetDownload = {
@@ -79,7 +68,8 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => {
bytesDownloaded: data.bytes_downloaded,
progress: data.progress,
status: data.status,
error: data.error
error: data.error,
lastUpdate: Date.now()
}
downloads.value.set(data.task_id, download)
@@ -87,33 +77,72 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => {
if (data.status === 'completed') {
const modelType = pendingModelTypes.get(data.task_id)
if (modelType) {
const newDownload: CompletedDownload = {
taskId: data.task_id,
modelType,
timestamp: Date.now()
}
const updated = [...completedDownloads.value, newDownload]
if (updated.length > MAX_COMPLETED_DOWNLOADS) {
updated.shift()
}
const updated = [
...completedDownloads.value,
{ taskId: data.task_id, modelType, timestamp: Date.now() }
]
if (updated.length > MAX_COMPLETED_DOWNLOADS) updated.shift()
completedDownloads.value = updated
pendingModelTypes.delete(data.task_id)
}
setTimeout(
() => processedTaskIds.delete(data.task_id),
PROCESSED_TASK_CLEANUP_MS
)
} else if (data.status === 'failed') {
pendingModelTypes.delete(data.task_id)
setTimeout(
() => processedTaskIds.delete(data.task_id),
PROCESSED_TASK_CLEANUP_MS
)
}
}
async function pollStaleDownloads() {
const now = Date.now()
const staleDownloads = activeDownloads.value.filter(
(d) => now - d.lastUpdate >= STALE_THRESHOLD_MS
)
if (staleDownloads.length === 0) return
async function pollSingleDownload(download: AssetDownload) {
try {
const task = await taskService.getTask(download.taskId)
if (task.status === 'completed' || task.status === 'failed') {
const result = task.result
handleAssetDownload(
new CustomEvent('asset_download', {
detail: {
task_id: download.taskId,
asset_id: result?.asset_id ?? download.assetId,
asset_name: result?.filename ?? download.assetName,
bytes_total: download.bytesTotal,
bytes_downloaded:
result?.bytes_downloaded ?? download.bytesTotal,
progress: task.status === 'completed' ? 100 : download.progress,
status: task.status,
error: task.error_message ?? result?.error
}
})
)
}
} catch {
// Task not ready or not found
}
}
await Promise.all(staleDownloads.map(pollSingleDownload))
}
const { pause, resume } = useIntervalFn(
() => void pollStaleDownloads(),
POLL_INTERVAL_MS,
{ immediate: false }
)
watch(
hasActiveDownloads,
(hasActive) => {
if (hasActive) resume()
else pause()
},
{ immediate: true }
)
api.addEventListener('asset_download', handleAssetDownload)
function clearFinishedDownloads() {

138
tests-ui/vitest-patterns.md Normal file
View File

@@ -0,0 +1,138 @@
---
globs:
- '**/*.test.ts'
- '**/*.spec.ts'
---
# Vitest Patterns
## Setup
Use `createTestingPinia` from `@pinia/testing`, not `createPinia`:
```typescript
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
describe('MyStore', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
vi.useFakeTimers()
vi.resetAllMocks()
})
afterEach(() => {
vi.useRealTimers()
})
})
```
**Why `stubActions: false`?** By default, testing pinia stubs all actions. Set to `false` when testing actual store behavior.
## Mock Patterns
### Reset all mocks at once
```typescript
beforeEach(() => {
vi.resetAllMocks() // Not individual mock.mockReset() calls
})
```
### Module mocks with vi.mock()
```typescript
vi.mock('@/scripts/api', () => ({
api: {
addEventListener: vi.fn(),
fetchData: vi.fn()
}
}))
vi.mock('@/services/myService', () => ({
myService: {
doThing: vi.fn()
}
}))
```
### Configure mocks in tests
```typescript
import { api } from '@/scripts/api'
import { myService } from '@/services/myService'
it('handles success', () => {
vi.mocked(myService.doThing).mockResolvedValue({ data: 'test' })
// ... test code
})
```
## Testing Event Listeners
When a store registers event listeners at module load time:
```typescript
function getEventHandler() {
const call = vi.mocked(api.addEventListener).mock.calls.find(
([event]) => event === 'my_event'
)
return call?.[1] as (e: CustomEvent<MyEventType>) => void
}
function dispatch(data: MyEventType) {
const handler = getEventHandler()
handler(new CustomEvent('my_event', { detail: data }))
}
it('handles events', () => {
const store = useMyStore()
dispatch({ field: 'value' })
expect(store.items).toHaveLength(1)
})
```
## Testing with Fake Timers
For stores with intervals, timeouts, or polling:
```typescript
beforeEach(() => {
vi.useFakeTimers()
})
afterEach(() => {
vi.useRealTimers()
})
it('polls after delay', async () => {
const store = useMyStore()
store.startPolling()
await vi.advanceTimersByTimeAsync(30000)
expect(mockService.fetch).toHaveBeenCalled()
})
```
## Assertion Style
Prefer `.toHaveLength()` over `.length.toBe()`:
```typescript
// Good
expect(store.items).toHaveLength(1)
// Avoid
expect(store.items.length).toBe(1)
```
Use `.toMatchObject()` for partial matching:
```typescript
expect(store.completedItems[0]).toMatchObject({
id: 'task-123',
status: 'done'
})
```