mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-04-30 03:01:54 +00:00
[backport cloud/1.37] feat: implement progressive pagination for Asset Browser model assets (#8240)
Backport of #8212 to cloud/1.37 ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-8240-backport-cloud-1-37-feat-implement-progressive-pagination-for-Asset-Browser-model-asse-2f06d73d365081b199a0dd6bcc242bba) by [Unito](https://www.unito.io) Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -6,6 +6,9 @@ import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vu
|
|||||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||||
import { useAssetsStore } from '@/stores/assetsStore'
|
import { useAssetsStore } from '@/stores/assetsStore'
|
||||||
|
|
||||||
|
const mockAssetsByKey = vi.hoisted(() => new Map<string, AssetItem[]>())
|
||||||
|
const mockLoadingByKey = vi.hoisted(() => new Map<string, boolean>())
|
||||||
|
|
||||||
vi.mock('@/i18n', () => ({
|
vi.mock('@/i18n', () => ({
|
||||||
t: (key: string, params?: Record<string, string>) =>
|
t: (key: string, params?: Record<string, string>) =>
|
||||||
params ? `${key}:${JSON.stringify(params)}` : key,
|
params ? `${key}:${JSON.stringify(params)}` : key,
|
||||||
@@ -13,13 +16,20 @@ vi.mock('@/i18n', () => ({
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@/stores/assetsStore', () => {
|
vi.mock('@/stores/assetsStore', () => {
|
||||||
const store = {
|
const getAssets = vi.fn((key: string) => mockAssetsByKey.get(key) ?? [])
|
||||||
modelAssetsByNodeType: new Map<string, AssetItem[]>(),
|
const isModelLoading = vi.fn(
|
||||||
modelLoadingByNodeType: new Map<string, boolean>(),
|
(key: string) => mockLoadingByKey.get(key) ?? false
|
||||||
updateModelsForNodeType: vi.fn(),
|
)
|
||||||
updateModelsForTag: vi.fn()
|
const updateModelsForNodeType = vi.fn()
|
||||||
|
const updateModelsForTag = vi.fn()
|
||||||
|
return {
|
||||||
|
useAssetsStore: () => ({
|
||||||
|
getAssets,
|
||||||
|
isModelLoading,
|
||||||
|
updateModelsForNodeType,
|
||||||
|
updateModelsForTag
|
||||||
|
})
|
||||||
}
|
}
|
||||||
return { useAssetsStore: () => store }
|
|
||||||
})
|
})
|
||||||
|
|
||||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||||
@@ -183,12 +193,10 @@ describe('AssetBrowserModal', () => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const mockStore = useAssetsStore()
|
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.resetAllMocks()
|
vi.resetAllMocks()
|
||||||
mockStore.modelAssetsByNodeType.clear()
|
mockAssetsByKey.clear()
|
||||||
mockStore.modelLoadingByNodeType.clear()
|
mockLoadingByKey.clear()
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('Integration with useAssetBrowser', () => {
|
describe('Integration with useAssetBrowser', () => {
|
||||||
@@ -197,7 +205,7 @@ describe('AssetBrowserModal', () => {
|
|||||||
createTestAsset('asset1', 'Model A', 'checkpoints'),
|
createTestAsset('asset1', 'Model A', 'checkpoints'),
|
||||||
createTestAsset('asset2', 'Model B', 'loras')
|
createTestAsset('asset2', 'Model B', 'loras')
|
||||||
]
|
]
|
||||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||||
|
|
||||||
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||||
await flushPromises()
|
await flushPromises()
|
||||||
@@ -214,7 +222,7 @@ describe('AssetBrowserModal', () => {
|
|||||||
createTestAsset('c1', 'model.safetensors', 'checkpoints'),
|
createTestAsset('c1', 'model.safetensors', 'checkpoints'),
|
||||||
createTestAsset('l1', 'lora.pt', 'loras')
|
createTestAsset('l1', 'lora.pt', 'loras')
|
||||||
]
|
]
|
||||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||||
|
|
||||||
const wrapper = createWrapper({
|
const wrapper = createWrapper({
|
||||||
nodeType: 'CheckpointLoaderSimple',
|
nodeType: 'CheckpointLoaderSimple',
|
||||||
@@ -231,17 +239,18 @@ describe('AssetBrowserModal', () => {
|
|||||||
|
|
||||||
describe('Data fetching', () => {
|
describe('Data fetching', () => {
|
||||||
it('triggers store refresh for node type on mount', async () => {
|
it('triggers store refresh for node type on mount', async () => {
|
||||||
|
const store = useAssetsStore()
|
||||||
createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||||
await flushPromises()
|
await flushPromises()
|
||||||
|
|
||||||
expect(mockStore.updateModelsForNodeType).toHaveBeenCalledWith(
|
expect(store.updateModelsForNodeType).toHaveBeenCalledWith(
|
||||||
'CheckpointLoaderSimple'
|
'CheckpointLoaderSimple'
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('displays cached assets immediately from store', async () => {
|
it('displays cached assets immediately from store', async () => {
|
||||||
const assets = [createTestAsset('asset1', 'Cached Model', 'checkpoints')]
|
const assets = [createTestAsset('asset1', 'Cached Model', 'checkpoints')]
|
||||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||||
|
|
||||||
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||||
|
|
||||||
@@ -253,15 +262,16 @@ describe('AssetBrowserModal', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
it('triggers store refresh for asset type (tag) on mount', async () => {
|
it('triggers store refresh for asset type (tag) on mount', async () => {
|
||||||
|
const store = useAssetsStore()
|
||||||
createWrapper({ assetType: 'models' })
|
createWrapper({ assetType: 'models' })
|
||||||
await flushPromises()
|
await flushPromises()
|
||||||
|
|
||||||
expect(mockStore.updateModelsForTag).toHaveBeenCalledWith('models')
|
expect(store.updateModelsForTag).toHaveBeenCalledWith('models')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('uses tag: prefix for cache key when assetType is provided', async () => {
|
it('uses tag: prefix for cache key when assetType is provided', async () => {
|
||||||
const assets = [createTestAsset('asset1', 'Tagged Model', 'models')]
|
const assets = [createTestAsset('asset1', 'Tagged Model', 'models')]
|
||||||
mockStore.modelAssetsByNodeType.set('tag:models', assets)
|
mockAssetsByKey.set('tag:models', assets)
|
||||||
|
|
||||||
const wrapper = createWrapper({ assetType: 'models' })
|
const wrapper = createWrapper({ assetType: 'models' })
|
||||||
await flushPromises()
|
await flushPromises()
|
||||||
@@ -277,7 +287,7 @@ describe('AssetBrowserModal', () => {
|
|||||||
describe('Asset Selection', () => {
|
describe('Asset Selection', () => {
|
||||||
it('emits asset-select event when asset is selected', async () => {
|
it('emits asset-select event when asset is selected', async () => {
|
||||||
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
|
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
|
||||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||||
|
|
||||||
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||||
await flushPromises()
|
await flushPromises()
|
||||||
@@ -290,7 +300,7 @@ describe('AssetBrowserModal', () => {
|
|||||||
|
|
||||||
it('executes onSelect callback when provided', async () => {
|
it('executes onSelect callback when provided', async () => {
|
||||||
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
|
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
|
||||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||||
|
|
||||||
const onSelect = vi.fn()
|
const onSelect = vi.fn()
|
||||||
const wrapper = createWrapper({
|
const wrapper = createWrapper({
|
||||||
@@ -333,7 +343,7 @@ describe('AssetBrowserModal', () => {
|
|||||||
createTestAsset('asset1', 'Model A', 'checkpoints'),
|
createTestAsset('asset1', 'Model A', 'checkpoints'),
|
||||||
createTestAsset('asset2', 'Model B', 'loras')
|
createTestAsset('asset2', 'Model B', 'loras')
|
||||||
]
|
]
|
||||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||||
|
|
||||||
const wrapper = createWrapper({
|
const wrapper = createWrapper({
|
||||||
nodeType: 'CheckpointLoaderSimple',
|
nodeType: 'CheckpointLoaderSimple',
|
||||||
@@ -366,7 +376,7 @@ describe('AssetBrowserModal', () => {
|
|||||||
|
|
||||||
it('passes computed contentTitle to BaseModalLayout when no title prop', async () => {
|
it('passes computed contentTitle to BaseModalLayout when no title prop', async () => {
|
||||||
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
|
const assets = [createTestAsset('asset1', 'Model A', 'checkpoints')]
|
||||||
mockStore.modelAssetsByNodeType.set('CheckpointLoaderSimple', assets)
|
mockAssetsByKey.set('CheckpointLoaderSimple', assets)
|
||||||
|
|
||||||
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
const wrapper = createWrapper({ nodeType: 'CheckpointLoaderSimple' })
|
||||||
await flushPromises()
|
await flushPromises()
|
||||||
|
|||||||
@@ -112,27 +112,21 @@ const cacheKey = computed(() => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Read directly from store cache - reactive to any store updates
|
// Read directly from store cache - reactive to any store updates
|
||||||
const fetchedAssets = computed(
|
const fetchedAssets = computed(() => assetStore.getAssets(cacheKey.value))
|
||||||
() => assetStore.modelAssetsByNodeType.get(cacheKey.value) ?? []
|
|
||||||
)
|
|
||||||
|
|
||||||
const isStoreLoading = computed(
|
const isStoreLoading = computed(() => assetStore.isModelLoading(cacheKey.value))
|
||||||
() => assetStore.modelLoadingByNodeType.get(cacheKey.value) ?? false
|
|
||||||
)
|
|
||||||
|
|
||||||
// Only show loading spinner when loading AND no cached data
|
// Only show loading spinner when loading AND no cached data
|
||||||
const isLoading = computed(
|
const isLoading = computed(
|
||||||
() => isStoreLoading.value && fetchedAssets.value.length === 0
|
() => isStoreLoading.value && fetchedAssets.value.length === 0
|
||||||
)
|
)
|
||||||
|
|
||||||
async function refreshAssets(): Promise<AssetItem[]> {
|
async function refreshAssets(): Promise<void> {
|
||||||
if (props.nodeType) {
|
if (props.nodeType) {
|
||||||
return await assetStore.updateModelsForNodeType(props.nodeType)
|
await assetStore.updateModelsForNodeType(props.nodeType)
|
||||||
|
} else if (props.assetType) {
|
||||||
|
await assetStore.updateModelsForTag(props.assetType)
|
||||||
}
|
}
|
||||||
if (props.assetType) {
|
|
||||||
return await assetStore.updateModelsForTag(props.assetType)
|
|
||||||
}
|
|
||||||
return []
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Trigger background refresh on mount
|
// Trigger background refresh on mount
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ describe('assetService', () => {
|
|||||||
const result = await assetService.getAssetModels('checkpoints')
|
const result = await assetService.getAssetModels('checkpoints')
|
||||||
|
|
||||||
expect(api.fetchApi).toHaveBeenCalledWith(
|
expect(api.fetchApi).toHaveBeenCalledWith(
|
||||||
'/assets?include_tags=models,checkpoints&limit=500'
|
'/assets?include_tags=models%2Ccheckpoints&limit=500'
|
||||||
)
|
)
|
||||||
expect(result).toEqual([
|
expect(result).toEqual([
|
||||||
expect.objectContaining({ name: 'valid.safetensors', pathIndex: 0 })
|
expect.objectContaining({ name: 'valid.safetensors', pathIndex: 0 })
|
||||||
@@ -231,9 +231,9 @@ describe('assetService', () => {
|
|||||||
)
|
)
|
||||||
expect(result).toEqual(testAssets)
|
expect(result).toEqual(testAssets)
|
||||||
|
|
||||||
// Verify API call includes correct category
|
// Verify API call includes correct category (comma is URL-encoded by URLSearchParams)
|
||||||
expect(api.fetchApi).toHaveBeenCalledWith(
|
expect(api.fetchApi).toHaveBeenCalledWith(
|
||||||
'/assets?include_tags=models,checkpoints&limit=500'
|
'/assets?include_tags=models%2Ccheckpoints&limit=500'
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -400,7 +400,7 @@ describe('assetService', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
expect(api.fetchApi).toHaveBeenCalledWith(
|
expect(api.fetchApi).toHaveBeenCalledWith(
|
||||||
'/assets?include_tags=models&limit=500&include_public=true&offset=50'
|
'/assets?include_tags=models&limit=500&offset=50&include_public=true'
|
||||||
)
|
)
|
||||||
expect(result).toEqual(testAssets)
|
expect(result).toEqual(testAssets)
|
||||||
})
|
})
|
||||||
@@ -415,7 +415,7 @@ describe('assetService', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
expect(api.fetchApi).toHaveBeenCalledWith(
|
expect(api.fetchApi).toHaveBeenCalledWith(
|
||||||
'/assets?include_tags=input&limit=100&include_public=false&offset=25'
|
'/assets?include_tags=input&limit=100&offset=25&include_public=false'
|
||||||
)
|
)
|
||||||
expect(result).toEqual(testAssets)
|
expect(result).toEqual(testAssets)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { fromZodError } from 'zod-validation-error'
|
import { fromZodError } from 'zod-validation-error'
|
||||||
|
|
||||||
import { st } from '@/i18n'
|
import { st } from '@/i18n'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
assetItemSchema,
|
assetItemSchema,
|
||||||
assetResponseSchema,
|
assetResponseSchema,
|
||||||
@@ -17,6 +18,16 @@ import type {
|
|||||||
import { api } from '@/scripts/api'
|
import { api } from '@/scripts/api'
|
||||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||||
|
|
||||||
|
export interface PaginationOptions {
|
||||||
|
limit?: number
|
||||||
|
offset?: number
|
||||||
|
}
|
||||||
|
|
||||||
|
interface AssetRequestOptions extends PaginationOptions {
|
||||||
|
includeTags: string[]
|
||||||
|
includePublic?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maps CivitAI validation error codes to localized error messages
|
* Maps CivitAI validation error codes to localized error messages
|
||||||
*/
|
*/
|
||||||
@@ -77,9 +88,27 @@ function createAssetService() {
|
|||||||
* Handles API response with consistent error handling and Zod validation
|
* Handles API response with consistent error handling and Zod validation
|
||||||
*/
|
*/
|
||||||
async function handleAssetRequest(
|
async function handleAssetRequest(
|
||||||
url: string,
|
options: AssetRequestOptions,
|
||||||
context: string
|
context: string
|
||||||
): Promise<AssetResponse> {
|
): Promise<AssetResponse> {
|
||||||
|
const {
|
||||||
|
includeTags,
|
||||||
|
limit = DEFAULT_LIMIT,
|
||||||
|
offset,
|
||||||
|
includePublic
|
||||||
|
} = options
|
||||||
|
const queryParams = new URLSearchParams({
|
||||||
|
include_tags: includeTags.join(','),
|
||||||
|
limit: limit.toString()
|
||||||
|
})
|
||||||
|
if (offset !== undefined && offset > 0) {
|
||||||
|
queryParams.set('offset', offset.toString())
|
||||||
|
}
|
||||||
|
if (includePublic !== undefined) {
|
||||||
|
queryParams.set('include_public', includePublic ? 'true' : 'false')
|
||||||
|
}
|
||||||
|
|
||||||
|
const url = `${ASSETS_ENDPOINT}?${queryParams.toString()}`
|
||||||
const res = await api.fetchApi(url)
|
const res = await api.fetchApi(url)
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
@@ -101,7 +130,7 @@ function createAssetService() {
|
|||||||
*/
|
*/
|
||||||
async function getAssetModelFolders(): Promise<ModelFolder[]> {
|
async function getAssetModelFolders(): Promise<ModelFolder[]> {
|
||||||
const data = await handleAssetRequest(
|
const data = await handleAssetRequest(
|
||||||
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG}&limit=${DEFAULT_LIMIT}`,
|
{ includeTags: [MODELS_TAG] },
|
||||||
'model folders'
|
'model folders'
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -130,7 +159,7 @@ function createAssetService() {
|
|||||||
*/
|
*/
|
||||||
async function getAssetModels(folder: string): Promise<ModelFile[]> {
|
async function getAssetModels(folder: string): Promise<ModelFile[]> {
|
||||||
const data = await handleAssetRequest(
|
const data = await handleAssetRequest(
|
||||||
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${folder}&limit=${DEFAULT_LIMIT}`,
|
{ includeTags: [MODELS_TAG, folder] },
|
||||||
`models for ${folder}`
|
`models for ${folder}`
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -169,9 +198,15 @@ function createAssetService() {
|
|||||||
* and fetching all assets with that category tag
|
* and fetching all assets with that category tag
|
||||||
*
|
*
|
||||||
* @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple')
|
* @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple')
|
||||||
|
* @param options - Pagination options
|
||||||
|
* @param options.limit - Maximum number of assets to return (default: 500)
|
||||||
|
* @param options.offset - Number of assets to skip (default: 0)
|
||||||
* @returns Promise<AssetItem[]> - Full asset objects with preserved metadata
|
* @returns Promise<AssetItem[]> - Full asset objects with preserved metadata
|
||||||
*/
|
*/
|
||||||
async function getAssetsForNodeType(nodeType: string): Promise<AssetItem[]> {
|
async function getAssetsForNodeType(
|
||||||
|
nodeType: string,
|
||||||
|
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
|
||||||
|
): Promise<AssetItem[]> {
|
||||||
if (!nodeType || typeof nodeType !== 'string') {
|
if (!nodeType || typeof nodeType !== 'string') {
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
@@ -186,7 +221,7 @@ function createAssetService() {
|
|||||||
|
|
||||||
// Fetch assets for this category using same API pattern as getAssetModels
|
// Fetch assets for this category using same API pattern as getAssetModels
|
||||||
const data = await handleAssetRequest(
|
const data = await handleAssetRequest(
|
||||||
`${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}&limit=${DEFAULT_LIMIT}`,
|
{ includeTags: [MODELS_TAG, category], limit, offset },
|
||||||
`assets for ${nodeType}`
|
`assets for ${nodeType}`
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -242,23 +277,10 @@ function createAssetService() {
|
|||||||
async function getAssetsByTag(
|
async function getAssetsByTag(
|
||||||
tag: string,
|
tag: string,
|
||||||
includePublic: boolean = true,
|
includePublic: boolean = true,
|
||||||
{
|
{ limit = DEFAULT_LIMIT, offset = 0 }: PaginationOptions = {}
|
||||||
limit = DEFAULT_LIMIT,
|
|
||||||
offset = 0
|
|
||||||
}: { limit?: number; offset?: number } = {}
|
|
||||||
): Promise<AssetItem[]> {
|
): Promise<AssetItem[]> {
|
||||||
const queryParams = new URLSearchParams({
|
|
||||||
include_tags: tag,
|
|
||||||
limit: limit.toString(),
|
|
||||||
include_public: includePublic ? 'true' : 'false'
|
|
||||||
})
|
|
||||||
|
|
||||||
if (offset > 0) {
|
|
||||||
queryParams.set('offset', offset.toString())
|
|
||||||
}
|
|
||||||
|
|
||||||
const data = await handleAssetRequest(
|
const data = await handleAssetRequest(
|
||||||
`${ASSETS_ENDPOINT}?${queryParams.toString()}`,
|
{ includeTags: [tag], limit, offset, includePublic },
|
||||||
`assets for tag ${tag}`
|
`assets for tag ${tag}`
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ const mockGetCategoryForNodeType = vi.fn()
|
|||||||
|
|
||||||
vi.mock('@/stores/assetsStore', () => ({
|
vi.mock('@/stores/assetsStore', () => ({
|
||||||
useAssetsStore: () => ({
|
useAssetsStore: () => ({
|
||||||
modelAssetsByNodeType: new Map(),
|
getAssets: () => [],
|
||||||
modelLoadingByNodeType: new Map(),
|
isModelLoading: () => false,
|
||||||
modelErrorByNodeType: new Map(),
|
getError: () => undefined,
|
||||||
updateModelsForNodeType: mockUpdateModelsForNodeType
|
updateModelsForNodeType: mockUpdateModelsForNodeType
|
||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
|
|||||||
@@ -8,17 +8,17 @@ vi.mock('@/platform/distribution/types', () => ({
|
|||||||
isCloud: true
|
isCloud: true
|
||||||
}))
|
}))
|
||||||
|
|
||||||
const mockModelAssetsByNodeType = new Map<string, AssetItem[]>()
|
const mockAssetsByKey = new Map<string, AssetItem[]>()
|
||||||
const mockModelLoadingByNodeType = new Map<string, boolean>()
|
const mockLoadingByKey = new Map<string, boolean>()
|
||||||
const mockModelErrorByNodeType = new Map<string, Error | null>()
|
const mockErrorByKey = new Map<string, Error | undefined>()
|
||||||
const mockUpdateModelsForNodeType = vi.fn()
|
const mockUpdateModelsForNodeType = vi.fn()
|
||||||
const mockGetCategoryForNodeType = vi.fn()
|
const mockGetCategoryForNodeType = vi.fn()
|
||||||
|
|
||||||
vi.mock('@/stores/assetsStore', () => ({
|
vi.mock('@/stores/assetsStore', () => ({
|
||||||
useAssetsStore: () => ({
|
useAssetsStore: () => ({
|
||||||
modelAssetsByNodeType: mockModelAssetsByNodeType,
|
getAssets: (key: string) => mockAssetsByKey.get(key) ?? [],
|
||||||
modelLoadingByNodeType: mockModelLoadingByNodeType,
|
isModelLoading: (key: string) => mockLoadingByKey.get(key) ?? false,
|
||||||
modelErrorByNodeType: mockModelErrorByNodeType,
|
getError: (key: string) => mockErrorByKey.get(key),
|
||||||
updateModelsForNodeType: mockUpdateModelsForNodeType
|
updateModelsForNodeType: mockUpdateModelsForNodeType
|
||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
@@ -32,9 +32,9 @@ vi.mock('@/stores/modelToNodeStore', () => ({
|
|||||||
describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
mockModelAssetsByNodeType.clear()
|
mockAssetsByKey.clear()
|
||||||
mockModelLoadingByNodeType.clear()
|
mockLoadingByKey.clear()
|
||||||
mockModelErrorByNodeType.clear()
|
mockErrorByKey.clear()
|
||||||
mockGetCategoryForNodeType.mockReturnValue(undefined)
|
mockGetCategoryForNodeType.mockReturnValue(undefined)
|
||||||
|
|
||||||
mockUpdateModelsForNodeType.mockImplementation(
|
mockUpdateModelsForNodeType.mockImplementation(
|
||||||
@@ -76,8 +76,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
|||||||
|
|
||||||
mockUpdateModelsForNodeType.mockImplementation(
|
mockUpdateModelsForNodeType.mockImplementation(
|
||||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
mockAssetsByKey.set(_nodeType, mockAssets)
|
||||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
mockLoadingByKey.set(_nodeType, false)
|
||||||
return mockAssets
|
return mockAssets
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -108,9 +108,9 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
|||||||
|
|
||||||
mockUpdateModelsForNodeType.mockImplementation(
|
mockUpdateModelsForNodeType.mockImplementation(
|
||||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||||
mockModelErrorByNodeType.set(_nodeType, mockError)
|
mockErrorByKey.set(_nodeType, mockError)
|
||||||
mockModelAssetsByNodeType.set(_nodeType, [])
|
mockAssetsByKey.set(_nodeType, [])
|
||||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
mockLoadingByKey.set(_nodeType, false)
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -130,8 +130,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
|||||||
|
|
||||||
mockUpdateModelsForNodeType.mockImplementation(
|
mockUpdateModelsForNodeType.mockImplementation(
|
||||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||||
mockModelAssetsByNodeType.set(_nodeType, [])
|
mockAssetsByKey.set(_nodeType, [])
|
||||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
mockLoadingByKey.set(_nodeType, false)
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -154,8 +154,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
|||||||
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
||||||
mockUpdateModelsForNodeType.mockImplementation(
|
mockUpdateModelsForNodeType.mockImplementation(
|
||||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
mockAssetsByKey.set(_nodeType, mockAssets)
|
||||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
mockLoadingByKey.set(_nodeType, false)
|
||||||
return mockAssets
|
return mockAssets
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -182,8 +182,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
|||||||
mockGetCategoryForNodeType.mockReturnValue('loras')
|
mockGetCategoryForNodeType.mockReturnValue('loras')
|
||||||
mockUpdateModelsForNodeType.mockImplementation(
|
mockUpdateModelsForNodeType.mockImplementation(
|
||||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
mockAssetsByKey.set(_nodeType, mockAssets)
|
||||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
mockLoadingByKey.set(_nodeType, false)
|
||||||
return mockAssets
|
return mockAssets
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -209,8 +209,8 @@ describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
|||||||
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
||||||
mockUpdateModelsForNodeType.mockImplementation(
|
mockUpdateModelsForNodeType.mockImplementation(
|
||||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
mockAssetsByKey.set(_nodeType, mockAssets)
|
||||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
mockLoadingByKey.set(_nodeType, false)
|
||||||
return mockAssets
|
return mockAssets
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,23 +34,17 @@ export function useAssetWidgetData(
|
|||||||
|
|
||||||
const assets = computed<AssetItem[]>(() => {
|
const assets = computed<AssetItem[]>(() => {
|
||||||
const resolvedType = toValue(nodeType)
|
const resolvedType = toValue(nodeType)
|
||||||
return resolvedType
|
return resolvedType ? (assetsStore.getAssets(resolvedType) ?? []) : []
|
||||||
? (assetsStore.modelAssetsByNodeType.get(resolvedType) ?? [])
|
|
||||||
: []
|
|
||||||
})
|
})
|
||||||
|
|
||||||
const isLoading = computed(() => {
|
const isLoading = computed(() => {
|
||||||
const resolvedType = toValue(nodeType)
|
const resolvedType = toValue(nodeType)
|
||||||
return resolvedType
|
return resolvedType ? assetsStore.isModelLoading(resolvedType) : false
|
||||||
? (assetsStore.modelLoadingByNodeType.get(resolvedType) ?? false)
|
|
||||||
: false
|
|
||||||
})
|
})
|
||||||
|
|
||||||
const error = computed<Error | null>(() => {
|
const error = computed<Error | null>(() => {
|
||||||
const resolvedType = toValue(nodeType)
|
const resolvedType = toValue(nodeType)
|
||||||
return resolvedType
|
return resolvedType ? (assetsStore.getError(resolvedType) ?? null) : null
|
||||||
? (assetsStore.modelErrorByNodeType.get(resolvedType) ?? null)
|
|
||||||
: null
|
|
||||||
})
|
})
|
||||||
|
|
||||||
const dropdownItems = computed<DropdownItem[]>(() => {
|
const dropdownItems = computed<DropdownItem[]>(() => {
|
||||||
@@ -71,7 +65,8 @@ export function useAssetWidgetData(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const hasData = assetsStore.modelAssetsByNodeType.has(currentNodeType)
|
const existingAssets = assetsStore.getAssets(currentNodeType) ?? []
|
||||||
|
const hasData = existingAssets.length > 0
|
||||||
|
|
||||||
if (!hasData) {
|
if (!hasData) {
|
||||||
await assetsStore.updateModelsForNodeType(currentNodeType)
|
await assetsStore.updateModelsForNodeType(currentNodeType)
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import { createPinia, setActivePinia } from 'pinia'
|
import { createTestingPinia } from '@pinia/testing'
|
||||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
import { setActivePinia } from 'pinia'
|
||||||
|
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
import { nextTick, watch } from 'vue'
|
||||||
|
|
||||||
import { useAssetsStore } from '@/stores/assetsStore'
|
import { useAssetsStore } from '@/stores/assetsStore'
|
||||||
import { api } from '@/scripts/api'
|
import { api } from '@/scripts/api'
|
||||||
@@ -9,6 +11,7 @@ import type {
|
|||||||
TaskStatus,
|
TaskStatus,
|
||||||
TaskOutput
|
TaskOutput
|
||||||
} from '@/schemas/apiSchema'
|
} from '@/schemas/apiSchema'
|
||||||
|
import { assetService } from '@/platform/assets/services/assetService'
|
||||||
|
|
||||||
// Mock the api module
|
// Mock the api module
|
||||||
vi.mock('@/scripts/api', () => ({
|
vi.mock('@/scripts/api', () => ({
|
||||||
@@ -25,13 +28,17 @@ vi.mock('@/scripts/api', () => ({
|
|||||||
// Mock the asset service
|
// Mock the asset service
|
||||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||||
assetService: {
|
assetService: {
|
||||||
getAssetsByTag: vi.fn()
|
getAssetsByTag: vi.fn(),
|
||||||
|
getAssetsForNodeType: vi.fn()
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Mock distribution type
|
// Mock distribution type - hoisted so it can be changed per test
|
||||||
|
const mockIsCloud = vi.hoisted(() => ({ value: false }))
|
||||||
vi.mock('@/platform/distribution/types', () => ({
|
vi.mock('@/platform/distribution/types', () => ({
|
||||||
isCloud: false
|
get isCloud() {
|
||||||
|
return mockIsCloud.value
|
||||||
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Mock TaskItemImpl
|
// Mock TaskItemImpl
|
||||||
@@ -144,7 +151,7 @@ describe('assetsStore - Refactored (Option A)', () => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
setActivePinia(createPinia())
|
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||||
store = useAssetsStore()
|
store = useAssetsStore()
|
||||||
vi.clearAllMocks()
|
vi.clearAllMocks()
|
||||||
})
|
})
|
||||||
@@ -520,3 +527,158 @@ describe('assetsStore - Refactored (Option A)', () => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||||
|
mockIsCloud.value = true
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
mockIsCloud.value = false
|
||||||
|
})
|
||||||
|
|
||||||
|
const createMockAsset = (id: string) => ({
|
||||||
|
id,
|
||||||
|
name: `asset-${id}`,
|
||||||
|
size: 100,
|
||||||
|
created_at: new Date().toISOString(),
|
||||||
|
tags: ['models'],
|
||||||
|
preview_url: `http://test.com/${id}`
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getAssets cache invalidation', () => {
|
||||||
|
it('should invalidate cache before mutating assets during batch loading', async () => {
|
||||||
|
const store = useAssetsStore()
|
||||||
|
const nodeType = 'CheckpointLoaderSimple'
|
||||||
|
|
||||||
|
const firstBatch = Array.from({ length: 500 }, (_, i) =>
|
||||||
|
createMockAsset(`asset-${i}`)
|
||||||
|
)
|
||||||
|
const secondBatch = Array.from({ length: 100 }, (_, i) =>
|
||||||
|
createMockAsset(`asset-${500 + i}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
let callCount = 0
|
||||||
|
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
|
||||||
|
async () => {
|
||||||
|
callCount++
|
||||||
|
return callCount === 1 ? firstBatch : secondBatch
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await store.updateModelsForNodeType(nodeType)
|
||||||
|
|
||||||
|
// Wait for background batch loading to complete
|
||||||
|
await vi.waitFor(() => {
|
||||||
|
expect(
|
||||||
|
vi.mocked(assetService.getAssetsForNodeType)
|
||||||
|
).toHaveBeenCalledTimes(2)
|
||||||
|
})
|
||||||
|
|
||||||
|
const assets = store.getAssets(nodeType)
|
||||||
|
expect(assets).toHaveLength(600)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not return stale cached array after background batch completes', async () => {
|
||||||
|
const store = useAssetsStore()
|
||||||
|
const nodeType = 'LoraLoader'
|
||||||
|
|
||||||
|
// First batch must be exactly MODEL_BATCH_SIZE (500) to trigger hasMore
|
||||||
|
const firstBatch = Array.from({ length: 500 }, (_, i) =>
|
||||||
|
createMockAsset(`first-${i}`)
|
||||||
|
)
|
||||||
|
const secondBatch = [createMockAsset('new-asset')]
|
||||||
|
|
||||||
|
let callCount = 0
|
||||||
|
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
|
||||||
|
async () => {
|
||||||
|
callCount++
|
||||||
|
return callCount === 1 ? firstBatch : secondBatch
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await store.updateModelsForNodeType(nodeType)
|
||||||
|
|
||||||
|
// Wait for background batch loading to complete
|
||||||
|
await vi.waitFor(() => {
|
||||||
|
expect(
|
||||||
|
vi.mocked(assetService.getAssetsForNodeType)
|
||||||
|
).toHaveBeenCalledTimes(2)
|
||||||
|
})
|
||||||
|
|
||||||
|
const assets = store.getAssets(nodeType)
|
||||||
|
expect(assets).toHaveLength(501)
|
||||||
|
expect(assets.map((a) => a.id)).toContain('new-asset')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return cached array on subsequent getAssets calls', () => {
|
||||||
|
const store = useAssetsStore()
|
||||||
|
const nodeType = 'TestLoader'
|
||||||
|
|
||||||
|
const firstCall = store.getAssets(nodeType)
|
||||||
|
const secondCall = store.getAssets(nodeType)
|
||||||
|
|
||||||
|
expect(secondCall).toBe(firstCall)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('concurrent request handling', () => {
|
||||||
|
it('should discard stale request when newer request starts', async () => {
|
||||||
|
const store = useAssetsStore()
|
||||||
|
const nodeType = 'CheckpointLoaderSimple'
|
||||||
|
const firstBatch = Array.from({ length: 5 }, (_, i) =>
|
||||||
|
createMockAsset(`first-${i}`)
|
||||||
|
)
|
||||||
|
const secondBatch = Array.from({ length: 10 }, (_, i) =>
|
||||||
|
createMockAsset(`second-${i}`)
|
||||||
|
)
|
||||||
|
|
||||||
|
let resolveFirst: (value: ReturnType<typeof createMockAsset>[]) => void
|
||||||
|
const firstPromise = new Promise<ReturnType<typeof createMockAsset>[]>(
|
||||||
|
(resolve) => {
|
||||||
|
resolveFirst = resolve
|
||||||
|
}
|
||||||
|
)
|
||||||
|
let callCount = 0
|
||||||
|
vi.mocked(assetService.getAssetsForNodeType).mockImplementation(
|
||||||
|
async () => {
|
||||||
|
callCount++
|
||||||
|
return callCount === 1 ? firstPromise : secondBatch
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
const firstRequest = store.updateModelsForNodeType(nodeType)
|
||||||
|
const secondRequest = store.updateModelsForNodeType(nodeType)
|
||||||
|
resolveFirst!(firstBatch)
|
||||||
|
await Promise.all([firstRequest, secondRequest])
|
||||||
|
|
||||||
|
expect(store.getAssets(nodeType)).toHaveLength(10)
|
||||||
|
expect(
|
||||||
|
store.getAssets(nodeType).every((a) => a.id.startsWith('second-'))
|
||||||
|
).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('shallowReactive state reactivity', () => {
|
||||||
|
it('should trigger reactivity on isModelLoading change', async () => {
|
||||||
|
const store = useAssetsStore()
|
||||||
|
const nodeType = 'CheckpointLoaderSimple'
|
||||||
|
|
||||||
|
const loadingStates: boolean[] = []
|
||||||
|
watch(
|
||||||
|
() => store.isModelLoading(nodeType),
|
||||||
|
(val) => loadingStates.push(val),
|
||||||
|
{ immediate: true }
|
||||||
|
)
|
||||||
|
|
||||||
|
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([])
|
||||||
|
await store.updateModelsForNodeType(nodeType)
|
||||||
|
await nextTick()
|
||||||
|
|
||||||
|
expect(loadingStates).toContain(true)
|
||||||
|
expect(loadingStates).toContain(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import { useAsyncState, whenever } from '@vueuse/core'
|
import { useAsyncState, whenever } from '@vueuse/core'
|
||||||
import { isEqual } from 'es-toolkit'
|
|
||||||
import { defineStore } from 'pinia'
|
import { defineStore } from 'pinia'
|
||||||
import { computed, shallowReactive, ref } from 'vue'
|
import { computed, reactive, ref, shallowReactive } from 'vue'
|
||||||
import {
|
import {
|
||||||
mapInputFileToAssetItem,
|
mapInputFileToAssetItem,
|
||||||
mapTaskOutputToAssetItem
|
mapTaskOutputToAssetItem
|
||||||
} from '@/platform/assets/composables/media/assetMappers'
|
} from '@/platform/assets/composables/media/assetMappers'
|
||||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||||
import { assetService } from '@/platform/assets/services/assetService'
|
import { assetService } from '@/platform/assets/services/assetService'
|
||||||
|
import type { PaginationOptions } from '@/platform/assets/services/assetService'
|
||||||
import { isCloud } from '@/platform/distribution/types'
|
import { isCloud } from '@/platform/distribution/types'
|
||||||
import type { TaskItem } from '@/schemas/apiSchema'
|
import type { TaskItem } from '@/schemas/apiSchema'
|
||||||
import { api } from '@/scripts/api'
|
import { api } from '@/scripts/api'
|
||||||
@@ -261,6 +261,16 @@ export const useAssetsStore = defineStore('assets', () => {
|
|||||||
return inputAssetsByFilename.value.get(filename)?.name ?? filename
|
return inputAssetsByFilename.value.get(filename)?.name ?? filename
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const MODEL_BATCH_SIZE = 500
|
||||||
|
|
||||||
|
interface ModelPaginationState {
|
||||||
|
assets: Map<string, AssetItem>
|
||||||
|
offset: number
|
||||||
|
hasMore: boolean
|
||||||
|
isLoading: boolean
|
||||||
|
error?: Error
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Model assets cached by node type (e.g., 'CheckpointLoaderSimple', 'LoraLoader')
|
* Model assets cached by node type (e.g., 'CheckpointLoaderSimple', 'LoraLoader')
|
||||||
* Used by multiple loader nodes to avoid duplicate fetches
|
* Used by multiple loader nodes to avoid duplicate fetches
|
||||||
@@ -268,109 +278,183 @@ export const useAssetsStore = defineStore('assets', () => {
|
|||||||
*/
|
*/
|
||||||
const getModelState = () => {
|
const getModelState = () => {
|
||||||
if (isCloud) {
|
if (isCloud) {
|
||||||
const modelAssetsByNodeType = shallowReactive(
|
const modelStateByKey = ref(new Map<string, ModelPaginationState>())
|
||||||
new Map<string, AssetItem[]>()
|
|
||||||
)
|
|
||||||
const modelLoadingByNodeType = shallowReactive(new Map<string, boolean>())
|
|
||||||
const modelErrorByNodeType = shallowReactive(
|
|
||||||
new Map<string, Error | null>()
|
|
||||||
)
|
|
||||||
|
|
||||||
const stateByNodeType = shallowReactive(
|
const assetsArrayCache = new Map<
|
||||||
new Map<string, ReturnType<typeof useAsyncState<AssetItem[]>>>()
|
string,
|
||||||
)
|
{ source: Map<string, AssetItem>; array: AssetItem[] }
|
||||||
|
>()
|
||||||
|
|
||||||
|
const pendingRequestByKey = new Map<string, ModelPaginationState>()
|
||||||
|
|
||||||
|
function createState(): ModelPaginationState {
|
||||||
|
return reactive({
|
||||||
|
assets: new Map(),
|
||||||
|
offset: 0,
|
||||||
|
hasMore: true,
|
||||||
|
isLoading: false
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
function isStale(key: string, state: ModelPaginationState): boolean {
|
||||||
|
const committed = modelStateByKey.value.get(key)
|
||||||
|
const pending = pendingRequestByKey.get(key)
|
||||||
|
return committed !== state && pending !== state
|
||||||
|
}
|
||||||
|
|
||||||
|
const EMPTY_ASSETS: AssetItem[] = []
|
||||||
|
|
||||||
|
function getAssets(key: string): AssetItem[] {
|
||||||
|
const state = modelStateByKey.value.get(key)
|
||||||
|
const assetsMap = state?.assets
|
||||||
|
if (!assetsMap) return EMPTY_ASSETS
|
||||||
|
|
||||||
|
const cached = assetsArrayCache.get(key)
|
||||||
|
if (cached && cached.source === assetsMap) {
|
||||||
|
return cached.array
|
||||||
|
}
|
||||||
|
|
||||||
|
const array = Array.from(assetsMap.values())
|
||||||
|
assetsArrayCache.set(key, { source: assetsMap, array })
|
||||||
|
return array
|
||||||
|
}
|
||||||
|
|
||||||
|
function isLoading(key: string): boolean {
|
||||||
|
return modelStateByKey.value.get(key)?.isLoading ?? false
|
||||||
|
}
|
||||||
|
|
||||||
|
function getError(key: string): Error | undefined {
|
||||||
|
return modelStateByKey.value.get(key)?.error
|
||||||
|
}
|
||||||
|
|
||||||
|
function hasMore(key: string): boolean {
|
||||||
|
return modelStateByKey.value.get(key)?.hasMore ?? false
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Internal helper to fetch and cache assets with a given key and fetcher
|
* Internal helper to fetch and cache assets with a given key and fetcher.
|
||||||
|
* Loads first batch immediately, then progressively loads remaining batches.
|
||||||
|
* Keeps existing data visible until new data is successfully fetched.
|
||||||
*/
|
*/
|
||||||
async function updateModelsForKey(
|
async function updateModelsForKey(
|
||||||
key: string,
|
key: string,
|
||||||
fetcher: () => Promise<AssetItem[]>
|
fetcher: (options: PaginationOptions) => Promise<AssetItem[]>
|
||||||
): Promise<AssetItem[]> {
|
): Promise<void> {
|
||||||
if (!stateByNodeType.has(key)) {
|
const state = createState()
|
||||||
stateByNodeType.set(
|
state.isLoading = true
|
||||||
key,
|
|
||||||
useAsyncState(fetcher, [], {
|
const hasExistingData = modelStateByKey.value.has(key)
|
||||||
immediate: false,
|
if (hasExistingData) {
|
||||||
resetOnExecute: false,
|
pendingRequestByKey.set(key, state)
|
||||||
onError: (err) => {
|
} else {
|
||||||
console.error(`Error fetching model assets for ${key}:`, err)
|
modelStateByKey.value.set(key, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function loadBatches(): Promise<void> {
|
||||||
|
while (state.hasMore) {
|
||||||
|
try {
|
||||||
|
const newAssets = await fetcher({
|
||||||
|
limit: MODEL_BATCH_SIZE,
|
||||||
|
offset: state.offset
|
||||||
|
})
|
||||||
|
|
||||||
|
if (isStale(key, state)) return
|
||||||
|
|
||||||
|
const isFirstBatch = state.offset === 0
|
||||||
|
if (isFirstBatch) {
|
||||||
|
assetsArrayCache.delete(key)
|
||||||
|
if (hasExistingData) {
|
||||||
|
pendingRequestByKey.delete(key)
|
||||||
|
modelStateByKey.value.set(key, state)
|
||||||
|
}
|
||||||
|
state.assets = new Map(newAssets.map((a) => [a.id, a]))
|
||||||
|
} else {
|
||||||
|
const assetsToAdd = newAssets.filter(
|
||||||
|
(a) => !state.assets.has(a.id)
|
||||||
|
)
|
||||||
|
if (assetsToAdd.length > 0) {
|
||||||
|
assetsArrayCache.delete(key)
|
||||||
|
for (const asset of assetsToAdd) {
|
||||||
|
state.assets.set(asset.id, asset)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
|
||||||
)
|
state.offset += newAssets.length
|
||||||
|
state.hasMore = newAssets.length === MODEL_BATCH_SIZE
|
||||||
|
|
||||||
|
if (isFirstBatch) {
|
||||||
|
state.isLoading = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state.hasMore) {
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 50))
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
if (isStale(key, state)) return
|
||||||
|
state.error = err instanceof Error ? err : new Error(String(err))
|
||||||
|
state.hasMore = false
|
||||||
|
console.error(`Error loading batch for ${key}:`, err)
|
||||||
|
if (state.offset === 0) {
|
||||||
|
state.isLoading = false
|
||||||
|
pendingRequestByKey.delete(key)
|
||||||
|
// TODO: Add toast indicator for first-batch load failures
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const state = stateByNodeType.get(key)!
|
await loadBatches()
|
||||||
|
|
||||||
modelLoadingByNodeType.set(key, true)
|
|
||||||
modelErrorByNodeType.set(key, null)
|
|
||||||
|
|
||||||
try {
|
|
||||||
await state.execute()
|
|
||||||
} finally {
|
|
||||||
modelLoadingByNodeType.set(key, state.isLoading.value)
|
|
||||||
}
|
|
||||||
|
|
||||||
const assets = state.state.value
|
|
||||||
const existingAssets = modelAssetsByNodeType.get(key)
|
|
||||||
|
|
||||||
if (!isEqual(existingAssets, assets)) {
|
|
||||||
modelAssetsByNodeType.set(key, assets)
|
|
||||||
}
|
|
||||||
|
|
||||||
modelErrorByNodeType.set(
|
|
||||||
key,
|
|
||||||
state.error.value instanceof Error ? state.error.value : null
|
|
||||||
)
|
|
||||||
|
|
||||||
return assets
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fetch and cache model assets for a specific node type
|
* Fetch and cache model assets for a specific node type
|
||||||
* @param nodeType The node type to fetch assets for (e.g., 'CheckpointLoaderSimple')
|
* @param nodeType The node type to fetch assets for (e.g., 'CheckpointLoaderSimple')
|
||||||
* @returns Promise resolving to the fetched assets
|
|
||||||
*/
|
*/
|
||||||
async function updateModelsForNodeType(
|
async function updateModelsForNodeType(nodeType: string): Promise<void> {
|
||||||
nodeType: string
|
await updateModelsForKey(nodeType, (opts) =>
|
||||||
): Promise<AssetItem[]> {
|
assetService.getAssetsForNodeType(nodeType, opts)
|
||||||
return updateModelsForKey(nodeType, () =>
|
|
||||||
assetService.getAssetsForNodeType(nodeType)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fetch and cache model assets for a specific tag
|
* Fetch and cache model assets for a specific tag
|
||||||
* @param tag The tag to fetch assets for (e.g., 'models')
|
* @param tag The tag to fetch assets for (e.g., 'models')
|
||||||
* @returns Promise resolving to the fetched assets
|
|
||||||
*/
|
*/
|
||||||
async function updateModelsForTag(tag: string): Promise<AssetItem[]> {
|
async function updateModelsForTag(tag: string): Promise<void> {
|
||||||
const key = `tag:${tag}`
|
const key = `tag:${tag}`
|
||||||
return updateModelsForKey(key, () => assetService.getAssetsByTag(tag))
|
await updateModelsForKey(key, (opts) =>
|
||||||
|
assetService.getAssetsByTag(tag, true, opts)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
modelAssetsByNodeType,
|
getAssets,
|
||||||
modelLoadingByNodeType,
|
isLoading,
|
||||||
modelErrorByNodeType,
|
getError,
|
||||||
|
hasMore,
|
||||||
updateModelsForNodeType,
|
updateModelsForNodeType,
|
||||||
updateModelsForTag
|
updateModelsForTag
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const emptyAssets: AssetItem[] = []
|
||||||
return {
|
return {
|
||||||
modelAssetsByNodeType: shallowReactive(new Map<string, AssetItem[]>()),
|
getAssets: () => emptyAssets,
|
||||||
modelLoadingByNodeType: shallowReactive(new Map<string, boolean>()),
|
isLoading: () => false,
|
||||||
modelErrorByNodeType: shallowReactive(new Map<string, Error | null>()),
|
getError: () => undefined,
|
||||||
updateModelsForNodeType: async () => [],
|
hasMore: () => false,
|
||||||
updateModelsForTag: async () => []
|
updateModelsForNodeType: async () => {},
|
||||||
|
updateModelsForTag: async () => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const {
|
const {
|
||||||
modelAssetsByNodeType,
|
getAssets,
|
||||||
modelLoadingByNodeType,
|
isLoading: isModelLoading,
|
||||||
modelErrorByNodeType,
|
getError,
|
||||||
|
hasMore,
|
||||||
updateModelsForNodeType,
|
updateModelsForNodeType,
|
||||||
updateModelsForTag
|
updateModelsForTag
|
||||||
} = getModelState()
|
} = getModelState()
|
||||||
@@ -432,10 +516,13 @@ export const useAssetsStore = defineStore('assets', () => {
|
|||||||
inputAssetsByFilename,
|
inputAssetsByFilename,
|
||||||
getInputName,
|
getInputName,
|
||||||
|
|
||||||
// Model assets
|
// Model assets - accessors
|
||||||
modelAssetsByNodeType,
|
getAssets,
|
||||||
modelLoadingByNodeType,
|
isModelLoading,
|
||||||
modelErrorByNodeType,
|
getError,
|
||||||
|
hasMore,
|
||||||
|
|
||||||
|
// Model assets - actions
|
||||||
updateModelsForNodeType,
|
updateModelsForNodeType,
|
||||||
updateModelsForTag
|
updateModelsForTag
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user