diff --git a/src/platform/assets/components/AssetBrowserModal.test.ts b/src/platform/assets/components/AssetBrowserModal.test.ts index 87135b468..6fb027519 100644 --- a/src/platform/assets/components/AssetBrowserModal.test.ts +++ b/src/platform/assets/components/AssetBrowserModal.test.ts @@ -6,6 +6,9 @@ import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vu import type { AssetItem } from '@/platform/assets/schemas/assetSchema' import { useAssetsStore } from '@/stores/assetsStore' +const mockAssetsByKey = vi.hoisted(() => new Map()) +const mockLoadingByKey = vi.hoisted(() => new Map()) + vi.mock('@/i18n', () => ({ t: (key: string, params?: Record) => params ? `${key}:${JSON.stringify(params)}` : key, @@ -13,13 +16,20 @@ vi.mock('@/i18n', () => ({ })) vi.mock('@/stores/assetsStore', () => { - const store = { - modelAssetsByNodeType: new Map(), - modelLoadingByNodeType: new Map(), - updateModelsForNodeType: vi.fn(), - updateModelsForTag: vi.fn() + const getAssets = vi.fn((key: string) => mockAssetsByKey.get(key) ?? []) + const isModelLoading = vi.fn( + (key: string) => mockLoadingByKey.get(key) ?? false + ) + const updateModelsForNodeType = vi.fn() + const updateModelsForTag = vi.fn() + return { + useAssetsStore: () => ({ + getAssets, + isModelLoading, + updateModelsForNodeType, + updateModelsForTag + }) } - return { useAssetsStore: () => store } }) vi.mock('@/stores/modelToNodeStore', () => ({ @@ -183,12 +193,10 @@ describe('AssetBrowserModal', () => { }) } - const mockStore = useAssetsStore() - beforeEach(() => { vi.resetAllMocks() - mockStore.modelAssetsByNodeType.clear() - mockStore.modelLoadingByNodeType.clear() + mockAssetsByKey.clear() + mockLoadingByKey.clear() }) describe('Integration with useAssetBrowser', () => { @@ -197,7 +205,7 @@ describe('AssetBrowserModal', () => { createTestAsset('asset1', 'Model A', 'checkpoints'), createTestAsset('asset2', 'Model B', 'loras') ] - mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets) + mockAssetsByKey.set('CheckpointLoaderSimple', assets) const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' }) await flushPromises() @@ -214,7 +222,7 @@ describe('AssetBrowserModal', () => { createTestAsset('c1', 'model.safetensors', 'checkpoints'), createTestAsset('l1', 'lora.pt', 'loras') ] - mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets) + mockAssetsByKey.set('CheckpointLoaderSimple', assets) const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple', @@ -231,17 +239,18 @@ describe('AssetBrowserModal', () => { describe('Data fetching', () => { it('triggers store refresh for node type on mount', async () => { + const store = useAssetsStore() createWrapper({ nodeType: 'CheckpointLoaderSimple' }) await flushPromises() - expect(mockStore.updateModelsForNodeType).toHaveBeenCalledWith( + expect(store.updateModelsForNodeType).toHaveBeenCalledWith( 'CheckpointLoaderSimple' ) }) it('displays cached assets immediately from store', async () => { const assets = [createTestAsset('asset1', 'Cached Model', 'checkpoints')] - mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets) + mockAssetsByKey.set('CheckpointLoaderSimple', assets) const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' }) @@ -253,15 +262,16 @@ describe('AssetBrowserModal', () => { }) it('triggers store refresh for asset type (tag) on mount', async () => { + const store = useAssetsStore() createWrapper({ assetType: 'models' }) await flushPromises() - expect(mockStore.updateModelsForTag).toHaveBeenCalledWith('models') + expect(store.updateModelsForTag).toHaveBeenCalledWith('models') }) it('uses tag: prefix for cache key when assetType is provided', async () => { const assets = [createTestAsset('asset1', 'Tagged Model', 'models')] - mockStore.modelAssetsByNodeType.set('tag:models', assets) + mockAssetsByKey.set('tag:models', assets) const wrapper = createWrapper({ assetType: 'models' }) await flushPromises() @@ -277,7 +287,7 @@ describe('AssetBrowserModal', () => { describe('Asset Selection', () => { it('emits asset-select event when asset is selected', async () => { const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')] - mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets) + mockAssetsByKey.set('CheckpointLoaderSimple', assets) const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' }) await flushPromises() @@ -290,7 +300,7 @@ describe('AssetBrowserModal', () => { it('executes onSelect callback when provided', async () => { const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')] - mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets) + mockAssetsByKey.set('CheckpointLoaderSimple', assets) const onSelect = vi.fn() const wrapper = createWrapper({ @@ -333,7 +343,7 @@ describe('AssetBrowserModal', () => { createTestAsset('asset1', 'Model A', 'checkpoints'), createTestAsset('asset2', 'Model B', 'loras') ] - mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets) + mockAssetsByKey.set('CheckpointLoaderSimple', assets) const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple', @@ -366,7 +376,7 @@ describe('AssetBrowserModal', () => { it('passes computed contentTitle to BaseModalLayout when no title prop', async () => { const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')] - mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets) + mockAssetsByKey.set('CheckpointLoaderSimple', assets) const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' }) await flushPromises() diff --git a/src/platform/assets/components/AssetBrowserModal.vue b/src/platform/assets/components/AssetBrowserModal.vue index 8344a807f..c30dd02b6 100644 --- a/src/platform/assets/components/AssetBrowserModal.vue +++ b/src/platform/assets/components/AssetBrowserModal.vue @@ -112,27 +112,21 @@ const cacheKey = computed(() => { }) // Read directly from store cache - reactive to any store updates -const fetchedAssets = computed( - () => assetStore.modelAssetsByNodeType.get(cacheKey.value) ?? [] -) +const fetchedAssets = computed(() => assetStore.getAssets(cacheKey.value)) -const isStoreLoading = computed( - () => assetStore.modelLoadingByNodeType.get(cacheKey.value) ?? false -) +const isStoreLoading = computed(() => assetStore.isModelLoading(cacheKey.value)) // Only show loading spinner when loading AND no cached data const isLoading = computed( () => isStoreLoading.value && fetchedAssets.value.length === 0 ) -async function refreshAssets(): Promise { +async function refreshAssets(): Promise { if (props.nodeType) { - return await assetStore.updateModelsForNodeType(props.nodeType) + await assetStore.updateModelsForNodeType(props.nodeType) + } else if (props.assetType) { + await assetStore.updateModelsForTag(props.assetType) } - if (props.assetType) { - return await assetStore.updateModelsForTag(props.assetType) - } - return [] } // Trigger background refresh on mount diff --git a/src/platform/assets/services/assetService.test.ts b/src/platform/assets/services/assetService.test.ts index ead7902ff..e126fd66c 100644 --- a/src/platform/assets/services/assetService.test.ts +++ b/src/platform/assets/services/assetService.test.ts @@ -160,7 +160,7 @@ describe('assetService', () => { const result = await assetService.getAssetModels('checkpoints') expect(api.fetchApi).toHaveBeenCalledWith( - '/assets?include_tags=models,checkpoints&limit=500' + '/assets?include_tags=models%2Ccheckpoints&limit=500' ) expect(result).toEqual([ expect.objectContaining({ name: 'valid.safetensors', pathIndex: 0 }) @@ -231,9 +231,9 @@ describe('assetService', () => { ) expect(result).toEqual(testAssets) - // Verify API call includes correct category + // Verify API call includes correct category (comma is URL-encoded by URLSearchParams) expect(api.fetchApi).toHaveBeenCalledWith( - '/assets?include_tags=models,checkpoints&limit=500' + '/assets?include_tags=models%2Ccheckpoints&limit=500' ) }) @@ -400,7 +400,7 @@ describe('assetService', () => { }) expect(api.fetchApi).toHaveBeenCalledWith( - '/assets?include_tags=models&limit=500&include_public=true&offset=50' + '/assets?include_tags=models&limit=500&offset=50&include_public=true' ) expect(result).toEqual(testAssets) }) @@ -415,7 +415,7 @@ describe('assetService', () => { }) expect(api.fetchApi).toHaveBeenCalledWith( - '/assets?include_tags=input&limit=100&include_public=false&offset=25' + '/assets?include_tags=input&limit=100&offset=25&include_public=false' ) expect(result).toEqual(testAssets) }) diff --git a/src/platform/assets/services/assetService.ts b/src/platform/assets/services/assetService.ts index da87681bb..fcf0367e8 100644 --- a/src/platform/assets/services/assetService.ts +++ b/src/platform/assets/services/assetService.ts @@ -1,6 +1,7 @@ import { fromZodError } from 'zod-validation-error' import { st } from '@/i18n' + import { assetItemSchema, assetResponseSchema, @@ -17,6 +18,16 @@ import type { import { api } from '@/scripts/api' import { useModelToNodeStore } from '@/stores/modelToNodeStore' +export interface PaginationOptions { + limit?: number + offset?: number +} + +interface AssetRequestOptions extends PaginationOptions { + includeTags: string[] + includePublic?: boolean +} + /** * Maps CivitAI validation error codes to localized error messages */ @@ -77,9 +88,27 @@ function createAssetService() { * Handles API response with consistent error handling and Zod validation */ async function handleAssetRequest( - url: string, + options: AssetRequestOptions, context: string ): Promise { + const { + includeTags, + limit = DEFAULT_LIMIT, + offset, + includePublic + } = options + const queryParams = new URLSearchParams({ + include_tags: includeTags.join(','), + limit: limit.toString() + }) + if (offset !== undefined && offset > 0) { + queryParams.set('offset', offset.toString()) + } + if (includePublic !== undefined) { + queryParams.set('include_public', includePublic ? 'true' : 'false') + } + + const url = `${ASSETS_ENDPOINT}?${queryParams.toString()}` const res = await api.fetchApi(url) if (!res.ok) { throw new Error( @@ -101,7 +130,7 @@ function createAssetService() { */ async function getAssetModelFolders(): Promise { const data = await handleAssetRequest( - `${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG}&limit=${DEFAULT_LIMIT}`, + { includeTags: [MODELS_TAG] }, 'model folders' ) @@ -130,7 +159,7 @@ function createAssetService() { */ async function getAssetModels(folder: string): Promise { const data = await handleAssetRequest( - `${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${folder}&limit=${DEFAULT_LIMIT}`, + { includeTags: [MODELS_TAG, folder] }, `models for ${folder}` ) @@ -169,9 +198,15 @@ function createAssetService() { * and fetching all assets with that category tag * * @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple') + * @param options - Pagination options + * @param options.limit - Maximum number of assets to return (default: 500) + * @param options.offset - Number of assets to skip (default: 0) * @returns Promise - Full asset objects with preserved metadata */ - async function getAssetsForNodeType(nodeType: string): Promise { + async function getAssetsForNodeType( + nodeType: string, + { limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {} + ): Promise { if (!nodeType || typeof nodeType !== 'string') { return [] } @@ -186,7 +221,7 @@ function createAssetService() { // Fetch assets for this category using same API pattern as getAssetModels const data = await handleAssetRequest( - `${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}&limit=${DEFAULT_LIMIT}`, + { includeTags: [MODELS_TAG, category], limit, offset }, `assets for ${nodeType}` ) @@ -242,23 +277,10 @@ function createAssetService() { async function getAssetsByTag( tag: string, includePublic: boolean = true, - { - limit = DEFAULT_LIMIT, - offset = 0 - }: { limit?: number; offset?: number } = {} + { limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {} ): Promise { - const queryParams = new URLSearchParams({ - include_tags: tag, - limit: limit.toString(), - include_public: includePublic ? 'true' : 'false' - }) - - if (offset > 0) { - queryParams.set('offset', offset.toString()) - } - const data = await handleAssetRequest( - `${ASSETS_ENDPOINT}?${queryParams.toString()}`, + { includeTags: [tag], limit, offset, includePublic }, `assets for tag ${tag}` ) diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.desktop.test.ts b/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.desktop.test.ts index 2410f0964..1c51b9d30 100644 --- a/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.desktop.test.ts +++ b/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.desktop.test.ts @@ -12,9 +12,9 @@ const mockGetCategoryForNodeType = vi.fn() vi.mock('@/stores/assetsStore', () => ({ useAssetsStore: () => ({ - modelAssetsByNodeType: new Map(), - modelLoadingByNodeType: new Map(), - modelErrorByNodeType: new Map(), + getAssets: () => [], + isModelLoading: () => false, + getError: () => undefined, updateModelsForNodeType: mockUpdateModelsForNodeType }) })) diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.test.ts b/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.test.ts index 9c54564f1..131d3245d 100644 --- a/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.test.ts +++ b/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.test.ts @@ -8,17 +8,17 @@ vi.mock('@/platform/distribution/types', () => ({ isCloud: true })) -const mockModelAssetsByNodeType = new Map() -const mockModelLoadingByNodeType = new Map() -const mockModelErrorByNodeType = new Map() +const mockAssetsByKey = new Map() +const mockLoadingByKey = new Map() +const mockErrorByKey = new Map() const mockUpdateModelsForNodeType = vi.fn() const mockGetCategoryForNodeType = vi.fn() vi.mock('@/stores/assetsStore', () => ({ useAssetsStore: () => ({ - modelAssetsByNodeType: mockModelAssetsByNodeType, - modelLoadingByNodeType: mockModelLoadingByNodeType, - modelErrorByNodeType: mockModelErrorByNodeType, + getAssets: (key: string) => mockAssetsByKey.get(key) ?? [], + isModelLoading: (key: string) => mockLoadingByKey.get(key) ?? false, + getError: (key: string) => mockErrorByKey.get(key), updateModelsForNodeType: mockUpdateModelsForNodeType }) })) @@ -32,9 +32,9 @@ vi.mock('@/stores/modelToNodeStore', () => ({ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => { beforeEach(() => { vi.clearAllMocks() - mockModelAssetsByNodeType.clear() - mockModelLoadingByNodeType.clear() - mockModelErrorByNodeType.clear() + mockAssetsByKey.clear() + mockLoadingByKey.clear() + mockErrorByKey.clear() mockGetCategoryForNodeType.mockReturnValue(undefined) mockUpdateModelsForNodeType.mockImplementation( @@ -76,8 +76,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => { mockUpdateModelsForNodeType.mockImplementation( async (_nodeType: string): Promise => { - mockModelAssetsByNodeType.set(_nodeType, mockAssets) - mockModelLoadingByNodeType.set(_nodeType, false) + mockAssetsByKey.set(_nodeType, mockAssets) + mockLoadingByKey.set(_nodeType, false) return mockAssets } ) @@ -108,9 +108,9 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => { mockUpdateModelsForNodeType.mockImplementation( async (_nodeType: string): Promise => { - mockModelErrorByNodeType.set(_nodeType, mockError) - mockModelAssetsByNodeType.set(_nodeType, []) - mockModelLoadingByNodeType.set(_nodeType, false) + mockErrorByKey.set(_nodeType, mockError) + mockAssetsByKey.set(_nodeType, []) + mockLoadingByKey.set(_nodeType, false) return [] } ) @@ -130,8 +130,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => { mockUpdateModelsForNodeType.mockImplementation( async (_nodeType: string): Promise => { - mockModelAssetsByNodeType.set(_nodeType, []) - mockModelLoadingByNodeType.set(_nodeType, false) + mockAssetsByKey.set(_nodeType, []) + mockLoadingByKey.set(_nodeType, false) return [] } ) @@ -154,8 +154,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => { mockGetCategoryForNodeType.mockReturnValue('checkpoints') mockUpdateModelsForNodeType.mockImplementation( async (_nodeType: string): Promise => { - mockModelAssetsByNodeType.set(_nodeType, mockAssets) - mockModelLoadingByNodeType.set(_nodeType, false) + mockAssetsByKey.set(_nodeType, mockAssets) + mockLoadingByKey.set(_nodeType, false) return mockAssets } ) @@ -182,8 +182,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => { mockGetCategoryForNodeType.mockReturnValue('loras') mockUpdateModelsForNodeType.mockImplementation( async (_nodeType: string): Promise => { - mockModelAssetsByNodeType.set(_nodeType, mockAssets) - mockModelLoadingByNodeType.set(_nodeType, false) + mockAssetsByKey.set(_nodeType, mockAssets) + mockLoadingByKey.set(_nodeType, false) return mockAssets } ) @@ -209,8 +209,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => { mockGetCategoryForNodeType.mockReturnValue('checkpoints') mockUpdateModelsForNodeType.mockImplementation( async (_nodeType: string): Promise => { - mockModelAssetsByNodeType.set(_nodeType, mockAssets) - mockModelLoadingByNodeType.set(_nodeType, false) + mockAssetsByKey.set(_nodeType, mockAssets) + mockLoadingByKey.set(_nodeType, false) return mockAssets } ) diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.ts b/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.ts index b03f1dac9..06179edb7 100644 --- a/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.ts +++ b/src/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData.ts @@ -34,23 +34,17 @@ export function useAssetWidgetData( const assets = computed(() => { const resolvedType = toValue(nodeType) - return resolvedType - ? (assetsStore.modelAssetsByNodeType.get(resolvedType) ?? []) - : [] + return resolvedType ? (assetsStore.getAssets(resolvedType) ?? []) : [] }) const isLoading = computed(() => { const resolvedType = toValue(nodeType) - return resolvedType - ? (assetsStore.modelLoadingByNodeType.get(resolvedType) ?? false) - : false + return resolvedType ? assetsStore.isModelLoading(resolvedType) : false }) const error = computed(() => { const resolvedType = toValue(nodeType) - return resolvedType - ? (assetsStore.modelErrorByNodeType.get(resolvedType) ?? null) - : null + return resolvedType ? (assetsStore.getError(resolvedType) ?? null) : null }) const dropdownItems = computed(() => { @@ -71,7 +65,8 @@ export function useAssetWidgetData( return } - const hasData = assetsStore.modelAssetsByNodeType.has(currentNodeType) + const existingAssets = assetsStore.getAssets(currentNodeType) ?? [] + const hasData = existingAssets.length > 0 if (!hasData) { await assetsStore.updateModelsForNodeType(currentNodeType) diff --git a/src/stores/assetsStore.test.ts b/src/stores/assetsStore.test.ts index 1905a0177..ed7705f15 100644 --- a/src/stores/assetsStore.test.ts +++ b/src/stores/assetsStore.test.ts @@ -1,5 +1,7 @@ -import { createPinia, setActivePinia } from 'pinia' -import { beforeEach, describe, expect, it, vi } from 'vitest' +import { createTestingPinia } from '@pinia/testing' +import { setActivePinia } from 'pinia' +import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest' +import { nextTick, watch } from 'vue' import { useAssetsStore } from '@/stores/assetsStore' import { api } from '@/scripts/api' @@ -9,6 +11,7 @@ import type { TaskStatus, TaskOutput } from '@/schemas/apiSchema' +import { assetService } from '@/platform/assets/services/assetService' // Mock the api module vi.mock('@/scripts/api', () => ({ @@ -25,13 +28,17 @@ vi.mock('@/scripts/api', () => ({ // Mock the asset service vi.mock('@/platform/assets/services/assetService', () => ({ assetService: { - getAssetsByTag: vi.fn() + getAssetsByTag: vi.fn(), + getAssetsForNodeType: vi.fn() } })) -// Mock distribution type +// Mock distribution type - hoisted so it can be changed per test +const mockIsCloud = vi.hoisted(() => ({ value: false })) vi.mock('@/platform/distribution/types', () => ({ - isCloud: false + get isCloud() { + return mockIsCloud.value + } })) // Mock TaskItemImpl @@ -144,7 +151,7 @@ describe('assetsStore - Refactored (Option A)', () => { }) beforeEach(() => { - setActivePinia(createPinia()) + setActivePinia(createTestingPinia({ stubActions: false })) store = useAssetsStore() vi.clearAllMocks() }) @@ -520,3 +527,158 @@ describe('assetsStore - Refactored (Option A)', () => { }) }) }) + +describe('assetsStore - Model Assets Cache (Cloud)', () => { + beforeEach(() => { + setActivePinia(createTestingPinia({ stubActions: false })) + mockIsCloud.value = true + vi.clearAllMocks() + }) + + afterEach(() => { + mockIsCloud.value = false + }) + + const createMockAsset = (id: string) => ({ + id, + name: `asset-${id}`, + size: 100, + created_at: new Date().toISOString(), + tags: ['models'], + preview_url: `http://test.com/${id}` + }) + + describe('getAssets cache invalidation', () => { + it('should invalidate cache before mutating assets during batch loading', async () => { + const store = useAssetsStore() + const nodeType = 'CheckpointLoaderSimple' + + const firstBatch = Array.from({ length: 500 }, (_, i) => + createMockAsset(`asset-${i}`) + ) + const secondBatch = Array.from({ length: 100 }, (_, i) => + createMockAsset(`asset-${500 + i}`) + ) + + let callCount = 0 + vi.mocked(assetService.getAssetsForNodeType).mockImplementation( + async () => { + callCount++ + return callCount === 1 ? firstBatch : secondBatch + } + ) + + await store.updateModelsForNodeType(nodeType) + + // Wait for background batch loading to complete + await vi.waitFor(() => { + expect( + vi.mocked(assetService.getAssetsForNodeType) + ).toHaveBeenCalledTimes(2) + }) + + const assets = store.getAssets(nodeType) + expect(assets).toHaveLength(600) + }) + + it('should not return stale cached array after background batch completes', async () => { + const store = useAssetsStore() + const nodeType = 'LoraLoader' + + // First batch must be exactly MODEL_BATCH_SIZE (500) to trigger hasMore + const firstBatch = Array.from({ length: 500 }, (_, i) => + createMockAsset(`first-${i}`) + ) + const secondBatch = [createMockAsset('new-asset')] + + let callCount = 0 + vi.mocked(assetService.getAssetsForNodeType).mockImplementation( + async () => { + callCount++ + return callCount === 1 ? firstBatch : secondBatch + } + ) + + await store.updateModelsForNodeType(nodeType) + + // Wait for background batch loading to complete + await vi.waitFor(() => { + expect( + vi.mocked(assetService.getAssetsForNodeType) + ).toHaveBeenCalledTimes(2) + }) + + const assets = store.getAssets(nodeType) + expect(assets).toHaveLength(501) + expect(assets.map((a) => a.id)).toContain('new-asset') + }) + + it('should return cached array on subsequent getAssets calls', () => { + const store = useAssetsStore() + const nodeType = 'TestLoader' + + const firstCall = store.getAssets(nodeType) + const secondCall = store.getAssets(nodeType) + + expect(secondCall).toBe(firstCall) + }) + }) + + describe('concurrent request handling', () => { + it('should discard stale request when newer request starts', async () => { + const store = useAssetsStore() + const nodeType = 'CheckpointLoaderSimple' + const firstBatch = Array.from({ length: 5 }, (_, i) => + createMockAsset(`first-${i}`) + ) + const secondBatch = Array.from({ length: 10 }, (_, i) => + createMockAsset(`second-${i}`) + ) + + let resolveFirst: (value: ReturnType[]) => void + const firstPromise = new Promise[]>( + (resolve) => { + resolveFirst = resolve + } + ) + let callCount = 0 + vi.mocked(assetService.getAssetsForNodeType).mockImplementation( + async () => { + callCount++ + return callCount === 1 ? firstPromise : secondBatch + } + ) + + const firstRequest = store.updateModelsForNodeType(nodeType) + const secondRequest = store.updateModelsForNodeType(nodeType) + resolveFirst!(firstBatch) + await Promise.all([firstRequest, secondRequest]) + + expect(store.getAssets(nodeType)).toHaveLength(10) + expect( + store.getAssets(nodeType).every((a) => a.id.startsWith('second-')) + ).toBe(true) + }) + }) + + describe('shallowReactive state reactivity', () => { + it('should trigger reactivity on isModelLoading change', async () => { + const store = useAssetsStore() + const nodeType = 'CheckpointLoaderSimple' + + const loadingStates: boolean[] = [] + watch( + () => store.isModelLoading(nodeType), + (val) => loadingStates.push(val), + { immediate: true } + ) + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([]) + await store.updateModelsForNodeType(nodeType) + await nextTick() + + expect(loadingStates).toContain(true) + expect(loadingStates).toContain(false) + }) + }) +}) diff --git a/src/stores/assetsStore.ts b/src/stores/assetsStore.ts index 56efe3daa..f32a4665f 100644 --- a/src/stores/assetsStore.ts +++ b/src/stores/assetsStore.ts @@ -1,13 +1,13 @@ import { useAsyncState, whenever } from '@vueuse/core' -import { isEqual } from 'es-toolkit' import { defineStore } from 'pinia' -import { computed, shallowReactive, ref } from 'vue' +import { computed, reactive, ref, shallowReactive } from 'vue' import { mapInputFileToAssetItem, mapTaskOutputToAssetItem } from '@/platform/assets/composables/media/assetMappers' import type { AssetItem } from '@/platform/assets/schemas/assetSchema' import { assetService } from '@/platform/assets/services/assetService' +import type { PaginationOptions } from '@/platform/assets/services/assetService' import { isCloud } from '@/platform/distribution/types' import type { TaskItem } from '@/schemas/apiSchema' import { api } from '@/scripts/api' @@ -261,6 +261,16 @@ export const useAssetsStore = defineStore('assets', () => { return inputAssetsByFilename.value.get(filename)?.name ?? filename } + const MODEL_BATCH_SIZE = 500 + + interface ModelPaginationState { + assets: Map + offset: number + hasMore: boolean + isLoading: boolean + error?: Error + } + /** * Model assets cached by node type (e.g., 'CheckpointLoaderSimple', 'LoraLoader') * Used by multiple loader nodes to avoid duplicate fetches @@ -268,109 +278,183 @@ export const useAssetsStore = defineStore('assets', () => { */ const getModelState = () => { if (isCloud) { - const modelAssetsByNodeType = shallowReactive( - new Map() - ) - const modelLoadingByNodeType = shallowReactive(new Map()) - const modelErrorByNodeType = shallowReactive( - new Map() - ) + const modelStateByKey = ref(new Map()) - const stateByNodeType = shallowReactive( - new Map>>() - ) + const assetsArrayCache = new Map< + string, + { source: Map; array: AssetItem[] } + >() + + const pendingRequestByKey = new Map() + + function createState(): ModelPaginationState { + return reactive({ + assets: new Map(), + offset: 0, + hasMore: true, + isLoading: false + }) + } + + function isStale(key: string, state: ModelPaginationState): boolean { + const committed = modelStateByKey.value.get(key) + const pending = pendingRequestByKey.get(key) + return committed !== state && pending !== state + } + + const EMPTY_ASSETS: AssetItem[] = [] + + function getAssets(key: string): AssetItem[] { + const state = modelStateByKey.value.get(key) + const assetsMap = state?.assets + if (!assetsMap) return EMPTY_ASSETS + + const cached = assetsArrayCache.get(key) + if (cached && cached.source === assetsMap) { + return cached.array + } + + const array = Array.from(assetsMap.values()) + assetsArrayCache.set(key, { source: assetsMap, array }) + return array + } + + function isLoading(key: string): boolean { + return modelStateByKey.value.get(key)?.isLoading ?? false + } + + function getError(key: string): Error | undefined { + return modelStateByKey.value.get(key)?.error + } + + function hasMore(key: string): boolean { + return modelStateByKey.value.get(key)?.hasMore ?? false + } /** - * Internal helper to fetch and cache assets with a given key and fetcher + * Internal helper to fetch and cache assets with a given key and fetcher. + * Loads first batch immediately, then progressively loads remaining batches. + * Keeps existing data visible until new data is successfully fetched. */ async function updateModelsForKey( key: string, - fetcher: () => Promise - ): Promise { - if (!stateByNodeType.has(key)) { - stateByNodeType.set( - key, - useAsyncState(fetcher, [], { - immediate: false, - resetOnExecute: false, - onError: (err) => { - console.error(`Error fetching model assets for ${key}:`, err) + fetcher: (options: PaginationOptions) => Promise + ): Promise { + const state = createState() + state.isLoading = true + + const hasExistingData = modelStateByKey.value.has(key) + if (hasExistingData) { + pendingRequestByKey.set(key, state) + } else { + modelStateByKey.value.set(key, state) + } + + async function loadBatches(): Promise { + while (state.hasMore) { + try { + const newAssets = await fetcher({ + limit: MODEL_BATCH_SIZE, + offset: state.offset + }) + + if (isStale(key, state)) return + + const isFirstBatch = state.offset === 0 + if (isFirstBatch) { + assetsArrayCache.delete(key) + if (hasExistingData) { + pendingRequestByKey.delete(key) + modelStateByKey.value.set(key, state) + } + state.assets = new Map(newAssets.map((a) => [a.id, a])) + } else { + const assetsToAdd = newAssets.filter( + (a) => !state.assets.has(a.id) + ) + if (assetsToAdd.length > 0) { + assetsArrayCache.delete(key) + for (const asset of assetsToAdd) { + state.assets.set(asset.id, asset) + } + } } - }) - ) + + state.offset += newAssets.length + state.hasMore = newAssets.length === MODEL_BATCH_SIZE + + if (isFirstBatch) { + state.isLoading = false + } + + if (state.hasMore) { + await new Promise((resolve) => setTimeout(resolve, 50)) + } + } catch (err) { + if (isStale(key, state)) return + state.error = err instanceof Error ? err : new Error(String(err)) + state.hasMore = false + console.error(`Error loading batch for ${key}:`, err) + if (state.offset === 0) { + state.isLoading = false + pendingRequestByKey.delete(key) + // TODO: Add toast indicator for first-batch load failures + } + return + } + } } - const state = stateByNodeType.get(key)! - - modelLoadingByNodeType.set(key, true) - modelErrorByNodeType.set(key, null) - - try { - await state.execute() - } finally { - modelLoadingByNodeType.set(key, state.isLoading.value) - } - - const assets = state.state.value - const existingAssets = modelAssetsByNodeType.get(key) - - if (!isEqual(existingAssets, assets)) { - modelAssetsByNodeType.set(key, assets) - } - - modelErrorByNodeType.set( - key, - state.error.value instanceof Error ? state.error.value : null - ) - - return assets + await loadBatches() } /** * Fetch and cache model assets for a specific node type * @param nodeType The node type to fetch assets for (e.g., 'CheckpointLoaderSimple') - * @returns Promise resolving to the fetched assets */ - async function updateModelsForNodeType( - nodeType: string - ): Promise { - return updateModelsForKey(nodeType, () => - assetService.getAssetsForNodeType(nodeType) + async function updateModelsForNodeType(nodeType: string): Promise { + await updateModelsForKey(nodeType, (opts) => + assetService.getAssetsForNodeType(nodeType, opts) ) } /** * Fetch and cache model assets for a specific tag * @param tag The tag to fetch assets for (e.g., 'models') - * @returns Promise resolving to the fetched assets */ - async function updateModelsForTag(tag: string): Promise { + async function updateModelsForTag(tag: string): Promise { const key = `tag:${tag}` - return updateModelsForKey(key, () => assetService.getAssetsByTag(tag)) + await updateModelsForKey(key, (opts) => + assetService.getAssetsByTag(tag, true, opts) + ) } return { - modelAssetsByNodeType, - modelLoadingByNodeType, - modelErrorByNodeType, + getAssets, + isLoading, + getError, + hasMore, updateModelsForNodeType, updateModelsForTag } } + const emptyAssets: AssetItem[] = [] return { - modelAssetsByNodeType: shallowReactive(new Map()), - modelLoadingByNodeType: shallowReactive(new Map()), - modelErrorByNodeType: shallowReactive(new Map()), - updateModelsForNodeType: async () => [], - updateModelsForTag: async () => [] + getAssets: () => emptyAssets, + isLoading: () => false, + getError: () => undefined, + hasMore: () => false, + updateModelsForNodeType: async () => {}, + updateModelsForTag: async () => {} } } const { - modelAssetsByNodeType, - modelLoadingByNodeType, - modelErrorByNodeType, + getAssets, + isLoading: isModelLoading, + getError, + hasMore, updateModelsForNodeType, updateModelsForTag } = getModelState() @@ -432,10 +516,13 @@ export const useAssetsStore = defineStore('assets', () => { inputAssetsByFilename, getInputName, - // Model assets - modelAssetsByNodeType, - modelLoadingByNodeType, - modelErrorByNodeType, + // Model assets - accessors + getAssets, + isModelLoading, + getError, + hasMore, + + // Model assets - actions updateModelsForNodeType, updateModelsForTag }