From 41ffb7c627630fa402011a8ba7b3379442a6c67a Mon Sep 17 00:00:00 2001 From: Alexander Brown Date: Fri, 9 Jan 2026 16:23:12 -0800 Subject: [PATCH] feat: add polling fallback for stale asset downloads (#7926) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Adds a polling fallback mechanism to recover from dropped WebSocket messages during model downloads. ## Problem When downloading models via the asset download service, status updates are received over WebSocket. Sometimes these messages are dropped (network issues, reconnection, etc.), causing downloads to appear "stuck" even when they've completed on the backend. ## Solution Periodically poll for stale downloads using the existing REST API: - Track `lastUpdate` timestamp on each download - Downloads without updates for 10s are considered "stale" - Poll stale downloads every 10s via `GET /tasks/{task_id}` to check if the asset exists - If the asset exists with size > 0, mark the download as completed ## Implementation - Added `lastUpdate` field to `AssetDownload` interface - Use VueUse's `useIntervalFn` with a `watch` to auto start/stop polling based on active downloads - Reuse existing `handleAssetDownload` for completion (synthetic event) - Added 9 unit tests covering the polling behavior ## Testing - All existing tests pass - New tests cover: - Basic download tracking - Completion/failure handling - Duplicate message prevention - Stale download polling - Polling error handling ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-7926-feat-add-polling-fallback-for-stale-asset-downloads-2e36d73d3650810ea966f5480f08b60c) by [Unito](https://www.unito.io) --------- Co-authored-by: Amp Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- AGENTS.md | 2 + docs/testing/vitest-patterns.md | 138 +++++++++++ .../honeyToast/HoneyToast.stories.ts | 1 + .../toast/ProgressToastItem.stories.ts | 1 + src/platform/tasks/services/taskService.ts | 70 ++++++ src/schemas/apiSchema.ts | 2 +- src/stores/assetDownloadStore.test.ts | 225 ++++++++++++++++++ src/stores/assetDownloadStore.ts | 113 +++++---- 8 files changed, 509 insertions(+), 43 deletions(-) create mode 100644 docs/testing/vitest-patterns.md create mode 100644 src/platform/tasks/services/taskService.ts create mode 100644 src/stores/assetDownloadStore.test.ts diff --git a/AGENTS.md b/AGENTS.md index b68246d4f..743572be3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -161,6 +161,8 @@ The project uses **Nx** for build orchestration and task management ## Testing Guidelines +See @docs/testing/*.md for detailed patterns. + - Frameworks: - Vitest (unit/component, happy-dom) - Playwright (E2E) diff --git a/docs/testing/vitest-patterns.md b/docs/testing/vitest-patterns.md new file mode 100644 index 000000000..2eb7c8e09 --- /dev/null +++ b/docs/testing/vitest-patterns.md @@ -0,0 +1,138 @@ +--- +globs: + - '**/*.test.ts' + - '**/*.spec.ts' +--- + +# Vitest Patterns + +## Setup + +Use `createTestingPinia` from `@pinia/testing`, not `createPinia`: + +```typescript +import { createTestingPinia } from '@pinia/testing' +import { setActivePinia } from 'pinia' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +describe('MyStore', () => { + beforeEach(() => { + setActivePinia(createTestingPinia({ stubActions: false })) + vi.useFakeTimers() + vi.resetAllMocks() + }) + + afterEach(() => { + vi.useRealTimers() + }) +}) +``` + +**Why `stubActions: false`?** By default, testing pinia stubs all actions. Set to `false` when testing actual store behavior. + +## Mock Patterns + +### Reset all mocks at once + +```typescript +beforeEach(() => { + vi.resetAllMocks() // Not individual mock.mockReset() calls +}) +``` + +### Module mocks with vi.mock() + +```typescript +vi.mock('@/scripts/api', () => ({ + api: { + addEventListener: vi.fn(), + fetchData: vi.fn() + } +})) + +vi.mock('@/services/myService', () => ({ + myService: { + doThing: vi.fn() + } +})) +``` + +### Configure mocks in tests + +```typescript +import { api } from '@/scripts/api' +import { myService } from '@/services/myService' + +it('handles success', () => { + vi.mocked(myService.doThing).mockResolvedValue({ data: 'test' }) + // ... test code +}) +``` + +## Testing Event Listeners + +When a store registers event listeners at module load time: + +```typescript +function getEventHandler() { + const call = vi.mocked(api.addEventListener).mock.calls.find( + ([event]) => event === 'my_event' + ) + return call?.[1] as (e: CustomEvent) => void +} + +function dispatch(data: MyEventType) { + const handler = getEventHandler() + handler(new CustomEvent('my_event', { detail: data })) +} + +it('handles events', () => { + const store = useMyStore() + dispatch({ field: 'value' }) + expect(store.items).toHaveLength(1) +}) +``` + +## Testing with Fake Timers + +For stores with intervals, timeouts, or polling: + +```typescript +beforeEach(() => { + vi.useFakeTimers() +}) + +afterEach(() => { + vi.useRealTimers() +}) + +it('polls after delay', async () => { + const store = useMyStore() + store.startPolling() + + await vi.advanceTimersByTimeAsync(30000) + + expect(mockService.fetch).toHaveBeenCalled() +}) +``` + +## Assertion Style + +Prefer `.toHaveLength()` over `.length.toBe()`: + +```typescript +// Good +expect(store.items).toHaveLength(1) + +// Avoid +expect(store.items.length).toBe(1) +``` + +Use `.toMatchObject()` for partial matching: + +```typescript +expect(store.completedItems[0]).toMatchObject({ + id: 'task-123', + status: 'done' +}) +``` diff --git a/src/components/honeyToast/HoneyToast.stories.ts b/src/components/honeyToast/HoneyToast.stories.ts index 98ae59070..74331d49f 100644 --- a/src/components/honeyToast/HoneyToast.stories.ts +++ b/src/components/honeyToast/HoneyToast.stories.ts @@ -17,6 +17,7 @@ function createMockJob(overrides: Partial = {}): AssetDownload { bytesDownloaded: 0, progress: 0, status: 'created', + lastUpdate: Date.now(), ...overrides } } diff --git a/src/components/toast/ProgressToastItem.stories.ts b/src/components/toast/ProgressToastItem.stories.ts index 2ad376a72..cdfa8e28e 100644 --- a/src/components/toast/ProgressToastItem.stories.ts +++ b/src/components/toast/ProgressToastItem.stories.ts @@ -29,6 +29,7 @@ function createMockJob(overrides: Partial = {}): AssetDownload { bytesDownloaded: 0, progress: 0, status: 'created', + lastUpdate: Date.now(), ...overrides } } diff --git a/src/platform/tasks/services/taskService.ts b/src/platform/tasks/services/taskService.ts new file mode 100644 index 000000000..7dc1e62db --- /dev/null +++ b/src/platform/tasks/services/taskService.ts @@ -0,0 +1,70 @@ +/** + * Task Service for polling background task status. + * + * CAVEAT: The `payload` and `result` schemas below are specific to + * `task:download_file` tasks. Other task types may have different + * payload/result structures. We are not generalizing this until + * additional use cases arise. + */ +import { z } from 'zod' +import { fromZodError } from 'zod-validation-error' + +import { api } from '@/scripts/api' + +const TASKS_ENDPOINT = '/tasks' + +const zTaskStatus = z.enum(['created', 'running', 'completed', 'failed']) + +const zDownloadFileResult = z.object({ + success: z.boolean(), + file_path: z.string().optional(), + bytes_downloaded: z.number().optional(), + content_type: z.string().optional(), + hash: z.string().optional(), + filename: z.string().optional(), + asset_id: z.string().optional(), + metadata: z.record(z.unknown()).optional(), + error: z.string().optional() +}) + +const zTaskResponse = z.object({ + id: z.string().uuid(), + idempotency_key: z.string(), + task_name: z.string(), + payload: z.record(z.unknown()), + status: zTaskStatus, + result: zDownloadFileResult.optional(), + error_message: z.string().optional(), + create_time: z.string().datetime(), + update_time: z.string().datetime(), + started_at: z.string().datetime().optional(), + completed_at: z.string().datetime().optional() +}) + +export type TaskResponse = z.infer + +function createTaskService() { + async function getTask(taskId: string): Promise { + const res = await api.fetchApi(`${TASKS_ENDPOINT}/${taskId}`) + + if (!res.ok) { + if (res.status === 404) { + throw new Error(`Task not found: ${taskId}`) + } + throw new Error(`Failed to get task ${taskId}: ${res.status}`) + } + + const data = await res.json() + const result = zTaskResponse.safeParse(data) + + if (!result.success) { + throw new Error(fromZodError(result.error).message) + } + + return result.data + } + + return { getTask } +} + +export const taskService = createTaskService() diff --git a/src/schemas/apiSchema.ts b/src/schemas/apiSchema.ts index fc7597c34..6e01d6f4e 100644 --- a/src/schemas/apiSchema.ts +++ b/src/schemas/apiSchema.ts @@ -137,12 +137,12 @@ const zFeatureFlagsWsMessage = z.record(z.string(), z.any()) const zAssetDownloadWsMessage = z.object({ task_id: z.string(), - asset_id: z.string(), asset_name: z.string(), bytes_total: z.number(), bytes_downloaded: z.number(), progress: z.number(), status: z.enum(['created', 'running', 'completed', 'failed']), + asset_id: z.string().optional(), error: z.string().optional() }) diff --git a/src/stores/assetDownloadStore.test.ts b/src/stores/assetDownloadStore.test.ts new file mode 100644 index 000000000..28d58b9c1 --- /dev/null +++ b/src/stores/assetDownloadStore.test.ts @@ -0,0 +1,225 @@ +import { createTestingPinia } from '@pinia/testing' +import { setActivePinia } from 'pinia' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import type { TaskResponse } from '@/platform/tasks/services/taskService' +import { taskService } from '@/platform/tasks/services/taskService' +import type { AssetDownloadWsMessage } from '@/schemas/apiSchema' +import { useAssetDownloadStore } from '@/stores/assetDownloadStore' + +type DownloadEventHandler = (e: CustomEvent) => void + +const eventHandler = vi.hoisted(() => { + const state: { current: DownloadEventHandler | null } = { current: null } + return state +}) + +vi.mock('@/scripts/api', () => ({ + api: { + addEventListener: vi.fn((_event: string, handler: DownloadEventHandler) => { + eventHandler.current = handler + }), + removeEventListener: vi.fn() + } +})) + +vi.mock('@/platform/tasks/services/taskService', () => ({ + taskService: { + getTask: vi.fn() + } +})) + +function createDownloadMessage( + overrides: Partial = {} +): AssetDownloadWsMessage { + return { + task_id: 'task-123', + asset_id: 'asset-456', + asset_name: 'model.safetensors', + bytes_total: 1000, + bytes_downloaded: 500, + progress: 50, + status: 'running', + ...overrides + } +} + +function dispatch(msg: AssetDownloadWsMessage) { + if (!eventHandler.current) { + throw new Error( + 'Event handler not registered. Call useAssetDownloadStore() first.' + ) + } + eventHandler.current(new CustomEvent('asset_download', { detail: msg })) +} + +describe('useAssetDownloadStore', () => { + beforeEach(() => { + setActivePinia(createTestingPinia({ stubActions: false })) + vi.useFakeTimers({ shouldAdvanceTime: false }) + vi.resetAllMocks() + eventHandler.current = null + }) + + afterEach(() => { + vi.useRealTimers() + }) + + describe('handleAssetDownload', () => { + it('tracks running downloads', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage()) + + expect(store.activeDownloads).toHaveLength(1) + expect(store.activeDownloads[0].taskId).toBe('task-123') + expect(store.activeDownloads[0].progress).toBe(50) + }) + + it('moves download to finished when completed', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage({ status: 'running' })) + expect(store.activeDownloads).toHaveLength(1) + + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + + expect(store.activeDownloads).toHaveLength(0) + expect(store.finishedDownloads).toHaveLength(1) + expect(store.finishedDownloads[0].status).toBe('completed') + }) + + it('moves download to finished when failed', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage({ status: 'running' })) + dispatch( + createDownloadMessage({ status: 'failed', error: 'Network error' }) + ) + + expect(store.activeDownloads).toHaveLength(0) + expect(store.finishedDownloads).toHaveLength(1) + expect(store.finishedDownloads[0].status).toBe('failed') + expect(store.finishedDownloads[0].error).toBe('Network error') + }) + + it('ignores duplicate terminal state messages', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + + expect(store.finishedDownloads).toHaveLength(1) + }) + }) + + describe('trackDownload', () => { + it('associates task with model type for completion tracking', () => { + const store = useAssetDownloadStore() + + store.trackDownload('task-123', 'checkpoints') + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + + expect(store.completedDownloads).toHaveLength(1) + expect(store.completedDownloads[0]).toMatchObject({ + taskId: 'task-123', + modelType: 'checkpoints' + }) + }) + }) + + describe('stale download polling', () => { + function createTaskResponse( + overrides: Partial = {} + ): TaskResponse { + return { + id: 'task-123', + idempotency_key: 'key-123', + task_name: 'task:download_file', + payload: {}, + status: 'completed', + create_time: new Date().toISOString(), + update_time: new Date().toISOString(), + result: { + success: true, + asset_id: 'asset-456', + filename: 'model.safetensors', + bytes_downloaded: 1000 + }, + ...overrides + } + } + + it('polls and completes stale downloads', async () => { + const store = useAssetDownloadStore() + + vi.mocked(taskService.getTask).mockResolvedValue(createTaskResponse()) + + dispatch(createDownloadMessage({ status: 'running' })) + expect(store.activeDownloads).toHaveLength(1) + + await vi.advanceTimersByTimeAsync(45_000) + + expect(taskService.getTask).toHaveBeenCalledWith('task-123') + expect(store.activeDownloads).toHaveLength(0) + expect(store.finishedDownloads[0].status).toBe('completed') + }) + + it('polls and marks failed downloads', async () => { + const store = useAssetDownloadStore() + + vi.mocked(taskService.getTask).mockResolvedValue( + createTaskResponse({ + status: 'failed', + error_message: 'Download failed', + result: { success: false, error: 'Network error' } + }) + ) + + dispatch(createDownloadMessage({ status: 'running' })) + await vi.advanceTimersByTimeAsync(45_000) + + expect(store.activeDownloads).toHaveLength(0) + expect(store.finishedDownloads[0].status).toBe('failed') + expect(store.finishedDownloads[0].error).toBe('Download failed') + }) + + it('does not complete if task still running', async () => { + const store = useAssetDownloadStore() + + vi.mocked(taskService.getTask).mockResolvedValue( + createTaskResponse({ status: 'running', result: undefined }) + ) + + dispatch(createDownloadMessage({ status: 'running' })) + await vi.advanceTimersByTimeAsync(45_000) + + expect(taskService.getTask).toHaveBeenCalled() + expect(store.activeDownloads).toHaveLength(1) + }) + + it('continues tracking on polling error', async () => { + const store = useAssetDownloadStore() + + vi.mocked(taskService.getTask).mockRejectedValue(new Error('Not found')) + dispatch(createDownloadMessage({ status: 'running' })) + + await vi.advanceTimersByTimeAsync(45_000) + + expect(store.activeDownloads).toHaveLength(1) + }) + }) + + describe('clearFinishedDownloads', () => { + it('removes all finished downloads', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + expect(store.finishedDownloads).toHaveLength(1) + + store.clearFinishedDownloads() + + expect(store.finishedDownloads).toHaveLength(0) + }) + }) +}) diff --git a/src/stores/assetDownloadStore.ts b/src/stores/assetDownloadStore.ts index 91bf69015..8aa5788bd 100644 --- a/src/stores/assetDownloadStore.ts +++ b/src/stores/assetDownloadStore.ts @@ -1,17 +1,20 @@ +import { useIntervalFn } from '@vueuse/core' import { defineStore } from 'pinia' -import { computed, ref } from 'vue' +import { computed, ref, watch } from 'vue' +import { taskService } from '@/platform/tasks/services/taskService' import type { AssetDownloadWsMessage } from '@/schemas/apiSchema' import { api } from '@/scripts/api' export interface AssetDownload { taskId: string - assetId: string assetName: string bytesTotal: number bytesDownloaded: number progress: number status: 'created' | 'running' | 'completed' | 'failed' + lastUpdate: number + assetId?: string error?: string } @@ -21,20 +24,13 @@ interface CompletedDownload { timestamp: number } -const PROCESSED_TASK_CLEANUP_MS = 60000 const MAX_COMPLETED_DOWNLOADS = 10 +const STALE_THRESHOLD_MS = 10_000 +const POLL_INTERVAL_MS = 10_000 export const useAssetDownloadStore = defineStore('assetDownload', () => { - /** Map of task IDs to their download progress data */ const downloads = ref>(new Map()) - - /** Map of task IDs to model types, used to track which model type to refresh after download completes */ const pendingModelTypes = new Map() - - /** Set of task IDs that have reached a terminal state (completed/failed), prevents duplicate processing */ - const processedTaskIds = new Set() - - /** Reactive signal for completed downloads */ const completedDownloads = ref([]) const downloadList = computed(() => Array.from(downloads.value.values())) @@ -51,24 +47,17 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => { const hasActiveDownloads = computed(() => activeDownloads.value.length > 0) const hasDownloads = computed(() => downloads.value.size > 0) - /** - * Associates a download task with its model type for later use when the download completes. - * Intended for external callers (e.g., useUploadModelWizard) to register async downloads. - */ function trackDownload(taskId: string, modelType: string) { pendingModelTypes.set(taskId, modelType) } - /** - * Handles asset download WebSocket events. Updates download progress, manages toast notifications, - * and tracks completed downloads. Prevents duplicate processing of terminal states (completed/failed). - */ function handleAssetDownload(e: CustomEvent) { const data = e.detail + const existing = downloads.value.get(data.task_id) - if (data.status === 'completed' || data.status === 'failed') { - if (processedTaskIds.has(data.task_id)) return - processedTaskIds.add(data.task_id) + // Skip if already in terminal state + if (existing?.status === 'completed' || existing?.status === 'failed') { + return } const download: AssetDownload = { @@ -79,7 +68,8 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => { bytesDownloaded: data.bytes_downloaded, progress: data.progress, status: data.status, - error: data.error + error: data.error, + lastUpdate: Date.now() } downloads.value.set(data.task_id, download) @@ -87,33 +77,72 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => { if (data.status === 'completed') { const modelType = pendingModelTypes.get(data.task_id) if (modelType) { - const newDownload: CompletedDownload = { - taskId: data.task_id, - modelType, - timestamp: Date.now() - } - - const updated = [...completedDownloads.value, newDownload] - if (updated.length > MAX_COMPLETED_DOWNLOADS) { - updated.shift() - } + const updated = [ + ...completedDownloads.value, + { taskId: data.task_id, modelType, timestamp: Date.now() } + ] + if (updated.length > MAX_COMPLETED_DOWNLOADS) updated.shift() completedDownloads.value = updated - pendingModelTypes.delete(data.task_id) } - setTimeout( - () => processedTaskIds.delete(data.task_id), - PROCESSED_TASK_CLEANUP_MS - ) } else if (data.status === 'failed') { pendingModelTypes.delete(data.task_id) - setTimeout( - () => processedTaskIds.delete(data.task_id), - PROCESSED_TASK_CLEANUP_MS - ) } } + async function pollStaleDownloads() { + const now = Date.now() + const staleDownloads = activeDownloads.value.filter( + (d) => now - d.lastUpdate >= STALE_THRESHOLD_MS + ) + + if (staleDownloads.length === 0) return + + async function pollSingleDownload(download: AssetDownload) { + try { + const task = await taskService.getTask(download.taskId) + + if (task.status === 'completed' || task.status === 'failed') { + const result = task.result + handleAssetDownload( + new CustomEvent('asset_download', { + detail: { + task_id: download.taskId, + asset_id: result?.asset_id ?? download.assetId, + asset_name: result?.filename ?? download.assetName, + bytes_total: download.bytesTotal, + bytes_downloaded: + result?.bytes_downloaded ?? download.bytesTotal, + progress: task.status === 'completed' ? 100 : download.progress, + status: task.status, + error: task.error_message ?? result?.error + } + }) + ) + } + } catch { + // Task not ready or not found + } + } + + await Promise.all(staleDownloads.map(pollSingleDownload)) + } + + const { pause, resume } = useIntervalFn( + () => void pollStaleDownloads(), + POLL_INTERVAL_MS, + { immediate: false } + ) + + watch( + hasActiveDownloads, + (hasActive) => { + if (hasActive) resume() + else pause() + }, + { immediate: true } + ) + api.addEventListener('asset_download', handleAssetDownload) function clearFinishedDownloads() {