mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-01-26 19:09:52 +00:00
feat: implement progressive pagination for Asset Browser model assets (#8212)
## Summary
Implements progressive pagination for model assets - returns the first
batch immediately while loading remaining batches in the background.
## Changes
### Store (`assetsStore.ts`)
- Adds `ModelPaginationState` tracking (assets Map, offset, hasMore,
loading, error)
- `updateModelsForKey()` returns first batch, then calls
`loadRemainingBatches()` to fetch the rest
- Accessor functions `getAssets(key)`, `isModelLoading(key)` replace
direct Map access
### API (`assetService.ts`)
- Adds `PaginationOptions` interface (`{ limit?, offset? }`)
### Components
- `AssetBrowserModal.vue` uses new accessor API
### Tests
- Updated mocks for new accessor pattern
┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-8212-feat-implement-progressive-pagination-for-Asset-Browser-model-assets-2ef6d73d36508157af04d1264780997e)
by [Unito](https://www.unito.io)
---------
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -6,6 +6,9 @@ import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vu
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { useAssetsStore } from '@/stores/assetsStore'
|
||||
|
||||
const mockAssetsByKey = vi.hoisted(() => new Map<string, AssetItem[]>())
|
||||
const mockLoadingByKey = vi.hoisted(() => new Map<string, boolean>())
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: (key: string, params?: Record<string, string>) =>
|
||||
params ? `${key}:${JSON.stringify(params)}` : key,
|
||||
@@ -13,13 +16,20 @@ vi.mock('@/i18n', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/assetsStore', () => {
|
||||
const store = {
|
||||
modelAssetsByNodeType: new Map<string, AssetItem[]>(),
|
||||
modelLoadingByNodeType: new Map<string, boolean>(),
|
||||
updateModelsForNodeType: vi.fn(),
|
||||
updateModelsForTag: vi.fn()
|
||||
const getAssets = vi.fn((key: string) => mockAssetsByKey.get(key) ?? [])
|
||||
const isModelLoading = vi.fn(
|
||||
(key: string) => mockLoadingByKey.get(key) ?? false
|
||||
)
|
||||
const updateModelsForNodeType = vi.fn()
|
||||
const updateModelsForTag = vi.fn()
|
||||
return {
|
||||
useAssetsStore: () => ({
|
||||
getAssets,
|
||||
isModelLoading,
|
||||
updateModelsForNodeType,
|
||||
updateModelsForTag
|
||||
})
|
||||
}
|
||||
return { useAssetsStore: () => store }
|
||||
})
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
@@ -183,12 +193,10 @@ describe('AssetBrowserModal', () => {
|
||||
})
|
||||
}
|
||||
|
||||
const mockStore = useAssetsStore()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks()
|
||||
mockStore.modelAssetsByNodeType.clear()
|
||||
mockStore.modelLoadingByNodeType.clear()
|
||||
mockAssetsByKey.clear()
|
||||
mockLoadingByKey.clear()
|
||||
})
|
||||
|
||||
describe('Integration with useAssetBrowser', () => {
|
||||
@@ -197,7 +205,7 @@ describe('AssetBrowserModal', () => {
|
||||
createTestAsset('asset1', 'Model A', 'checkpoints'),
|
||||
createTestAsset('asset2', 'Model B', 'loras')
|
||||
]
|
||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
||||
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||
|
||||
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||
await flushPromises()
|
||||
@@ -214,7 +222,7 @@ describe('AssetBrowserModal', () => {
|
||||
createTestAsset('c1', 'model.safetensors', 'checkpoints'),
|
||||
createTestAsset('l1', 'lora.pt', 'loras')
|
||||
]
|
||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
||||
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||
|
||||
const wrapper = createWrapper({
|
||||
nodeType: 'CheckpointLoaderSimple',
|
||||
@@ -231,17 +239,18 @@ describe('AssetBrowserModal', () => {
|
||||
|
||||
describe('Data fetching', () => {
|
||||
it('triggers store refresh for node type on mount', async () => {
|
||||
const store = useAssetsStore()
|
||||
createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||
await flushPromises()
|
||||
|
||||
expect(mockStore.updateModelsForNodeType).toHaveBeenCalledWith(
|
||||
expect(store.updateModelsForNodeType).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
})
|
||||
|
||||
it('displays cached assets immediately from store', async () => {
|
||||
const assets = [createTestAsset('asset1', 'Cached Model', 'checkpoints')]
|
||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
||||
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||
|
||||
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||
|
||||
@@ -253,15 +262,16 @@ describe('AssetBrowserModal', () => {
|
||||
})
|
||||
|
||||
it('triggers store refresh for asset type (tag) on mount', async () => {
|
||||
const store = useAssetsStore()
|
||||
createWrapper({ assetType: 'models' })
|
||||
await flushPromises()
|
||||
|
||||
expect(mockStore.updateModelsForTag).toHaveBeenCalledWith('models')
|
||||
expect(store.updateModelsForTag).toHaveBeenCalledWith('models')
|
||||
})
|
||||
|
||||
it('uses tag: prefix for cache key when assetType is provided', async () => {
|
||||
const assets = [createTestAsset('asset1', 'Tagged Model', 'models')]
|
||||
mockStore.modelAssetsByNodeType.set('tag:models', assets)
|
||||
mockAssetsByKey.set('tag:models', assets)
|
||||
|
||||
const wrapper = createWrapper({ assetType: 'models' })
|
||||
await flushPromises()
|
||||
@@ -277,7 +287,7 @@ describe('AssetBrowserModal', () => {
|
||||
describe('Asset Selection', () => {
|
||||
it('emits asset-select event when asset is selected', async () => {
|
||||
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
|
||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
||||
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||
|
||||
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||
await flushPromises()
|
||||
@@ -290,7 +300,7 @@ describe('AssetBrowserModal', () => {
|
||||
|
||||
it('executes onSelect callback when provided', async () => {
|
||||
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
|
||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
||||
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||
|
||||
const onSelect = vi.fn()
|
||||
const wrapper = createWrapper({
|
||||
@@ -333,7 +343,7 @@ describe('AssetBrowserModal', () => {
|
||||
createTestAsset('asset1', 'Model A', 'checkpoints'),
|
||||
createTestAsset('asset2', 'Model B', 'loras')
|
||||
]
|
||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
||||
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||
|
||||
const wrapper = createWrapper({
|
||||
nodeType: 'CheckpointLoaderSimple',
|
||||
@@ -366,7 +376,7 @@ describe('AssetBrowserModal', () => {
|
||||
|
||||
it('passes computed contentTitle to BaseModalLayout when no title prop', async () => {
|
||||
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
|
||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
||||
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||
|
||||
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||
await flushPromises()
|
||||
|
||||
@@ -112,27 +112,21 @@ const cacheKey = computed(() => {
|
||||
})
|
||||
|
||||
// Read directly from store cache - reactive to any store updates
|
||||
const fetchedAssets = computed(
|
||||
() => assetStore.modelAssetsByNodeType.get(cacheKey.value) ?? []
|
||||
)
|
||||
const fetchedAssets = computed(() => assetStore.getAssets(cacheKey.value))
|
||||
|
||||
const isStoreLoading = computed(
|
||||
() => assetStore.modelLoadingByNodeType.get(cacheKey.value) ?? false
|
||||
)
|
||||
const isStoreLoading = computed(() => assetStore.isModelLoading(cacheKey.value))
|
||||
|
||||
// Only show loading spinner when loading AND no cached data
|
||||
const isLoading = computed(
|
||||
() => isStoreLoading.value && fetchedAssets.value.length === 0
|
||||
)
|
||||
|
||||
async function refreshAssets(): Promise<AssetItem[]> {
|
||||
async function refreshAssets(): Promise<void> {
|
||||
if (props.nodeType) {
|
||||
return await assetStore.updateModelsForNodeType(props.nodeType)
|
||||
await assetStore.updateModelsForNodeType(props.nodeType)
|
||||
} else if (props.assetType) {
|
||||
await assetStore.updateModelsForTag(props.assetType)
|
||||
}
|
||||
if (props.assetType) {
|
||||
return await assetStore.updateModelsForTag(props.assetType)
|
||||
}
|
||||
return []
|
||||
}
|
||||
|
||||
// Trigger background refresh on mount
|
||||
|
||||
@@ -160,7 +160,7 @@ describe('assetService', () => {
|
||||
const result = await assetService.getAssetModels('checkpoints')
|
||||
|
||||
expect(api.fetchApi).toHaveBeenCalledWith(
|
||||
'/assets?include_tags=models,checkpoints&limit=500'
|
||||
'/assets?include_tags=models%2Ccheckpoints&limit=500'
|
||||
)
|
||||
expect(result).toEqual([
|
||||
expect.objectContaining({ name: 'valid.safetensors', pathIndex: 0 })
|
||||
@@ -231,9 +231,9 @@ describe('assetService', () => {
|
||||
)
|
||||
expect(result).toEqual(testAssets)
|
||||
|
||||
// Verify API call includes correct category
|
||||
// Verify API call includes correct category (comma is URL-encoded by URLSearchParams)
|
||||
expect(api.fetchApi).toHaveBeenCalledWith(
|
||||
'/assets?include_tags=models,checkpoints&limit=500'
|
||||
'/assets?include_tags=models%2Ccheckpoints&limit=500'
|
||||
)
|
||||
})
|
||||
|
||||
@@ -400,7 +400,7 @@ describe('assetService', () => {
|
||||
})
|
||||
|
||||
expect(api.fetchApi).toHaveBeenCalledWith(
|
||||
'/assets?include_tags=models&limit=500&include_public=true&offset=50'
|
||||
'/assets?include_tags=models&limit=500&offset=50&include_public=true'
|
||||
)
|
||||
expect(result).toEqual(testAssets)
|
||||
})
|
||||
@@ -415,7 +415,7 @@ describe('assetService', () => {
|
||||
})
|
||||
|
||||
expect(api.fetchApi).toHaveBeenCalledWith(
|
||||
'/assets?include_tags=input&limit=100&include_public=false&offset=25'
|
||||
'/assets?include_tags=input&limit=100&offset=25&include_public=false'
|
||||
)
|
||||
expect(result).toEqual(testAssets)
|
||||
})
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { fromZodError } from 'zod-validation-error'
|
||||
|
||||
import { st } from '@/i18n'
|
||||
|
||||
import {
|
||||
assetItemSchema,
|
||||
assetResponseSchema,
|
||||
@@ -17,6 +18,16 @@ import type {
|
||||
import { api } from '@/scripts/api'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
export interface PaginationOptions {
|
||||
limit?: number
|
||||
offset?: number
|
||||
}
|
||||
|
||||
interface AssetRequestOptions extends PaginationOptions {
|
||||
includeTags: string[]
|
||||
includePublic?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps CivitAI validation error codes to localized error messages
|
||||
*/
|
||||
@@ -77,9 +88,27 @@ function createAssetService() {
|
||||
* Handles API response with consistent error handling and Zod validation
|
||||
*/
|
||||
async function handleAssetRequest(
|
||||
url: string,
|
||||
options: AssetRequestOptions,
|
||||
context: string
|
||||
): Promise<AssetResponse> {
|
||||
const {
|
||||
includeTags,
|
||||
limit = DEFAULT_LIMIT,
|
||||
offset,
|
||||
includePublic
|
||||
} = options
|
||||
const queryParams = new URLSearchParams({
|
||||
include_tags: includeTags.join(','),
|
||||
limit: limit.toString()
|
||||
})
|
||||
if (offset !== undefined && offset > 0) {
|
||||
queryParams.set('offset', offset.toString())
|
||||
}
|
||||
if (includePublic !== undefined) {
|
||||
queryParams.set('include_public', includePublic ? 'true' : 'false')
|
||||
}
|
||||
|
||||
const url = `${ASSETS_ENDPOINT}?${queryParams.toString()}`
|
||||
const res = await api.fetchApi(url)
|
||||
if (!res.ok) {
|
||||
throw new Error(
|
||||
@@ -101,7 +130,7 @@ function createAssetService() {
|
||||
*/
|
||||
async function getAssetModelFolders(): Promise<ModelFolder[]> {
|
||||
const data = await handleAssetRequest(
|
||||
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG}&limit=${DEFAULT_LIMIT}`,
|
||||
{ includeTags: [MODELS_TAG] },
|
||||
'model folders'
|
||||
)
|
||||
|
||||
@@ -130,7 +159,7 @@ function createAssetService() {
|
||||
*/
|
||||
async function getAssetModels(folder: string): Promise<ModelFile[]> {
|
||||
const data = await handleAssetRequest(
|
||||
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${folder}&limit=${DEFAULT_LIMIT}`,
|
||||
{ includeTags: [MODELS_TAG, folder] },
|
||||
`models for ${folder}`
|
||||
)
|
||||
|
||||
@@ -169,9 +198,15 @@ function createAssetService() {
|
||||
* and fetching all assets with that category tag
|
||||
*
|
||||
* @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple')
|
||||
* @param options - Pagination options
|
||||
* @param options.limit - Maximum number of assets to return (default: 500)
|
||||
* @param options.offset - Number of assets to skip (default: 0)
|
||||
* @returns Promise<AssetItem[]> - Full asset objects with preserved metadata
|
||||
*/
|
||||
async function getAssetsForNodeType(nodeType: string): Promise<AssetItem[]> {
|
||||
async function getAssetsForNodeType(
|
||||
nodeType: string,
|
||||
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
|
||||
): Promise<AssetItem[]> {
|
||||
if (!nodeType || typeof nodeType !== 'string') {
|
||||
return []
|
||||
}
|
||||
@@ -186,7 +221,7 @@ function createAssetService() {
|
||||
|
||||
// Fetch assets for this category using same API pattern as getAssetModels
|
||||
const data = await handleAssetRequest(
|
||||
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}&limit=${DEFAULT_LIMIT}`,
|
||||
{ includeTags: [MODELS_TAG, category], limit, offset },
|
||||
`assets for ${nodeType}`
|
||||
)
|
||||
|
||||
@@ -242,23 +277,10 @@ function createAssetService() {
|
||||
async function getAssetsByTag(
|
||||
tag: string,
|
||||
includePublic: boolean = true,
|
||||
{
|
||||
limit = DEFAULT_LIMIT,
|
||||
offset = 0
|
||||
}: { limit?: number; offset?: number } = {}
|
||||
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
|
||||
): Promise<AssetItem[]> {
|
||||
const queryParams = new URLSearchParams({
|
||||
include_tags: tag,
|
||||
limit: limit.toString(),
|
||||
include_public: includePublic ? 'true' : 'false'
|
||||
})
|
||||
|
||||
if (offset > 0) {
|
||||
queryParams.set('offset', offset.toString())
|
||||
}
|
||||
|
||||
const data = await handleAssetRequest(
|
||||
`${ASSETS_ENDPOINT}?${queryParams.toString()}`,
|
||||
{ includeTags: [tag], limit, offset, includePublic },
|
||||
`assets for tag ${tag}`
|
||||
)
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ const mockGetCategoryForNodeType = vi.fn()
|
||||
|
||||
vi.mock('@/stores/assetsStore', () => ({
|
||||
useAssetsStore: () => ({
|
||||
modelAssetsByNodeType: new Map(),
|
||||
modelLoadingByNodeType: new Map(),
|
||||
modelErrorByNodeType: new Map(),
|
||||
getAssets: () => [],
|
||||
isModelLoading: () => false,
|
||||
getError: () => undefined,
|
||||
updateModelsForNodeType: mockUpdateModelsForNodeType
|
||||
})
|
||||
}))
|
||||
|
||||
@@ -8,17 +8,17 @@ vi.mock('@/platform/distribution/types', () => ({
|
||||
isCloud: true
|
||||
}))
|
||||
|
||||
const mockModelAssetsByNodeType = new Map<string, AssetItem[]>()
|
||||
const mockModelLoadingByNodeType = new Map<string, boolean>()
|
||||
const mockModelErrorByNodeType = new Map<string, Error | null>()
|
||||
const mockAssetsByKey = new Map<string, AssetItem[]>()
|
||||
const mockLoadingByKey = new Map<string, boolean>()
|
||||
const mockErrorByKey = new Map<string, Error | undefined>()
|
||||
const mockUpdateModelsForNodeType = vi.fn()
|
||||
const mockGetCategoryForNodeType = vi.fn()
|
||||
|
||||
vi.mock('@/stores/assetsStore', () => ({
|
||||
useAssetsStore: () => ({
|
||||
modelAssetsByNodeType: mockModelAssetsByNodeType,
|
||||
modelLoadingByNodeType: mockModelLoadingByNodeType,
|
||||
modelErrorByNodeType: mockModelErrorByNodeType,
|
||||
getAssets: (key: string) => mockAssetsByKey.get(key) ?? [],
|
||||
isModelLoading: (key: string) => mockLoadingByKey.get(key) ?? false,
|
||||
getError: (key: string) => mockErrorByKey.get(key),
|
||||
updateModelsForNodeType: mockUpdateModelsForNodeType
|
||||
})
|
||||
}))
|
||||
@@ -32,9 +32,9 @@ vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockModelAssetsByNodeType.clear()
|
||||
mockModelLoadingByNodeType.clear()
|
||||
mockModelErrorByNodeType.clear()
|
||||
mockAssetsByKey.clear()
|
||||
mockLoadingByKey.clear()
|
||||
mockErrorByKey.clear()
|
||||
mockGetCategoryForNodeType.mockReturnValue(undefined)
|
||||
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
@@ -76,8 +76,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
||||
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
mockAssetsByKey.set(_nodeType, mockAssets)
|
||||
mockLoadingByKey.set(_nodeType, false)
|
||||
return mockAssets
|
||||
}
|
||||
)
|
||||
@@ -108,9 +108,9 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
||||
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelErrorByNodeType.set(_nodeType, mockError)
|
||||
mockModelAssetsByNodeType.set(_nodeType, [])
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
mockErrorByKey.set(_nodeType, mockError)
|
||||
mockAssetsByKey.set(_nodeType, [])
|
||||
mockLoadingByKey.set(_nodeType, false)
|
||||
return []
|
||||
}
|
||||
)
|
||||
@@ -130,8 +130,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
||||
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, [])
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
mockAssetsByKey.set(_nodeType, [])
|
||||
mockLoadingByKey.set(_nodeType, false)
|
||||
return []
|
||||
}
|
||||
)
|
||||
@@ -154,8 +154,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
||||
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
mockAssetsByKey.set(_nodeType, mockAssets)
|
||||
mockLoadingByKey.set(_nodeType, false)
|
||||
return mockAssets
|
||||
}
|
||||
)
|
||||
@@ -182,8 +182,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
||||
mockGetCategoryForNodeType.mockReturnValue('loras')
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
mockAssetsByKey.set(_nodeType, mockAssets)
|
||||
mockLoadingByKey.set(_nodeType, false)
|
||||
return mockAssets
|
||||
}
|
||||
)
|
||||
@@ -209,8 +209,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
||||
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
mockAssetsByKey.set(_nodeType, mockAssets)
|
||||
mockLoadingByKey.set(_nodeType, false)
|
||||
return mockAssets
|
||||
}
|
||||
)
|
||||
|
||||
@@ -34,23 +34,17 @@ export function useAssetWidgetData(
|
||||
|
||||
const assets = computed<AssetItem[]>(() => {
|
||||
const resolvedType = toValue(nodeType)
|
||||
return resolvedType
|
||||
? (assetsStore.modelAssetsByNodeType.get(resolvedType) ?? [])
|
||||
: []
|
||||
return resolvedType ? (assetsStore.getAssets(resolvedType) ?? []) : []
|
||||
})
|
||||
|
||||
const isLoading = computed(() => {
|
||||
const resolvedType = toValue(nodeType)
|
||||
return resolvedType
|
||||
? (assetsStore.modelLoadingByNodeType.get(resolvedType) ?? false)
|
||||
: false
|
||||
return resolvedType ? assetsStore.isModelLoading(resolvedType) : false
|
||||
})
|
||||
|
||||
const error = computed<Error | null>(() => {
|
||||
const resolvedType = toValue(nodeType)
|
||||
return resolvedType
|
||||
? (assetsStore.modelErrorByNodeType.get(resolvedType) ?? null)
|
||||
: null
|
||||
return resolvedType ? (assetsStore.getError(resolvedType) ?? null) : null
|
||||
})
|
||||
|
||||
const dropdownItems = computed<DropdownItem[]>(() => {
|
||||
@@ -71,7 +65,8 @@ export function useAssetWidgetData(
|
||||
return
|
||||
}
|
||||
|
||||
const hasData = assetsStore.modelAssetsByNodeType.has(currentNodeType)
|
||||
const existingAssets = assetsStore.getAssets(currentNodeType) ?? []
|
||||
const hasData = existingAssets.length > 0
|
||||
|
||||
if (!hasData) {
|
||||
await assetsStore.updateModelsForNodeType(currentNodeType)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick, watch } from 'vue'
|
||||
|
||||
import { useAssetsStore } from '@/stores/assetsStore'
|
||||
import { api } from '@/scripts/api'
|
||||
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
|
||||
import { assetService } from '@/platform/assets/services/assetService'
|
||||
|
||||
// Mock the api module
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
@@ -20,13 +23,17 @@ vi.mock('@/scripts/api', () => ({
|
||||
// Mock the asset service
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: {
|
||||
getAssetsByTag: vi.fn()
|
||||
getAssetsByTag: vi.fn(),
|
||||
getAssetsForNodeType: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock distribution type
|
||||
// Mock distribution type - hoisted so it can be changed per test
|
||||
const mockIsCloud = vi.hoisted(() => ({ value: false }))
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
isCloud: false
|
||||
get isCloud() {
|
||||
return mockIsCloud.value
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock TaskItemImpl
|
||||
@@ -115,7 +122,7 @@ describe('assetsStore - Refactored (Option A)', () => {
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
store = useAssetsStore()
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
@@ -453,3 +460,158 @@ describe('assetsStore - Refactored (Option A)', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
mockIsCloud.value = true
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
mockIsCloud.value = false
|
||||
})
|
||||
|
||||
const createMockAsset = (id: string) => ({
|
||||
id,
|
||||
name: `asset-${id}`,
|
||||
size: 100,
|
||||
created_at: new Date().toISOString(),
|
||||
tags: ['models'],
|
||||
preview_url: `http://test.com/${id}`
|
||||
})
|
||||
|
||||
describe('getAssets cache invalidation', () => {
|
||||
it('should invalidate cache before mutating assets during batch loading', async () => {
|
||||
const store = useAssetsStore()
|
||||
const nodeType = 'CheckpointLoaderSimple'
|
||||
|
||||
const firstBatch = Array.from({ length: 500 }, (_, i) =>
|
||||
createMockAsset(`asset-${i}`)
|
||||
)
|
||||
const secondBatch = Array.from({ length: 100 }, (_, i) =>
|
||||
createMockAsset(`asset-${500 + i}`)
|
||||
)
|
||||
|
||||
let callCount = 0
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
|
||||
async () => {
|
||||
callCount++
|
||||
return callCount === 1 ? firstBatch : secondBatch
|
||||
}
|
||||
)
|
||||
|
||||
await store.updateModelsForNodeType(nodeType)
|
||||
|
||||
// Wait for background batch loading to complete
|
||||
await vi.waitFor(() => {
|
||||
expect(
|
||||
vi.mocked(assetService.getAssetsForNodeType)
|
||||
).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
const assets = store.getAssets(nodeType)
|
||||
expect(assets).toHaveLength(600)
|
||||
})
|
||||
|
||||
it('should not return stale cached array after background batch completes', async () => {
|
||||
const store = useAssetsStore()
|
||||
const nodeType = 'LoraLoader'
|
||||
|
||||
// First batch must be exactly MODEL_BATCH_SIZE (500) to trigger hasMore
|
||||
const firstBatch = Array.from({ length: 500 }, (_, i) =>
|
||||
createMockAsset(`first-${i}`)
|
||||
)
|
||||
const secondBatch = [createMockAsset('new-asset')]
|
||||
|
||||
let callCount = 0
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
|
||||
async () => {
|
||||
callCount++
|
||||
return callCount === 1 ? firstBatch : secondBatch
|
||||
}
|
||||
)
|
||||
|
||||
await store.updateModelsForNodeType(nodeType)
|
||||
|
||||
// Wait for background batch loading to complete
|
||||
await vi.waitFor(() => {
|
||||
expect(
|
||||
vi.mocked(assetService.getAssetsForNodeType)
|
||||
).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
|
||||
const assets = store.getAssets(nodeType)
|
||||
expect(assets).toHaveLength(501)
|
||||
expect(assets.map((a) => a.id)).toContain('new-asset')
|
||||
})
|
||||
|
||||
it('should return cached array on subsequent getAssets calls', () => {
|
||||
const store = useAssetsStore()
|
||||
const nodeType = 'TestLoader'
|
||||
|
||||
const firstCall = store.getAssets(nodeType)
|
||||
const secondCall = store.getAssets(nodeType)
|
||||
|
||||
expect(secondCall).toBe(firstCall)
|
||||
})
|
||||
})
|
||||
|
||||
describe('concurrent request handling', () => {
|
||||
it('should discard stale request when newer request starts', async () => {
|
||||
const store = useAssetsStore()
|
||||
const nodeType = 'CheckpointLoaderSimple'
|
||||
const firstBatch = Array.from({ length: 5 }, (_, i) =>
|
||||
createMockAsset(`first-${i}`)
|
||||
)
|
||||
const secondBatch = Array.from({ length: 10 }, (_, i) =>
|
||||
createMockAsset(`second-${i}`)
|
||||
)
|
||||
|
||||
let resolveFirst: (value: ReturnType<typeof createMockAsset>[]) => void
|
||||
const firstPromise = new Promise<ReturnType<typeof createMockAsset>[]>(
|
||||
(resolve) => {
|
||||
resolveFirst = resolve
|
||||
}
|
||||
)
|
||||
let callCount = 0
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
|
||||
async () => {
|
||||
callCount++
|
||||
return callCount === 1 ? firstPromise : secondBatch
|
||||
}
|
||||
)
|
||||
|
||||
const firstRequest = store.updateModelsForNodeType(nodeType)
|
||||
const secondRequest = store.updateModelsForNodeType(nodeType)
|
||||
resolveFirst!(firstBatch)
|
||||
await Promise.all([firstRequest, secondRequest])
|
||||
|
||||
expect(store.getAssets(nodeType)).toHaveLength(10)
|
||||
expect(
|
||||
store.getAssets(nodeType).every((a) => a.id.startsWith('second-'))
|
||||
).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('shallowReactive state reactivity', () => {
|
||||
it('should trigger reactivity on isModelLoading change', async () => {
|
||||
const store = useAssetsStore()
|
||||
const nodeType = 'CheckpointLoaderSimple'
|
||||
|
||||
const loadingStates: boolean[] = []
|
||||
watch(
|
||||
() => store.isModelLoading(nodeType),
|
||||
(val) => loadingStates.push(val),
|
||||
{ immediate: true }
|
||||
)
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([])
|
||||
await store.updateModelsForNodeType(nodeType)
|
||||
await nextTick()
|
||||
|
||||
expect(loadingStates).toContain(true)
|
||||
expect(loadingStates).toContain(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { useAsyncState, whenever } from '@vueuse/core'
|
||||
import { isEqual } from 'es-toolkit'
|
||||
import { defineStore } from 'pinia'
|
||||
import { computed, shallowReactive, ref } from 'vue'
|
||||
import { computed, reactive, ref, shallowReactive } from 'vue'
|
||||
import {
|
||||
mapInputFileToAssetItem,
|
||||
mapTaskOutputToAssetItem
|
||||
} from '@/platform/assets/composables/media/assetMappers'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { assetService } from '@/platform/assets/services/assetService'
|
||||
import type { PaginationOptions } from '@/platform/assets/services/assetService'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import type { JobListItem } from '@/platform/remote/comfyui/jobs/jobTypes'
|
||||
import { api } from '@/scripts/api'
|
||||
@@ -251,6 +251,16 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
return inputAssetsByFilename.value.get(filename)?.name ?? filename
|
||||
}
|
||||
|
||||
const MODEL_BATCH_SIZE = 500
|
||||
|
||||
interface ModelPaginationState {
|
||||
assets: Map<string, AssetItem>
|
||||
offset: number
|
||||
hasMore: boolean
|
||||
isLoading: boolean
|
||||
error?: Error
|
||||
}
|
||||
|
||||
/**
|
||||
* Model assets cached by node type (e.g., 'CheckpointLoaderSimple', 'LoraLoader')
|
||||
* Used by multiple loader nodes to avoid duplicate fetches
|
||||
@@ -258,109 +268,183 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
*/
|
||||
const getModelState = () => {
|
||||
if (isCloud) {
|
||||
const modelAssetsByNodeType = shallowReactive(
|
||||
new Map<string, AssetItem[]>()
|
||||
)
|
||||
const modelLoadingByNodeType = shallowReactive(new Map<string, boolean>())
|
||||
const modelErrorByNodeType = shallowReactive(
|
||||
new Map<string, Error | null>()
|
||||
)
|
||||
const modelStateByKey = ref(new Map<string, ModelPaginationState>())
|
||||
|
||||
const stateByNodeType = shallowReactive(
|
||||
new Map<string, ReturnType<typeof useAsyncState<AssetItem[]>>>()
|
||||
)
|
||||
const assetsArrayCache = new Map<
|
||||
string,
|
||||
{ source: Map<string, AssetItem>; array: AssetItem[] }
|
||||
>()
|
||||
|
||||
const pendingRequestByKey = new Map<string, ModelPaginationState>()
|
||||
|
||||
function createState(): ModelPaginationState {
|
||||
return reactive({
|
||||
assets: new Map(),
|
||||
offset: 0,
|
||||
hasMore: true,
|
||||
isLoading: false
|
||||
})
|
||||
}
|
||||
|
||||
function isStale(key: string, state: ModelPaginationState): boolean {
|
||||
const committed = modelStateByKey.value.get(key)
|
||||
const pending = pendingRequestByKey.get(key)
|
||||
return committed !== state && pending !== state
|
||||
}
|
||||
|
||||
const EMPTY_ASSETS: AssetItem[] = []
|
||||
|
||||
function getAssets(key: string): AssetItem[] {
|
||||
const state = modelStateByKey.value.get(key)
|
||||
const assetsMap = state?.assets
|
||||
if (!assetsMap) return EMPTY_ASSETS
|
||||
|
||||
const cached = assetsArrayCache.get(key)
|
||||
if (cached && cached.source === assetsMap) {
|
||||
return cached.array
|
||||
}
|
||||
|
||||
const array = Array.from(assetsMap.values())
|
||||
assetsArrayCache.set(key, { source: assetsMap, array })
|
||||
return array
|
||||
}
|
||||
|
||||
function isLoading(key: string): boolean {
|
||||
return modelStateByKey.value.get(key)?.isLoading ?? false
|
||||
}
|
||||
|
||||
function getError(key: string): Error | undefined {
|
||||
return modelStateByKey.value.get(key)?.error
|
||||
}
|
||||
|
||||
function hasMore(key: string): boolean {
|
||||
return modelStateByKey.value.get(key)?.hasMore ?? false
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal helper to fetch and cache assets with a given key and fetcher
|
||||
* Internal helper to fetch and cache assets with a given key and fetcher.
|
||||
* Loads first batch immediately, then progressively loads remaining batches.
|
||||
* Keeps existing data visible until new data is successfully fetched.
|
||||
*/
|
||||
async function updateModelsForKey(
|
||||
key: string,
|
||||
fetcher: () => Promise<AssetItem[]>
|
||||
): Promise<AssetItem[]> {
|
||||
if (!stateByNodeType.has(key)) {
|
||||
stateByNodeType.set(
|
||||
key,
|
||||
useAsyncState(fetcher, [], {
|
||||
immediate: false,
|
||||
resetOnExecute: false,
|
||||
onError: (err) => {
|
||||
console.error(`Error fetching model assets for ${key}:`, err)
|
||||
fetcher: (options: PaginationOptions) => Promise<AssetItem[]>
|
||||
): Promise<void> {
|
||||
const state = createState()
|
||||
state.isLoading = true
|
||||
|
||||
const hasExistingData = modelStateByKey.value.has(key)
|
||||
if (hasExistingData) {
|
||||
pendingRequestByKey.set(key, state)
|
||||
} else {
|
||||
modelStateByKey.value.set(key, state)
|
||||
}
|
||||
|
||||
async function loadBatches(): Promise<void> {
|
||||
while (state.hasMore) {
|
||||
try {
|
||||
const newAssets = await fetcher({
|
||||
limit: MODEL_BATCH_SIZE,
|
||||
offset: state.offset
|
||||
})
|
||||
|
||||
if (isStale(key, state)) return
|
||||
|
||||
const isFirstBatch = state.offset === 0
|
||||
if (isFirstBatch) {
|
||||
assetsArrayCache.delete(key)
|
||||
if (hasExistingData) {
|
||||
pendingRequestByKey.delete(key)
|
||||
modelStateByKey.value.set(key, state)
|
||||
}
|
||||
state.assets = new Map(newAssets.map((a) => [a.id, a]))
|
||||
} else {
|
||||
const assetsToAdd = newAssets.filter(
|
||||
(a) => !state.assets.has(a.id)
|
||||
)
|
||||
if (assetsToAdd.length > 0) {
|
||||
assetsArrayCache.delete(key)
|
||||
for (const asset of assetsToAdd) {
|
||||
state.assets.set(asset.id, asset)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
state.offset += newAssets.length
|
||||
state.hasMore = newAssets.length === MODEL_BATCH_SIZE
|
||||
|
||||
if (isFirstBatch) {
|
||||
state.isLoading = false
|
||||
}
|
||||
|
||||
if (state.hasMore) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 50))
|
||||
}
|
||||
} catch (err) {
|
||||
if (isStale(key, state)) return
|
||||
state.error = err instanceof Error ? err : new Error(String(err))
|
||||
state.hasMore = false
|
||||
console.error(`Error loading batch for ${key}:`, err)
|
||||
if (state.offset === 0) {
|
||||
state.isLoading = false
|
||||
pendingRequestByKey.delete(key)
|
||||
// TODO: Add toast indicator for first-batch load failures
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const state = stateByNodeType.get(key)!
|
||||
|
||||
modelLoadingByNodeType.set(key, true)
|
||||
modelErrorByNodeType.set(key, null)
|
||||
|
||||
try {
|
||||
await state.execute()
|
||||
} finally {
|
||||
modelLoadingByNodeType.set(key, state.isLoading.value)
|
||||
}
|
||||
|
||||
const assets = state.state.value
|
||||
const existingAssets = modelAssetsByNodeType.get(key)
|
||||
|
||||
if (!isEqual(existingAssets, assets)) {
|
||||
modelAssetsByNodeType.set(key, assets)
|
||||
}
|
||||
|
||||
modelErrorByNodeType.set(
|
||||
key,
|
||||
state.error.value instanceof Error ? state.error.value : null
|
||||
)
|
||||
|
||||
return assets
|
||||
await loadBatches()
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch and cache model assets for a specific node type
|
||||
* @param nodeType The node type to fetch assets for (e.g., 'CheckpointLoaderSimple')
|
||||
* @returns Promise resolving to the fetched assets
|
||||
*/
|
||||
async function updateModelsForNodeType(
|
||||
nodeType: string
|
||||
): Promise<AssetItem[]> {
|
||||
return updateModelsForKey(nodeType, () =>
|
||||
assetService.getAssetsForNodeType(nodeType)
|
||||
async function updateModelsForNodeType(nodeType: string): Promise<void> {
|
||||
await updateModelsForKey(nodeType, (opts) =>
|
||||
assetService.getAssetsForNodeType(nodeType, opts)
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetch and cache model assets for a specific tag
|
||||
* @param tag The tag to fetch assets for (e.g., 'models')
|
||||
* @returns Promise resolving to the fetched assets
|
||||
*/
|
||||
async function updateModelsForTag(tag: string): Promise<AssetItem[]> {
|
||||
async function updateModelsForTag(tag: string): Promise<void> {
|
||||
const key = `tag:${tag}`
|
||||
return updateModelsForKey(key, () => assetService.getAssetsByTag(tag))
|
||||
await updateModelsForKey(key, (opts) =>
|
||||
assetService.getAssetsByTag(tag, true, opts)
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
modelAssetsByNodeType,
|
||||
modelLoadingByNodeType,
|
||||
modelErrorByNodeType,
|
||||
getAssets,
|
||||
isLoading,
|
||||
getError,
|
||||
hasMore,
|
||||
updateModelsForNodeType,
|
||||
updateModelsForTag
|
||||
}
|
||||
}
|
||||
|
||||
const emptyAssets: AssetItem[] = []
|
||||
return {
|
||||
modelAssetsByNodeType: shallowReactive(new Map<string, AssetItem[]>()),
|
||||
modelLoadingByNodeType: shallowReactive(new Map<string, boolean>()),
|
||||
modelErrorByNodeType: shallowReactive(new Map<string, Error | null>()),
|
||||
updateModelsForNodeType: async () => [],
|
||||
updateModelsForTag: async () => []
|
||||
getAssets: () => emptyAssets,
|
||||
isLoading: () => false,
|
||||
getError: () => undefined,
|
||||
hasMore: () => false,
|
||||
updateModelsForNodeType: async () => {},
|
||||
updateModelsForTag: async () => {}
|
||||
}
|
||||
}
|
||||
|
||||
const {
|
||||
modelAssetsByNodeType,
|
||||
modelLoadingByNodeType,
|
||||
modelErrorByNodeType,
|
||||
getAssets,
|
||||
isLoading: isModelLoading,
|
||||
getError,
|
||||
hasMore,
|
||||
updateModelsForNodeType,
|
||||
updateModelsForTag
|
||||
} = getModelState()
|
||||
@@ -422,10 +506,13 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
inputAssetsByFilename,
|
||||
getInputName,
|
||||
|
||||
// Model assets
|
||||
modelAssetsByNodeType,
|
||||
modelLoadingByNodeType,
|
||||
modelErrorByNodeType,
|
||||
// Model assets - accessors
|
||||
getAssets,
|
||||
isModelLoading,
|
||||
getError,
|
||||
hasMore,
|
||||
|
||||
// Model assets - actions
|
||||
updateModelsForNodeType,
|
||||
updateModelsForTag
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user