feat: implement progressive pagination for Asset Browser model assets

This commit is contained in:
Alexander Brown
2026-01-17 18:42:55 -08:00
parent d9e0577df4
commit f5b422d493
7 changed files with 224 additions and 137 deletions

View File

@@ -6,6 +6,9 @@ import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vu
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { useAssetsStore } from '@/stores/assetsStore'
const mockAssetsByKey = vi.hoisted(() => new Map<string, AssetItem[]>())
const mockLoadingByKey = vi.hoisted(() => new Map<string, boolean>())
vi.mock('@/i18n', () => ({
t: (key: string, params?: Record<string, string>) =>
params ? `${key}:${JSON.stringify(params)}` : key,
@@ -13,13 +16,20 @@ vi.mock('@/i18n', () => ({
}))
vi.mock('@/stores/assetsStore', () => {
const store = {
modelAssetsByNodeType: new Map<string, AssetItem[]>(),
modelLoadingByNodeType: new Map<string, boolean>(),
updateModelsForNodeType: vi.fn(),
updateModelsForTag: vi.fn()
const getAssets = vi.fn((key: string) => mockAssetsByKey.get(key) ?? [])
const isModelLoading = vi.fn(
(key: string) => mockLoadingByKey.get(key) ?? false
)
const updateModelsForNodeType = vi.fn()
const updateModelsForTag = vi.fn()
return {
useAssetsStore: () => ({
getAssets,
isModelLoading,
updateModelsForNodeType,
updateModelsForTag
})
}
return { useAssetsStore: () => store }
})
vi.mock('@/stores/modelToNodeStore', () => ({
@@ -183,12 +193,10 @@ describe('AssetBrowserModal', () => {
})
}
const mockStore = useAssetsStore()
beforeEach(() => {
vi.resetAllMocks()
mockStore.modelAssetsByNodeType.clear()
mockStore.modelLoadingByNodeType.clear()
mockAssetsByKey.clear()
mockLoadingByKey.clear()
})
describe('Integration with useAssetBrowser', () => {
@@ -197,7 +205,7 @@ describe('AssetBrowserModal', () => {
createTestAsset('asset1', 'Model A', 'checkpoints'),
createTestAsset('asset2', 'Model B', 'loras')
]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()
@@ -214,7 +222,7 @@ describe('AssetBrowserModal', () => {
createTestAsset('c1', 'model.safetensors', 'checkpoints'),
createTestAsset('l1', 'lora.pt', 'loras')
]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({
nodeType: 'CheckpointLoaderSimple',
@@ -231,17 +239,18 @@ describe('AssetBrowserModal', () => {
describe('Data fetching', () => {
it('triggers store refresh for node type on mount', async () => {
const store = useAssetsStore()
createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()
expect(mockStore.updateModelsForNodeType).toHaveBeenCalledWith(
expect(store.updateModelsForNodeType).toHaveBeenCalledWith(
'CheckpointLoaderSimple'
)
})
it('displays cached assets immediately from store', async () => {
const assets = [createTestAsset('asset1', 'Cached Model', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
@@ -253,15 +262,16 @@ describe('AssetBrowserModal', () => {
})
it('triggers store refresh for asset type (tag) on mount', async () => {
const store = useAssetsStore()
createWrapper({ assetType: 'models' })
await flushPromises()
expect(mockStore.updateModelsForTag).toHaveBeenCalledWith('models')
expect(store.updateModelsForTag).toHaveBeenCalledWith('models')
})
it('uses tag: prefix for cache key when assetType is provided', async () => {
const assets = [createTestAsset('asset1', 'Tagged Model', 'models')]
mockStore.modelAssetsByNodeType.set('tag:models', assets)
mockAssetsByKey.set('tag:models', assets)
const wrapper = createWrapper({ assetType: 'models' })
await flushPromises()
@@ -277,7 +287,7 @@ describe('AssetBrowserModal', () => {
describe('Asset Selection', () => {
it('emits asset-select event when asset is selected', async () => {
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()
@@ -290,7 +300,7 @@ describe('AssetBrowserModal', () => {
it('executes onSelect callback when provided', async () => {
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const onSelect = vi.fn()
const wrapper = createWrapper({
@@ -333,7 +343,7 @@ describe('AssetBrowserModal', () => {
createTestAsset('asset1', 'Model A', 'checkpoints'),
createTestAsset('asset2', 'Model B', 'loras')
]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({
nodeType: 'CheckpointLoaderSimple',
@@ -366,7 +376,7 @@ describe('AssetBrowserModal', () => {
it('passes computed contentTitle to BaseModalLayout when no title prop', async () => {
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
await flushPromises()

View File

@@ -129,13 +129,9 @@ const cacheKey = computed(() => {
})
// Read directly from store cache - reactive to any store updates
const fetchedAssets = computed(
() => assetStore.modelAssetsByNodeType.get(cacheKey.value) ?? []
)
const fetchedAssets = computed(() => assetStore.getAssets(cacheKey.value))
const isStoreLoading = computed(
() => assetStore.modelLoadingByNodeType.get(cacheKey.value) ?? false
)
const isStoreLoading = computed(() => assetStore.isModelLoading(cacheKey.value))
// Only show loading spinner when loading AND no cached data
const isLoading = computed(

View File

@@ -1,6 +1,11 @@
import { fromZodError } from 'zod-validation-error'
import { st } from '@/i18n'
export interface PaginationOptions {
limit?: number
offset?: number
}
import {
assetItemSchema,
assetResponseSchema,
@@ -170,9 +175,15 @@ function createAssetService() {
* and fetching all assets with that category tag
*
* @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple')
* @param options - Pagination options
* @param options.limit - Maximum number of assets to return (default: 500)
* @param options.offset - Number of assets to skip (default: 0)
* @returns Promise<AssetItem[]> - Full asset objects with preserved metadata
*/
async function getAssetsForNodeType(nodeType: string): Promise<AssetItem[]> {
async function getAssetsForNodeType(
nodeType: string,
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
): Promise<AssetItem[]> {
if (!nodeType || typeof nodeType !== 'string') {
return []
}
@@ -185,9 +196,18 @@ function createAssetService() {
return []
}
const queryParams = new URLSearchParams({
include_tags: `${MODELS_TAG},${category}`,
limit: limit.toString()
})
if (offset > 0) {
queryParams.set('offset', offset.toString())
}
// Fetch assets for this category using same API pattern as getAssetModels
const data = await handleAssetRequest(
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}&limit=${DEFAULT_LIMIT}`,
`${ASSETS_ENDPOINT}?${queryParams.toString()}`,
`assets for ${nodeType}`
)
@@ -243,10 +263,7 @@ function createAssetService() {
async function getAssetsByTag(
tag: string,
includePublic: boolean = true,
{
limit = DEFAULT_LIMIT,
offset = 0
}: { limit?: number; offset?: number } = {}
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
): Promise<AssetItem[]> {
const queryParams = new URLSearchParams({
include_tags: tag,

View File

@@ -12,9 +12,9 @@ const mockGetCategoryForNodeType = vi.fn()
vi.mock('@/stores/assetsStore', () => ({
useAssetsStore: () => ({
modelAssetsByNodeType: new Map(),
modelLoadingByNodeType: new Map(),
modelErrorByNodeType: new Map(),
getAssets: () => [],
isModelLoading: () => false,
getError: () => undefined,
updateModelsForNodeType: mockUpdateModelsForNodeType
})
}))

View File

@@ -8,17 +8,17 @@ vi.mock('@/platform/distribution/types', () => ({
isCloud: true
}))
const mockModelAssetsByNodeType = new Map<string, AssetItem[]>()
const mockModelLoadingByNodeType = new Map<string, boolean>()
const mockModelErrorByNodeType = new Map<string, Error | null>()
const mockAssetsByKey = new Map<string, AssetItem[]>()
const mockLoadingByKey = new Map<string, boolean>()
const mockErrorByKey = new Map<string, Error | undefined>()
const mockUpdateModelsForNodeType = vi.fn()
const mockGetCategoryForNodeType = vi.fn()
vi.mock('@/stores/assetsStore', () => ({
useAssetsStore: () => ({
modelAssetsByNodeType: mockModelAssetsByNodeType,
modelLoadingByNodeType: mockModelLoadingByNodeType,
modelErrorByNodeType: mockModelErrorByNodeType,
getAssets: (key: string) => mockAssetsByKey.get(key) ?? [],
isModelLoading: (key: string) => mockLoadingByKey.get(key) ?? false,
getError: (key: string) => mockErrorByKey.get(key),
updateModelsForNodeType: mockUpdateModelsForNodeType
})
}))
@@ -32,9 +32,9 @@ vi.mock('@/stores/modelToNodeStore', () => ({
describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
beforeEach(() => {
vi.clearAllMocks()
mockModelAssetsByNodeType.clear()
mockModelLoadingByNodeType.clear()
mockModelErrorByNodeType.clear()
mockAssetsByKey.clear()
mockLoadingByKey.clear()
mockErrorByKey.clear()
mockGetCategoryForNodeType.mockReturnValue(undefined)
mockUpdateModelsForNodeType.mockImplementation(
@@ -76,8 +76,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, mockAssets)
mockLoadingByKey.set(_nodeType, false)
return mockAssets
}
)
@@ -108,9 +108,9 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelErrorByNodeType.set(_nodeType, mockError)
mockModelAssetsByNodeType.set(_nodeType, [])
mockModelLoadingByNodeType.set(_nodeType, false)
mockErrorByKey.set(_nodeType, mockError)
mockAssetsByKey.set(_nodeType, [])
mockLoadingByKey.set(_nodeType, false)
return []
}
)
@@ -130,8 +130,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, [])
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, [])
mockLoadingByKey.set(_nodeType, false)
return []
}
)
@@ -154,8 +154,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, mockAssets)
mockLoadingByKey.set(_nodeType, false)
return mockAssets
}
)
@@ -182,8 +182,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockGetCategoryForNodeType.mockReturnValue('loras')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, mockAssets)
mockLoadingByKey.set(_nodeType, false)
return mockAssets
}
)
@@ -209,8 +209,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
mockAssetsByKey.set(_nodeType, mockAssets)
mockLoadingByKey.set(_nodeType, false)
return mockAssets
}
)

View File

@@ -34,23 +34,17 @@ export function useAssetWidgetData(
const assets = computed<AssetItem[]>(() => {
const resolvedType = toValue(nodeType)
return resolvedType
? (assetsStore.modelAssetsByNodeType.get(resolvedType) ?? [])
: []
return resolvedType ? assetsStore.getAssets(resolvedType) : []
})
const isLoading = computed(() => {
const resolvedType = toValue(nodeType)
return resolvedType
? (assetsStore.modelLoadingByNodeType.get(resolvedType) ?? false)
: false
return resolvedType ? assetsStore.isModelLoading(resolvedType) : false
})
const error = computed<Error | null>(() => {
const resolvedType = toValue(nodeType)
return resolvedType
? (assetsStore.modelErrorByNodeType.get(resolvedType) ?? null)
: null
return resolvedType ? (assetsStore.getError(resolvedType) ?? null) : null
})
const dropdownItems = computed<DropdownItem[]>(() => {
@@ -71,7 +65,7 @@ export function useAssetWidgetData(
return
}
const hasData = assetsStore.modelAssetsByNodeType.has(currentNodeType)
const hasData = assetsStore.getAssets(currentNodeType).length > 0
if (!hasData) {
await assetsStore.updateModelsForNodeType(currentNodeType)

View File

@@ -1,13 +1,13 @@
import { useAsyncState, whenever } from '@vueuse/core'
import { isEqual } from 'es-toolkit'
import { defineStore } from 'pinia'
import { computed, shallowReactive, ref } from 'vue'
import { computed, reactive, ref, shallowReactive } from 'vue'
import {
mapInputFileToAssetItem,
mapTaskOutputToAssetItem
} from '@/platform/assets/composables/media/assetMappers'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { assetService } from '@/platform/assets/services/assetService'
import type { PaginationOptions } from '@/platform/assets/services/assetService'
import { isCloud } from '@/platform/distribution/types'
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
import { api } from '@/scripts/api'
@@ -252,6 +252,16 @@ export const useAssetsStore = defineStore('assets', () => {
return inputAssetsByFilename.value.get(filename)?.name ?? filename
}
const MODEL_BATCH_SIZE = 500
interface ModelPaginationState {
assets: Map<string, AssetItem>
offset: number
hasMore: boolean
isLoading: boolean
error?: Error
}
/**
* Model assets cached by node type (e.g., 'CheckpointLoaderSimple', 'LoraLoader')
* Used by multiple loader nodes to avoid duplicate fetches
@@ -259,62 +269,116 @@ export const useAssetsStore = defineStore('assets', () => {
*/
const getModelState = () => {
if (isCloud) {
const modelAssetsByNodeType = shallowReactive(
new Map<string, AssetItem[]>()
)
const modelLoadingByNodeType = shallowReactive(new Map<string, boolean>())
const modelErrorByNodeType = shallowReactive(
new Map<string, Error | null>()
)
const modelStateByKey = ref(new Map<string, ModelPaginationState>())
const stateByNodeType = shallowReactive(
new Map<string, ReturnType<typeof useAsyncState<AssetItem[]>>>()
)
function createInitialState(): ModelPaginationState {
const state: ModelPaginationState = {
assets: new Map(),
offset: 0,
hasMore: true,
isLoading: false
}
return reactive(state)
}
function getOrCreateState(key: string): ModelPaginationState {
if (!modelStateByKey.value.has(key)) {
modelStateByKey.value.set(key, createInitialState())
}
return modelStateByKey.value.get(key)!
}
function resetPaginationForKey(key: string) {
const state = getOrCreateState(key)
state.assets = new Map()
state.offset = 0
state.hasMore = true
delete state.error
}
function getAssets(key: string): AssetItem[] {
return Array.from(modelStateByKey.value.get(key)?.assets.values() ?? [])
}
function isLoading(key: string): boolean {
return modelStateByKey.value.get(key)?.isLoading ?? false
}
function getError(key: string): Error | undefined {
return modelStateByKey.value.get(key)?.error
}
function hasMore(key: string): boolean {
return modelStateByKey.value.get(key)?.hasMore ?? false
}
/**
* Internal helper to fetch and cache assets with a given key and fetcher
* Internal helper to fetch and cache assets with a given key and fetcher.
* Loads first batch immediately, then progressively loads remaining batches.
*/
async function updateModelsForKey(
key: string,
fetcher: () => Promise<AssetItem[]>
fetcher: (options: PaginationOptions) => Promise<AssetItem[]>
): Promise<AssetItem[]> {
if (!stateByNodeType.has(key)) {
stateByNodeType.set(
key,
useAsyncState(fetcher, [], {
immediate: false,
resetOnExecute: false,
onError: (err) => {
console.error(`Error fetching model assets for ${key}:`, err)
}
})
)
}
const state = getOrCreateState(key)
const state = stateByNodeType.get(key)!
modelLoadingByNodeType.set(key, true)
modelErrorByNodeType.set(key, null)
resetPaginationForKey(key)
state.isLoading = true
try {
await state.execute()
const assets = await fetcher({
limit: MODEL_BATCH_SIZE,
offset: 0
})
state.assets = new Map(assets.map((a) => [a.id, a]))
state.offset = assets.length
state.hasMore = assets.length === MODEL_BATCH_SIZE
if (state.hasMore) {
void loadRemainingBatches(key, fetcher)
}
return assets
} catch (err) {
state.error = err instanceof Error ? err : new Error(String(err))
console.error(`Error fetching model assets for ${key}:`, err)
return []
} finally {
modelLoadingByNodeType.set(key, state.isLoading.value)
state.isLoading = false
}
}
const assets = state.state.value
const existingAssets = modelAssetsByNodeType.get(key)
/**
* Progressively load remaining batches until complete
*/
async function loadRemainingBatches(
key: string,
fetcher: (options: PaginationOptions) => Promise<AssetItem[]>
): Promise<void> {
const state = modelStateByKey.value.get(key)
if (!state) return
if (!isEqual(existingAssets, assets)) {
modelAssetsByNodeType.set(key, assets)
while (state.hasMore) {
try {
const newAssets = await fetcher({
limit: MODEL_BATCH_SIZE,
offset: state.offset
})
for (const asset of newAssets) {
if (!state.assets.has(asset.id)) {
state.assets.set(asset.id, asset)
}
}
state.offset += newAssets.length
state.hasMore = newAssets.length === MODEL_BATCH_SIZE
} catch (err) {
console.error(`Error loading batch for ${key}:`, err)
break
}
}
modelErrorByNodeType.set(
key,
state.error.value instanceof Error ? state.error.value : null
)
return assets
}
/**
@@ -325,8 +389,8 @@ export const useAssetsStore = defineStore('assets', () => {
async function updateModelsForNodeType(
nodeType: string
): Promise<AssetItem[]> {
return updateModelsForKey(nodeType, () =>
assetService.getAssetsForNodeType(nodeType)
return updateModelsForKey(nodeType, (opts) =>
assetService.getAssetsForNodeType(nodeType, opts)
)
}
@@ -337,7 +401,9 @@ export const useAssetsStore = defineStore('assets', () => {
*/
async function updateModelsForTag(tag: string): Promise<AssetItem[]> {
const key = `tag:${tag}`
return updateModelsForKey(key, () => assetService.getAssetsByTag(tag))
return updateModelsForKey(key, (opts) =>
assetService.getAssetsByTag(tag, true, opts)
)
}
/**
@@ -353,18 +419,15 @@ export const useAssetsStore = defineStore('assets', () => {
) {
const keysToCheck = cacheKey
? [cacheKey]
: Array.from(modelAssetsByNodeType.keys())
: Array.from(modelStateByKey.value.keys())
for (const key of keysToCheck) {
const assets = modelAssetsByNodeType.get(key)
if (!assets) continue
const state = modelStateByKey.value.get(key)
if (!state) continue
const index = assets.findIndex((a) => a.id === assetId)
if (index !== -1) {
const updatedAsset = { ...assets[index], ...updates }
const newAssets = [...assets]
newAssets[index] = updatedAsset
modelAssetsByNodeType.set(key, newAssets)
const existingAsset = state.assets.get(assetId)
if (existingAsset) {
state.assets.set(assetId, { ...existingAsset, ...updates })
if (cacheKey) return
}
}
@@ -401,9 +464,10 @@ export const useAssetsStore = defineStore('assets', () => {
}
return {
modelAssetsByNodeType,
modelLoadingByNodeType,
modelErrorByNodeType,
getAssets,
isLoading,
getError,
hasMore,
updateModelsForNodeType,
updateModelsForTag,
updateAssetMetadata,
@@ -411,21 +475,24 @@ export const useAssetsStore = defineStore('assets', () => {
}
}
const emptyAssets: AssetItem[] = []
return {
modelAssetsByNodeType: shallowReactive(new Map<string, AssetItem[]>()),
modelLoadingByNodeType: shallowReactive(new Map<string, boolean>()),
modelErrorByNodeType: shallowReactive(new Map<string, Error | null>()),
updateModelsForNodeType: async () => [],
updateModelsForTag: async () => [],
getAssets: () => emptyAssets,
isLoading: () => false,
getError: () => undefined,
hasMore: () => false,
updateModelsForNodeType: async () => emptyAssets,
updateModelsForTag: async () => emptyAssets,
updateAssetMetadata: async () => {},
updateAssetTags: async () => {}
}
}
const {
modelAssetsByNodeType,
modelLoadingByNodeType,
modelErrorByNodeType,
getAssets,
isLoading: isModelLoading,
getError,
hasMore,
updateModelsForNodeType,
updateModelsForTag,
updateAssetMetadata,
@@ -489,10 +556,13 @@ export const useAssetsStore = defineStore('assets', () => {
inputAssetsByFilename,
getInputName,
// Model assets
modelAssetsByNodeType,
modelLoadingByNodeType,
modelErrorByNodeType,
// Model assets - accessors
getAssets,
isModelLoading,
getError,
hasMore,
// Model assets - actions
updateModelsForNodeType,
updateModelsForTag,
updateAssetMetadata,