diff --git a/AGENTS.md b/AGENTS.md index b68246d4f..743572be3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -161,6 +161,8 @@ The project uses **Nx** for build orchestration and task management ## Testing Guidelines +See @docs/testing/*.md for detailed patterns. + - Frameworks: - Vitest (unit/component, happy-dom) - Playwright (E2E) diff --git a/docs/testing/vitest-patterns.md b/docs/testing/vitest-patterns.md new file mode 100644 index 000000000..2eb7c8e09 --- /dev/null +++ b/docs/testing/vitest-patterns.md @@ -0,0 +1,138 @@ +--- +globs: + - '**/*.test.ts' + - '**/*.spec.ts' +--- + +# Vitest Patterns + +## Setup + +Use `createTestingPinia` from `@pinia/testing`, not `createPinia`: + +```typescript +import { createTestingPinia } from '@pinia/testing' +import { setActivePinia } from 'pinia' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +describe('MyStore', () => { + beforeEach(() => { + setActivePinia(createTestingPinia({ stubActions: false })) + vi.useFakeTimers() + vi.resetAllMocks() + }) + + afterEach(() => { + vi.useRealTimers() + }) +}) +``` + +**Why `stubActions: false`?** By default, testing pinia stubs all actions. Set to `false` when testing actual store behavior. + +## Mock Patterns + +### Reset all mocks at once + +```typescript +beforeEach(() => { + vi.resetAllMocks() // Not individual mock.mockReset() calls +}) +``` + +### Module mocks with vi.mock() + +```typescript +vi.mock('@/scripts/api', () => ({ + api: { + addEventListener: vi.fn(), + fetchData: vi.fn() + } +})) + +vi.mock('@/services/myService', () => ({ + myService: { + doThing: vi.fn() + } +})) +``` + +### Configure mocks in tests + +```typescript +import { api } from '@/scripts/api' +import { myService } from '@/services/myService' + +it('handles success', () => { + vi.mocked(myService.doThing).mockResolvedValue({ data: 'test' }) + // ... test code +}) +``` + +## Testing Event Listeners + +When a store registers event listeners at module load time: + +```typescript +function getEventHandler() { + const call = vi.mocked(api.addEventListener).mock.calls.find( + ([event]) => event === 'my_event' + ) + return call?.[1] as (e: CustomEvent) => void +} + +function dispatch(data: MyEventType) { + const handler = getEventHandler() + handler(new CustomEvent('my_event', { detail: data })) +} + +it('handles events', () => { + const store = useMyStore() + dispatch({ field: 'value' }) + expect(store.items).toHaveLength(1) +}) +``` + +## Testing with Fake Timers + +For stores with intervals, timeouts, or polling: + +```typescript +beforeEach(() => { + vi.useFakeTimers() +}) + +afterEach(() => { + vi.useRealTimers() +}) + +it('polls after delay', async () => { + const store = useMyStore() + store.startPolling() + + await vi.advanceTimersByTimeAsync(30000) + + expect(mockService.fetch).toHaveBeenCalled() +}) +``` + +## Assertion Style + +Prefer `.toHaveLength()` over `.length.toBe()`: + +```typescript +// Good +expect(store.items).toHaveLength(1) + +// Avoid +expect(store.items.length).toBe(1) +``` + +Use `.toMatchObject()` for partial matching: + +```typescript +expect(store.completedItems[0]).toMatchObject({ + id: 'task-123', + status: 'done' +}) +``` diff --git a/src/components/honeyToast/HoneyToast.stories.ts b/src/components/honeyToast/HoneyToast.stories.ts index 98ae59070..74331d49f 100644 --- a/src/components/honeyToast/HoneyToast.stories.ts +++ b/src/components/honeyToast/HoneyToast.stories.ts @@ -17,6 +17,7 @@ function createMockJob(overrides: Partial = {}): AssetDownload { bytesDownloaded: 0, progress: 0, status: 'created', + lastUpdate: Date.now(), ...overrides } } diff --git a/src/components/toast/ProgressToastItem.stories.ts b/src/components/toast/ProgressToastItem.stories.ts index 2ad376a72..cdfa8e28e 100644 --- a/src/components/toast/ProgressToastItem.stories.ts +++ b/src/components/toast/ProgressToastItem.stories.ts @@ -29,6 +29,7 @@ function createMockJob(overrides: Partial = {}): AssetDownload { bytesDownloaded: 0, progress: 0, status: 'created', + lastUpdate: Date.now(), ...overrides } } diff --git a/src/platform/tasks/services/taskService.ts b/src/platform/tasks/services/taskService.ts new file mode 100644 index 000000000..7dc1e62db --- /dev/null +++ b/src/platform/tasks/services/taskService.ts @@ -0,0 +1,70 @@ +/** + * Task Service for polling background task status. + * + * CAVEAT: The `payload` and `result` schemas below are specific to + * `task:download_file` tasks. Other task types may have different + * payload/result structures. We are not generalizing this until + * additional use cases arise. + */ +import { z } from 'zod' +import { fromZodError } from 'zod-validation-error' + +import { api } from '@/scripts/api' + +const TASKS_ENDPOINT = '/tasks' + +const zTaskStatus = z.enum(['created', 'running', 'completed', 'failed']) + +const zDownloadFileResult = z.object({ + success: z.boolean(), + file_path: z.string().optional(), + bytes_downloaded: z.number().optional(), + content_type: z.string().optional(), + hash: z.string().optional(), + filename: z.string().optional(), + asset_id: z.string().optional(), + metadata: z.record(z.unknown()).optional(), + error: z.string().optional() +}) + +const zTaskResponse = z.object({ + id: z.string().uuid(), + idempotency_key: z.string(), + task_name: z.string(), + payload: z.record(z.unknown()), + status: zTaskStatus, + result: zDownloadFileResult.optional(), + error_message: z.string().optional(), + create_time: z.string().datetime(), + update_time: z.string().datetime(), + started_at: z.string().datetime().optional(), + completed_at: z.string().datetime().optional() +}) + +export type TaskResponse = z.infer + +function createTaskService() { + async function getTask(taskId: string): Promise { + const res = await api.fetchApi(`${TASKS_ENDPOINT}/${taskId}`) + + if (!res.ok) { + if (res.status === 404) { + throw new Error(`Task not found: ${taskId}`) + } + throw new Error(`Failed to get task ${taskId}: ${res.status}`) + } + + const data = await res.json() + const result = zTaskResponse.safeParse(data) + + if (!result.success) { + throw new Error(fromZodError(result.error).message) + } + + return result.data + } + + return { getTask } +} + +export const taskService = createTaskService() diff --git a/src/schemas/apiSchema.ts b/src/schemas/apiSchema.ts index fc7597c34..6e01d6f4e 100644 --- a/src/schemas/apiSchema.ts +++ b/src/schemas/apiSchema.ts @@ -137,12 +137,12 @@ const zFeatureFlagsWsMessage = z.record(z.string(), z.any()) const zAssetDownloadWsMessage = z.object({ task_id: z.string(), - asset_id: z.string(), asset_name: z.string(), bytes_total: z.number(), bytes_downloaded: z.number(), progress: z.number(), status: z.enum(['created', 'running', 'completed', 'failed']), + asset_id: z.string().optional(), error: z.string().optional() }) diff --git a/src/stores/assetDownloadStore.test.ts b/src/stores/assetDownloadStore.test.ts new file mode 100644 index 000000000..28d58b9c1 --- /dev/null +++ b/src/stores/assetDownloadStore.test.ts @@ -0,0 +1,225 @@ +import { createTestingPinia } from '@pinia/testing' +import { setActivePinia } from 'pinia' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' + +import type { TaskResponse } from '@/platform/tasks/services/taskService' +import { taskService } from '@/platform/tasks/services/taskService' +import type { AssetDownloadWsMessage } from '@/schemas/apiSchema' +import { useAssetDownloadStore } from '@/stores/assetDownloadStore' + +type DownloadEventHandler = (e: CustomEvent) => void + +const eventHandler = vi.hoisted(() => { + const state: { current: DownloadEventHandler | null } = { current: null } + return state +}) + +vi.mock('@/scripts/api', () => ({ + api: { + addEventListener: vi.fn((_event: string, handler: DownloadEventHandler) => { + eventHandler.current = handler + }), + removeEventListener: vi.fn() + } +})) + +vi.mock('@/platform/tasks/services/taskService', () => ({ + taskService: { + getTask: vi.fn() + } +})) + +function createDownloadMessage( + overrides: Partial = {} +): AssetDownloadWsMessage { + return { + task_id: 'task-123', + asset_id: 'asset-456', + asset_name: 'model.safetensors', + bytes_total: 1000, + bytes_downloaded: 500, + progress: 50, + status: 'running', + ...overrides + } +} + +function dispatch(msg: AssetDownloadWsMessage) { + if (!eventHandler.current) { + throw new Error( + 'Event handler not registered. Call useAssetDownloadStore() first.' + ) + } + eventHandler.current(new CustomEvent('asset_download', { detail: msg })) +} + +describe('useAssetDownloadStore', () => { + beforeEach(() => { + setActivePinia(createTestingPinia({ stubActions: false })) + vi.useFakeTimers({ shouldAdvanceTime: false }) + vi.resetAllMocks() + eventHandler.current = null + }) + + afterEach(() => { + vi.useRealTimers() + }) + + describe('handleAssetDownload', () => { + it('tracks running downloads', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage()) + + expect(store.activeDownloads).toHaveLength(1) + expect(store.activeDownloads[0].taskId).toBe('task-123') + expect(store.activeDownloads[0].progress).toBe(50) + }) + + it('moves download to finished when completed', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage({ status: 'running' })) + expect(store.activeDownloads).toHaveLength(1) + + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + + expect(store.activeDownloads).toHaveLength(0) + expect(store.finishedDownloads).toHaveLength(1) + expect(store.finishedDownloads[0].status).toBe('completed') + }) + + it('moves download to finished when failed', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage({ status: 'running' })) + dispatch( + createDownloadMessage({ status: 'failed', error: 'Network error' }) + ) + + expect(store.activeDownloads).toHaveLength(0) + expect(store.finishedDownloads).toHaveLength(1) + expect(store.finishedDownloads[0].status).toBe('failed') + expect(store.finishedDownloads[0].error).toBe('Network error') + }) + + it('ignores duplicate terminal state messages', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + + expect(store.finishedDownloads).toHaveLength(1) + }) + }) + + describe('trackDownload', () => { + it('associates task with model type for completion tracking', () => { + const store = useAssetDownloadStore() + + store.trackDownload('task-123', 'checkpoints') + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + + expect(store.completedDownloads).toHaveLength(1) + expect(store.completedDownloads[0]).toMatchObject({ + taskId: 'task-123', + modelType: 'checkpoints' + }) + }) + }) + + describe('stale download polling', () => { + function createTaskResponse( + overrides: Partial = {} + ): TaskResponse { + return { + id: 'task-123', + idempotency_key: 'key-123', + task_name: 'task:download_file', + payload: {}, + status: 'completed', + create_time: new Date().toISOString(), + update_time: new Date().toISOString(), + result: { + success: true, + asset_id: 'asset-456', + filename: 'model.safetensors', + bytes_downloaded: 1000 + }, + ...overrides + } + } + + it('polls and completes stale downloads', async () => { + const store = useAssetDownloadStore() + + vi.mocked(taskService.getTask).mockResolvedValue(createTaskResponse()) + + dispatch(createDownloadMessage({ status: 'running' })) + expect(store.activeDownloads).toHaveLength(1) + + await vi.advanceTimersByTimeAsync(45_000) + + expect(taskService.getTask).toHaveBeenCalledWith('task-123') + expect(store.activeDownloads).toHaveLength(0) + expect(store.finishedDownloads[0].status).toBe('completed') + }) + + it('polls and marks failed downloads', async () => { + const store = useAssetDownloadStore() + + vi.mocked(taskService.getTask).mockResolvedValue( + createTaskResponse({ + status: 'failed', + error_message: 'Download failed', + result: { success: false, error: 'Network error' } + }) + ) + + dispatch(createDownloadMessage({ status: 'running' })) + await vi.advanceTimersByTimeAsync(45_000) + + expect(store.activeDownloads).toHaveLength(0) + expect(store.finishedDownloads[0].status).toBe('failed') + expect(store.finishedDownloads[0].error).toBe('Download failed') + }) + + it('does not complete if task still running', async () => { + const store = useAssetDownloadStore() + + vi.mocked(taskService.getTask).mockResolvedValue( + createTaskResponse({ status: 'running', result: undefined }) + ) + + dispatch(createDownloadMessage({ status: 'running' })) + await vi.advanceTimersByTimeAsync(45_000) + + expect(taskService.getTask).toHaveBeenCalled() + expect(store.activeDownloads).toHaveLength(1) + }) + + it('continues tracking on polling error', async () => { + const store = useAssetDownloadStore() + + vi.mocked(taskService.getTask).mockRejectedValue(new Error('Not found')) + dispatch(createDownloadMessage({ status: 'running' })) + + await vi.advanceTimersByTimeAsync(45_000) + + expect(store.activeDownloads).toHaveLength(1) + }) + }) + + describe('clearFinishedDownloads', () => { + it('removes all finished downloads', () => { + const store = useAssetDownloadStore() + + dispatch(createDownloadMessage({ status: 'completed', progress: 100 })) + expect(store.finishedDownloads).toHaveLength(1) + + store.clearFinishedDownloads() + + expect(store.finishedDownloads).toHaveLength(0) + }) + }) +}) diff --git a/src/stores/assetDownloadStore.ts b/src/stores/assetDownloadStore.ts index 91bf69015..8aa5788bd 100644 --- a/src/stores/assetDownloadStore.ts +++ b/src/stores/assetDownloadStore.ts @@ -1,17 +1,20 @@ +import { useIntervalFn } from '@vueuse/core' import { defineStore } from 'pinia' -import { computed, ref } from 'vue' +import { computed, ref, watch } from 'vue' +import { taskService } from '@/platform/tasks/services/taskService' import type { AssetDownloadWsMessage } from '@/schemas/apiSchema' import { api } from '@/scripts/api' export interface AssetDownload { taskId: string - assetId: string assetName: string bytesTotal: number bytesDownloaded: number progress: number status: 'created' | 'running' | 'completed' | 'failed' + lastUpdate: number + assetId?: string error?: string } @@ -21,20 +24,13 @@ interface CompletedDownload { timestamp: number } -const PROCESSED_TASK_CLEANUP_MS = 60000 const MAX_COMPLETED_DOWNLOADS = 10 +const STALE_THRESHOLD_MS = 10_000 +const POLL_INTERVAL_MS = 10_000 export const useAssetDownloadStore = defineStore('assetDownload', () => { - /** Map of task IDs to their download progress data */ const downloads = ref>(new Map()) - - /** Map of task IDs to model types, used to track which model type to refresh after download completes */ const pendingModelTypes = new Map() - - /** Set of task IDs that have reached a terminal state (completed/failed), prevents duplicate processing */ - const processedTaskIds = new Set() - - /** Reactive signal for completed downloads */ const completedDownloads = ref([]) const downloadList = computed(() => Array.from(downloads.value.values())) @@ -51,24 +47,17 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => { const hasActiveDownloads = computed(() => activeDownloads.value.length > 0) const hasDownloads = computed(() => downloads.value.size > 0) - /** - * Associates a download task with its model type for later use when the download completes. - * Intended for external callers (e.g., useUploadModelWizard) to register async downloads. - */ function trackDownload(taskId: string, modelType: string) { pendingModelTypes.set(taskId, modelType) } - /** - * Handles asset download WebSocket events. Updates download progress, manages toast notifications, - * and tracks completed downloads. Prevents duplicate processing of terminal states (completed/failed). - */ function handleAssetDownload(e: CustomEvent) { const data = e.detail + const existing = downloads.value.get(data.task_id) - if (data.status === 'completed' || data.status === 'failed') { - if (processedTaskIds.has(data.task_id)) return - processedTaskIds.add(data.task_id) + // Skip if already in terminal state + if (existing?.status === 'completed' || existing?.status === 'failed') { + return } const download: AssetDownload = { @@ -79,7 +68,8 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => { bytesDownloaded: data.bytes_downloaded, progress: data.progress, status: data.status, - error: data.error + error: data.error, + lastUpdate: Date.now() } downloads.value.set(data.task_id, download) @@ -87,33 +77,72 @@ export const useAssetDownloadStore = defineStore('assetDownload', () => { if (data.status === 'completed') { const modelType = pendingModelTypes.get(data.task_id) if (modelType) { - const newDownload: CompletedDownload = { - taskId: data.task_id, - modelType, - timestamp: Date.now() - } - - const updated = [...completedDownloads.value, newDownload] - if (updated.length > MAX_COMPLETED_DOWNLOADS) { - updated.shift() - } + const updated = [ + ...completedDownloads.value, + { taskId: data.task_id, modelType, timestamp: Date.now() } + ] + if (updated.length > MAX_COMPLETED_DOWNLOADS) updated.shift() completedDownloads.value = updated - pendingModelTypes.delete(data.task_id) } - setTimeout( - () => processedTaskIds.delete(data.task_id), - PROCESSED_TASK_CLEANUP_MS - ) } else if (data.status === 'failed') { pendingModelTypes.delete(data.task_id) - setTimeout( - () => processedTaskIds.delete(data.task_id), - PROCESSED_TASK_CLEANUP_MS - ) } } + async function pollStaleDownloads() { + const now = Date.now() + const staleDownloads = activeDownloads.value.filter( + (d) => now - d.lastUpdate >= STALE_THRESHOLD_MS + ) + + if (staleDownloads.length === 0) return + + async function pollSingleDownload(download: AssetDownload) { + try { + const task = await taskService.getTask(download.taskId) + + if (task.status === 'completed' || task.status === 'failed') { + const result = task.result + handleAssetDownload( + new CustomEvent('asset_download', { + detail: { + task_id: download.taskId, + asset_id: result?.asset_id ?? download.assetId, + asset_name: result?.filename ?? download.assetName, + bytes_total: download.bytesTotal, + bytes_downloaded: + result?.bytes_downloaded ?? download.bytesTotal, + progress: task.status === 'completed' ? 100 : download.progress, + status: task.status, + error: task.error_message ?? result?.error + } + }) + ) + } + } catch { + // Task not ready or not found + } + } + + await Promise.all(staleDownloads.map(pollSingleDownload)) + } + + const { pause, resume } = useIntervalFn( + () => void pollStaleDownloads(), + POLL_INTERVAL_MS, + { immediate: false } + ) + + watch( + hasActiveDownloads, + (hasActive) => { + if (hasActive) resume() + else pause() + }, + { immediate: true } + ) + api.addEventListener('asset_download', handleAssetDownload) function clearFinishedDownloads() {