feat: implement progressive pagination for Asset Browser model assets (#8212)

## Summary

Implements progressive pagination for model assets - returns the first
batch immediately while loading remaining batches in the background.

## Changes

### Store (`assetsStore.ts`)
- Adds `ModelPaginationState` tracking (assets Map, offset, hasMore,
loading, error)
- `updateModelsForKey()` returns first batch, then calls
`loadRemainingBatches()` to fetch the rest
- Accessor functions `getAssets(key)`, `isModelLoading(key)` replace
direct Map access

### API (`assetService.ts`)
- Adds `PaginationOptions` interface (`{ limit?, offset? }`)

### Components
- `AssetBrowserModal.vue` uses new accessor API

### Tests
- Updated mocks for new accessor pattern

┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-8212-feat-implement-progressive-pagination-for-Asset-Browser-model-assets-2ef6d73d36508157af04d1264780997e)
by [Unito](https://www.unito.io)

---------

Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Alexander Brown
2026-01-21 17:59:08 -08:00
committed by GitHub
parent f08b0f44ef
commit 482159957e
9 changed files with 441 additions and 171 deletions

View File

@@ -6,6 +6,9 @@ import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vu
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { useAssetsStore } from '@/stores/assetsStore'
const mockAssetsByKey = vi.hoisted(() => new Map<string, AssetItem[]>())
const mockLoadingByKey = vi.hoisted(() => new Map<string, boolean>())
vi.mock('@/i18n', () => ({
t: (key: string, params?: Record<string, string>) =>
params ? `${key}:${JSON.stringify(params)}` : key,
@@ -13,13 +16,20 @@ vi.mock('@/i18n', () => ({
}))
vi.mock('@/stores/assetsStore', () => {
const store = {
modelAssetsByNodeType: new Map<string, AssetItem[]>(),
modelLoadingByNodeType: new Map<string, boolean>(),
updateModelsForNodeType: vi.fn(),
updateModelsForTag: vi.fn()
const getAssets = vi.fn((key: string) => mockAssetsByKey.get(key) ?? [])
const isModelLoading = vi.fn(
(key: string) => mockLoadingByKey.get(key) ?? false
)
const updateModelsForNodeType = vi.fn()
const updateModelsForTag = vi.fn()
return {
useAssetsStore: () => ({
getAssets,
isModelLoading,
updateModelsForNodeType,
updateModelsForTag
})
}
return { useAssetsStore: () => store }
})
vi.mock('@/stores/modelToNodeStore', () => ({
@@ -183,12 +193,10 @@ describe('AssetBrowserModal', () => {
})
}
const mockStore = useAssetsStore()
beforeEach(() => {
vi.resetAllMocks()
mockStore.modelAssetsByNodeType.clear()
mockStore.modelLoadingByNodeType.clear()
mockAssetsByKey.clear()
mockLoadingByKey.clear()
})
describe('Integration with useAssetBrowser', () => {
@@ -197,7 +205,7 @@ describe('AssetBrowserModal', () => {
createTestAsset('asset1', 'Model A', 'checkpoints'),
createTestAsset('asset2', 'Model B', 'loras')
]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()
@@ -214,7 +222,7 @@ describe('AssetBrowserModal', () => {
createTestAsset('c1', 'model.safetensors', 'checkpoints'),
createTestAsset('l1', 'lora.pt', 'loras')
]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({
nodeType: 'CheckpointLoaderSimple',
@@ -231,17 +239,18 @@ describe('AssetBrowserModal', () => {
describe('Data fetching', () => {
it('triggers store refresh for node type on mount', async () => {
const store = useAssetsStore()
createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()
expect(mockStore.updateModelsForNodeType).toHaveBeenCalledWith(
expect(store.updateModelsForNodeType).toHaveBeenCalledWith(
'CheckpointLoaderSimple'
)
})
it('displays cached assets immediately from store', async () => {
const assets = [createTestAsset('asset1', 'Cached Model', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
@@ -253,15 +262,16 @@ describe('AssetBrowserModal', () => {
})
it('triggers store refresh for asset type (tag) on mount', async () => {
const store = useAssetsStore()
createWrapper({ assetType: 'models' })
await flushPromises()
expect(mockStore.updateModelsForTag).toHaveBeenCalledWith('models')
expect(store.updateModelsForTag).toHaveBeenCalledWith('models')
})
it('uses tag: prefix for cache key when assetType is provided', async () => {
const assets = [createTestAsset('asset1', 'Tagged Model', 'models')]
mockStore.modelAssetsByNodeType.set('tag:models', assets)
mockAssetsByKey.set('tag:models', assets)
const wrapper = createWrapper({ assetType: 'models' })
await flushPromises()
@@ -277,7 +287,7 @@ describe('AssetBrowserModal', () => {
describe('Asset Selection', () => {
it('emits asset-select event when asset is selected', async () => {
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()
@@ -290,7 +300,7 @@ describe('AssetBrowserModal', () => {
it('executes onSelect callback when provided', async () => {
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const onSelect = vi.fn()
const wrapper = createWrapper({
@@ -333,7 +343,7 @@ describe('AssetBrowserModal', () => {
createTestAsset('asset1', 'Model A', 'checkpoints'),
createTestAsset('asset2', 'Model B', 'loras')
]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({
nodeType: 'CheckpointLoaderSimple',
@@ -366,7 +376,7 @@ describe('AssetBrowserModal', () => {
it('passes computed contentTitle to BaseModalLayout when no title prop', async () => {
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()

View File

@@ -112,27 +112,21 @@ const cacheKey = computed(() => {
})
// Read directly from store cache - reactive to any store updates
const fetchedAssets = computed(
() => assetStore.modelAssetsByNodeType.get(cacheKey.value) ?? []
)
const fetchedAssets = computed(() => assetStore.getAssets(cacheKey.value))
const isStoreLoading = computed(
() => assetStore.modelLoadingByNodeType.get(cacheKey.value) ?? false
)
const isStoreLoading = computed(() => assetStore.isModelLoading(cacheKey.value))
// Only show loading spinner when loading AND no cached data
const isLoading = computed(
() => isStoreLoading.value && fetchedAssets.value.length === 0
)
async function refreshAssets(): Promise<AssetItem[]> {
async function refreshAssets(): Promise<void> {
if (props.nodeType) {
return await assetStore.updateModelsForNodeType(props.nodeType)
await assetStore.updateModelsForNodeType(props.nodeType)
} else if (props.assetType) {
await assetStore.updateModelsForTag(props.assetType)
}
if (props.assetType) {
return await assetStore.updateModelsForTag(props.assetType)
}
return []
}
// Trigger background refresh on mount

View File

@@ -160,7 +160,7 @@ describe('assetService', () => {
const result = await assetService.getAssetModels('checkpoints')
expect(api.fetchApi).toHaveBeenCalledWith(
'/assets?include_tags=models,checkpoints&limit=500'
'/assets?include_tags=models%2Ccheckpoints&limit=500'
)
expect(result).toEqual([
expect.objectContaining({ name: 'valid.safetensors', pathIndex: 0 })
@@ -231,9 +231,9 @@ describe('assetService', () => {
)
expect(result).toEqual(testAssets)
// Verify API call includes correct category
// Verify API call includes correct category (comma is URL-encoded by URLSearchParams)
expect(api.fetchApi).toHaveBeenCalledWith(
'/assets?include_tags=models,checkpoints&limit=500'
'/assets?include_tags=models%2Ccheckpoints&limit=500'
)
})
@@ -400,7 +400,7 @@ describe('assetService', () => {
})
expect(api.fetchApi).toHaveBeenCalledWith(
'/assets?include_tags=models&limit=500&include_public=true&offset=50'
'/assets?include_tags=models&limit=500&offset=50&include_public=true'
)
expect(result).toEqual(testAssets)
})
@@ -415,7 +415,7 @@ describe('assetService', () => {
})
expect(api.fetchApi).toHaveBeenCalledWith(
'/assets?include_tags=input&limit=100&include_public=false&offset=25'
'/assets?include_tags=input&limit=100&offset=25&include_public=false'
)
expect(result).toEqual(testAssets)
})

View File

@@ -1,6 +1,7 @@
import { fromZodError } from 'zod-validation-error'
import { st } from '@/i18n'
import {
assetItemSchema,
assetResponseSchema,
@@ -17,6 +18,16 @@ import type {
import { api } from '@/scripts/api'
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
export interface PaginationOptions {
limit?: number
offset?: number
}
interface AssetRequestOptions extends PaginationOptions {
includeTags: string[]
includePublic?: boolean
}
/**
* Maps CivitAI validation error codes to localized error messages
*/
@@ -77,9 +88,27 @@ function createAssetService() {
* Handles API response with consistent error handling and Zod validation
*/
async function handleAssetRequest(
url: string,
options: AssetRequestOptions,
context: string
): Promise<AssetResponse> {
const {
includeTags,
limit = DEFAULT_LIMIT,
offset,
includePublic
} = options
const queryParams = new URLSearchParams({
include_tags: includeTags.join(','),
limit: limit.toString()
})
if (offset !== undefined && offset > 0) {
queryParams.set('offset', offset.toString())
}
if (includePublic !== undefined) {
queryParams.set('include_public', includePublic ? 'true' : 'false')
}
const url = `${ASSETS_ENDPOINT}?${queryParams.toString()}`
const res = await api.fetchApi(url)
if (!res.ok) {
throw new Error(
@@ -101,7 +130,7 @@ function createAssetService() {
*/
async function getAssetModelFolders(): Promise<ModelFolder[]> {
const data = await handleAssetRequest(
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG}&limit=${DEFAULT_LIMIT}`,
{ includeTags: [MODELS_TAG] },
'model folders'
)
@@ -130,7 +159,7 @@ function createAssetService() {
*/
async function getAssetModels(folder: string): Promise<ModelFile[]> {
const data = await handleAssetRequest(
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${folder}&limit=${DEFAULT_LIMIT}`,
{ includeTags: [MODELS_TAG, folder] },
`models for ${folder}`
)
@@ -169,9 +198,15 @@ function createAssetService() {
* and fetching all assets with that category tag
*
* @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple')
* @param options - Pagination options
* @param options.limit - Maximum number of assets to return (default: 500)
* @param options.offset - Number of assets to skip (default: 0)
* @returns Promise<AssetItem[]> - Full asset objects with preserved metadata
*/
async function getAssetsForNodeType(nodeType: string): Promise<AssetItem[]> {
async function getAssetsForNodeType(
nodeType: string,
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
): Promise<AssetItem[]> {
if (!nodeType || typeof nodeType !== 'string') {
return []
}
@@ -186,7 +221,7 @@ function createAssetService() {
// Fetch assets for this category using same API pattern as getAssetModels
const data = await handleAssetRequest(
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}&limit=${DEFAULT_LIMIT}`,
{ includeTags: [MODELS_TAG, category], limit, offset },
`assets for ${nodeType}`
)
@@ -242,23 +277,10 @@ function createAssetService() {
async function getAssetsByTag(
tag: string,
includePublic: boolean = true,
{
limit = DEFAULT_LIMIT,
offset = 0
}: { limit?: number; offset?: number } = {}
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
): Promise<AssetItem[]> {
const queryParams = new URLSearchParams({
include_tags: tag,
limit: limit.toString(),
include_public: includePublic ? 'true' : 'false'
})
if (offset > 0) {
queryParams.set('offset', offset.toString())
}
const data = await handleAssetRequest(
`${ASSETS_ENDPOINT}?${queryParams.toString()}`,
{ includeTags: [tag], limit, offset, includePublic },
`assets for tag ${tag}`
)

View File

@@ -12,9 +12,9 @@ const mockGetCategoryForNodeType = vi.fn()
vi.mock('@/stores/assetsStore', () => ({
useAssetsStore: () => ({
modelAssetsByNodeType: new Map(),
modelLoadingByNodeType: new Map(),
modelErrorByNodeType: new Map(),
getAssets: () => [],
isModelLoading: () => false,
getError: () => undefined,
updateModelsForNodeType: mockUpdateModelsForNodeType
})
}))

View File

@@ -8,17 +8,17 @@ vi.mock('@/platform/distribution/types', () => ({
isCloud: true
}))
const mockModelAssetsByNodeType = new Map<string, AssetItem[]>()
const mockModelLoadingByNodeType = new Map<string, boolean>()
const mockModelErrorByNodeType = new Map<string, Error | null>()
const mockAssetsByKey = new Map<string, AssetItem[]>()
const mockLoadingByKey = new Map<string, boolean>()
const mockErrorByKey = new Map<string, Error | undefined>()
const mockUpdateModelsForNodeType = vi.fn()
const mockGetCategoryForNodeType = vi.fn()
vi.mock('@/stores/assetsStore', () => ({
useAssetsStore: () => ({
modelAssetsByNodeType: mockModelAssetsByNodeType,
modelLoadingByNodeType: mockModelLoadingByNodeType,
modelErrorByNodeType: mockModelErrorByNodeType,
getAssets: (key: string) => mockAssetsByKey.get(key) ?? [],
isModelLoading: (key: string) => mockLoadingByKey.get(key) ?? false,
getError: (key: string) => mockErrorByKey.get(key),
updateModelsForNodeType: mockUpdateModelsForNodeType
})
}))
@@ -32,9 +32,9 @@ vi.mock('@/stores/modelToNodeStore', () => ({
describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
beforeEach(() => {
vi.clearAllMocks()
mockModelAssetsByNodeType.clear()
mockModelLoadingByNodeType.clear()
mockModelErrorByNodeType.clear()
mockAssetsByKey.clear()
mockLoadingByKey.clear()
mockErrorByKey.clear()
mockGetCategoryForNodeType.mockReturnValue(undefined)
mockUpdateModelsForNodeType.mockImplementation(
@@ -76,8 +76,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, mockAssets)
mockLoadingByKey.set(_nodeType, false)
return mockAssets
}
)
@@ -108,9 +108,9 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelErrorByNodeType.set(_nodeType, mockError)
mockModelAssetsByNodeType.set(_nodeType, [])
mockModelLoadingByNodeType.set(_nodeType, false)
mockErrorByKey.set(_nodeType, mockError)
mockAssetsByKey.set(_nodeType, [])
mockLoadingByKey.set(_nodeType, false)
return []
}
)
@@ -130,8 +130,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, [])
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, [])
mockLoadingByKey.set(_nodeType, false)
return []
}
)
@@ -154,8 +154,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, mockAssets)
mockLoadingByKey.set(_nodeType, false)
return mockAssets
}
)
@@ -182,8 +182,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockGetCategoryForNodeType.mockReturnValue('loras')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, mockAssets)
mockLoadingByKey.set(_nodeType, false)
return mockAssets
}
)
@@ -209,8 +209,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, mockAssets)
mockLoadingByKey.set(_nodeType, false)
return mockAssets
}
)

View File

@@ -34,23 +34,17 @@ export function useAssetWidgetData(
const assets = computed<AssetItem[]>(() => {
const resolvedType = toValue(nodeType)
return resolvedType
? (assetsStore.modelAssetsByNodeType.get(resolvedType) ?? [])
: []
return resolvedType ? (assetsStore.getAssets(resolvedType) ?? []) : []
})
const isLoading = computed(() => {
const resolvedType = toValue(nodeType)
return resolvedType
? (assetsStore.modelLoadingByNodeType.get(resolvedType) ?? false)
: false
return resolvedType ? assetsStore.isModelLoading(resolvedType) : false
})
const error = computed<Error | null>(() => {
const resolvedType = toValue(nodeType)
return resolvedType
? (assetsStore.modelErrorByNodeType.get(resolvedType) ?? null)
: null
return resolvedType ? (assetsStore.getError(resolvedType) ?? null) : null
})
const dropdownItems = computed<DropdownItem[]>(() => {
@@ -71,7 +65,8 @@ export function useAssetWidgetData(
return
}
const hasData = assetsStore.modelAssetsByNodeType.has(currentNodeType)
const existingAssets = assetsStore.getAssets(currentNodeType) ?? []
const hasData = existingAssets.length > 0
if (!hasData) {
await assetsStore.updateModelsForNodeType(currentNodeType)

View File

@@ -1,9 +1,12 @@
import { createPinia, setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createTestingPinia } from '@pinia/testing'
import { setActivePinia } from 'pinia'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { nextTick, watch } from 'vue'
import { useAssetsStore } from '@/stores/assetsStore'
import { api } from '@/scripts/api'
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
import { assetService } from '@/platform/assets/services/assetService'
// Mock the api module
vi.mock('@/scripts/api', () => ({
@@ -20,13 +23,17 @@ vi.mock('@/scripts/api', () => ({
// Mock the asset service
vi.mock('@/platform/assets/services/assetService', () => ({
assetService: {
getAssetsByTag: vi.fn()
getAssetsByTag: vi.fn(),
getAssetsForNodeType: vi.fn()
}
}))
// Mock distribution type
// Mock distribution type - hoisted so it can be changed per test
const mockIsCloud = vi.hoisted(() => ({ value: false }))
vi.mock('@/platform/distribution/types', () => ({
isCloud: false
get isCloud() {
return mockIsCloud.value
}
}))
// Mock TaskItemImpl
@@ -115,7 +122,7 @@ describe('assetsStore - Refactored (Option A)', () => {
})
beforeEach(() => {
setActivePinia(createPinia())
setActivePinia(createTestingPinia({ stubActions: false }))
store = useAssetsStore()
vi.clearAllMocks()
})
@@ -453,3 +460,158 @@ describe('assetsStore - Refactored (Option A)', () => {
})
})
})
describe('assetsStore - Model Assets Cache (Cloud)', () => {
beforeEach(() => {
setActivePinia(createTestingPinia({ stubActions: false }))
mockIsCloud.value = true
vi.clearAllMocks()
})
afterEach(() => {
mockIsCloud.value = false
})
const createMockAsset = (id: string) => ({
id,
name: `asset-${id}`,
size: 100,
created_at: new Date().toISOString(),
tags: ['models'],
preview_url: `http://test.com/${id}`
})
describe('getAssets cache invalidation', () => {
it('should invalidate cache before mutating assets during batch loading', async () => {
const store = useAssetsStore()
const nodeType = 'CheckpointLoaderSimple'
const firstBatch = Array.from({ length: 500 }, (_, i) =>
createMockAsset(`asset-${i}`)
)
const secondBatch = Array.from({ length: 100 }, (_, i) =>
createMockAsset(`asset-${500 + i}`)
)
let callCount = 0
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
async () => {
callCount++
return callCount === 1 ? firstBatch : secondBatch
}
)
await store.updateModelsForNodeType(nodeType)
// Wait for background batch loading to complete
await vi.waitFor(() => {
expect(
vi.mocked(assetService.getAssetsForNodeType)
).toHaveBeenCalledTimes(2)
})
const assets = store.getAssets(nodeType)
expect(assets).toHaveLength(600)
})
it('should not return stale cached array after background batch completes', async () => {
const store = useAssetsStore()
const nodeType = 'LoraLoader'
// First batch must be exactly MODEL_BATCH_SIZE (500) to trigger hasMore
const firstBatch = Array.from({ length: 500 }, (_, i) =>
createMockAsset(`first-${i}`)
)
const secondBatch = [createMockAsset('new-asset')]
let callCount = 0
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
async () => {
callCount++
return callCount === 1 ? firstBatch : secondBatch
}
)
await store.updateModelsForNodeType(nodeType)
// Wait for background batch loading to complete
await vi.waitFor(() => {
expect(
vi.mocked(assetService.getAssetsForNodeType)
).toHaveBeenCalledTimes(2)
})
const assets = store.getAssets(nodeType)
expect(assets).toHaveLength(501)
expect(assets.map((a) => a.id)).toContain('new-asset')
})
it('should return cached array on subsequent getAssets calls', () => {
const store = useAssetsStore()
const nodeType = 'TestLoader'
const firstCall = store.getAssets(nodeType)
const secondCall = store.getAssets(nodeType)
expect(secondCall).toBe(firstCall)
})
})
describe('concurrent request handling', () => {
it('should discard stale request when newer request starts', async () => {
const store = useAssetsStore()
const nodeType = 'CheckpointLoaderSimple'
const firstBatch = Array.from({ length: 5 }, (_, i) =>
createMockAsset(`first-${i}`)
)
const secondBatch = Array.from({ length: 10 }, (_, i) =>
createMockAsset(`second-${i}`)
)
let resolveFirst: (value: ReturnType<typeof createMockAsset>[]) => void
const firstPromise = new Promise<ReturnType<typeof createMockAsset>[]>(
(resolve) => {
resolveFirst = resolve
}
)
let callCount = 0
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
async () => {
callCount++
return callCount === 1 ? firstPromise : secondBatch
}
)
const firstRequest = store.updateModelsForNodeType(nodeType)
const secondRequest = store.updateModelsForNodeType(nodeType)
resolveFirst!(firstBatch)
await Promise.all([firstRequest, secondRequest])
expect(store.getAssets(nodeType)).toHaveLength(10)
expect(
store.getAssets(nodeType).every((a) => a.id.startsWith('second-'))
).toBe(true)
})
})
describe('shallowReactive state reactivity', () => {
it('should trigger reactivity on isModelLoading change', async () => {
const store = useAssetsStore()
const nodeType = 'CheckpointLoaderSimple'
const loadingStates: boolean[] = []
watch(
() => store.isModelLoading(nodeType),
(val) => loadingStates.push(val),
{ immediate: true }
)
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([])
await store.updateModelsForNodeType(nodeType)
await nextTick()
expect(loadingStates).toContain(true)
expect(loadingStates).toContain(false)
})
})
})

View File

@@ -1,13 +1,13 @@
import { useAsyncState, whenever } from '@vueuse/core'
import { isEqual } from 'es-toolkit'
import { defineStore } from 'pinia'
import { computed, shallowReactive, ref } from 'vue'
import { computed, reactive, ref, shallowReactive } from 'vue'
import {
mapInputFileToAssetItem,
mapTaskOutputToAssetItem
} from '@/platform/assets/composables/media/assetMappers'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { assetService } from '@/platform/assets/services/assetService'
import type { PaginationOptions } from '@/platform/assets/services/assetService'
import { isCloud } from '@/platform/distribution/types'
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
import { api } from '@/scripts/api'
@@ -251,6 +251,16 @@ export const useAssetsStore = defineStore('assets', () => {
return inputAssetsByFilename.value.get(filename)?.name ?? filename
}
const MODEL_BATCH_SIZE = 500
interface ModelPaginationState {
assets: Map<string, AssetItem>
offset: number
hasMore: boolean
isLoading: boolean
error?: Error
}
/**
* Model assets cached by node type (e.g., 'CheckpointLoaderSimple', 'LoraLoader')
* Used by multiple loader nodes to avoid duplicate fetches
@@ -258,109 +268,183 @@ export const useAssetsStore = defineStore('assets', () => {
*/
const getModelState = () => {
if (isCloud) {
const modelAssetsByNodeType = shallowReactive(
new Map<string, AssetItem[]>()
)
const modelLoadingByNodeType = shallowReactive(new Map<string, boolean>())
const modelErrorByNodeType = shallowReactive(
new Map<string, Error | null>()
)
const modelStateByKey = ref(new Map<string, ModelPaginationState>())
const stateByNodeType = shallowReactive(
new Map<string, ReturnType<typeof useAsyncState<AssetItem[]>>>()
)
const assetsArrayCache = new Map<
string,
{ source: Map<string, AssetItem>; array: AssetItem[] }
>()
const pendingRequestByKey = new Map<string, ModelPaginationState>()
function createState(): ModelPaginationState {
return reactive({
assets: new Map(),
offset: 0,
hasMore: true,
isLoading: false
})
}
function isStale(key: string, state: ModelPaginationState): boolean {
const committed = modelStateByKey.value.get(key)
const pending = pendingRequestByKey.get(key)
return committed !== state && pending !== state
}
const EMPTY_ASSETS: AssetItem[] = []
function getAssets(key: string): AssetItem[] {
const state = modelStateByKey.value.get(key)
const assetsMap = state?.assets
if (!assetsMap) return EMPTY_ASSETS
const cached = assetsArrayCache.get(key)
if (cached && cached.source === assetsMap) {
return cached.array
}
const array = Array.from(assetsMap.values())
assetsArrayCache.set(key, { source: assetsMap, array })
return array
}
function isLoading(key: string): boolean {
return modelStateByKey.value.get(key)?.isLoading ?? false
}
function getError(key: string): Error | undefined {
return modelStateByKey.value.get(key)?.error
}
function hasMore(key: string): boolean {
return modelStateByKey.value.get(key)?.hasMore ?? false
}
/**
* Internal helper to fetch and cache assets with a given key and fetcher
* Internal helper to fetch and cache assets with a given key and fetcher.
* Loads first batch immediately, then progressively loads remaining batches.
* Keeps existing data visible until new data is successfully fetched.
*/
async function updateModelsForKey(
key: string,
fetcher: () => Promise<AssetItem[]>
): Promise<AssetItem[]> {
if (!stateByNodeType.has(key)) {
stateByNodeType.set(
key,
useAsyncState(fetcher, [], {
immediate: false,
resetOnExecute: false,
onError: (err) => {
console.error(`Error fetching model assets for ${key}:`, err)
}
})
)
fetcher: (options: PaginationOptions) => Promise<AssetItem[]>
): Promise<void> {
const state = createState()
state.isLoading = true
const hasExistingData = modelStateByKey.value.has(key)
if (hasExistingData) {
pendingRequestByKey.set(key, state)
} else {
modelStateByKey.value.set(key, state)
}
const state = stateByNodeType.get(key)!
modelLoadingByNodeType.set(key, true)
modelErrorByNodeType.set(key, null)
async function loadBatches(): Promise<void> {
while (state.hasMore) {
try {
await state.execute()
} finally {
modelLoadingByNodeType.set(key, state.isLoading.value)
const newAssets = await fetcher({
limit: MODEL_BATCH_SIZE,
offset: state.offset
})
if (isStale(key, state)) return
const isFirstBatch = state.offset === 0
if (isFirstBatch) {
assetsArrayCache.delete(key)
if (hasExistingData) {
pendingRequestByKey.delete(key)
modelStateByKey.value.set(key, state)
}
const assets = state.state.value
const existingAssets = modelAssetsByNodeType.get(key)
if (!isEqual(existingAssets, assets)) {
modelAssetsByNodeType.set(key, assets)
}
modelErrorByNodeType.set(
key,
state.error.value instanceof Error ? state.error.value : null
state.assets = new Map(newAssets.map((a) => [a.id, a]))
} else {
const assetsToAdd = newAssets.filter(
(a) => !state.assets.has(a.id)
)
if (assetsToAdd.length > 0) {
assetsArrayCache.delete(key)
for (const asset of assetsToAdd) {
state.assets.set(asset.id, asset)
}
}
}
return assets
state.offset += newAssets.length
state.hasMore = newAssets.length === MODEL_BATCH_SIZE
if (isFirstBatch) {
state.isLoading = false
}
if (state.hasMore) {
await new Promise((resolve) => setTimeout(resolve, 50))
}
} catch (err) {
if (isStale(key, state)) return
state.error = err instanceof Error ? err : new Error(String(err))
state.hasMore = false
console.error(`Error loading batch for ${key}:`, err)
if (state.offset === 0) {
state.isLoading = false
pendingRequestByKey.delete(key)
// TODO: Add toast indicator for first-batch load failures
}
return
}
}
}
await loadBatches()
}
/**
* Fetch and cache model assets for a specific node type
* @param nodeType The node type to fetch assets for (e.g., 'CheckpointLoaderSimple')
* @returns Promise resolving to the fetched assets
*/
async function updateModelsForNodeType(
nodeType: string
): Promise<AssetItem[]> {
return updateModelsForKey(nodeType, () =>
assetService.getAssetsForNodeType(nodeType)
async function updateModelsForNodeType(nodeType: string): Promise<void> {
await updateModelsForKey(nodeType, (opts) =>
assetService.getAssetsForNodeType(nodeType, opts)
)
}
/**
* Fetch and cache model assets for a specific tag
* @param tag The tag to fetch assets for (e.g., 'models')
* @returns Promise resolving to the fetched assets
*/
async function updateModelsForTag(tag: string): Promise<AssetItem[]> {
async function updateModelsForTag(tag: string): Promise<void> {
const key = `tag:${tag}`
return updateModelsForKey(key, () => assetService.getAssetsByTag(tag))
await updateModelsForKey(key, (opts) =>
assetService.getAssetsByTag(tag, true, opts)
)
}
return {
modelAssetsByNodeType,
modelLoadingByNodeType,
modelErrorByNodeType,
getAssets,
isLoading,
getError,
hasMore,
updateModelsForNodeType,
updateModelsForTag
}
}
const emptyAssets: AssetItem[] = []
return {
modelAssetsByNodeType: shallowReactive(new Map<string, AssetItem[]>()),
modelLoadingByNodeType: shallowReactive(new Map<string, boolean>()),
modelErrorByNodeType: shallowReactive(new Map<string, Error | null>()),
updateModelsForNodeType: async () => [],
updateModelsForTag: async () => []
getAssets: () => emptyAssets,
isLoading: () => false,
getError: () => undefined,
hasMore: () => false,
updateModelsForNodeType: async () => {},
updateModelsForTag: async () => {}
}
}
const {
modelAssetsByNodeType,
modelLoadingByNodeType,
modelErrorByNodeType,
getAssets,
isLoading: isModelLoading,
getError,
hasMore,
updateModelsForNodeType,
updateModelsForTag
} = getModelState()
@@ -422,10 +506,13 @@ export const useAssetsStore = defineStore('assets', () => {
inputAssetsByFilename,
getInputName,
// Model assets
modelAssetsByNodeType,
modelLoadingByNodeType,
modelErrorByNodeType,
// Model assets - accessors
getAssets,
isModelLoading,
getError,
hasMore,
// Model assets - actions
updateModelsForNodeType,
updateModelsForTag
}