From 90d681ad312ea04e6c3a5dd506995cc06035149c Mon Sep 17 00:00:00 2001 From: Subagent 5 Date: Wed, 28 Jan 2026 21:47:18 -0800 Subject: [PATCH] fix: invalidate loader node dropdown cache after model asset deletion Amp-Thread-ID: https://ampcode.com/threads/T-019c0830-033a-7224-85ec-766b1765426d Co-authored-by: Amp --- .../composables/useMediaAssetActions.test.ts | 126 +++++++++++++++++- .../composables/useMediaAssetActions.ts | 22 +++ src/stores/assetsStore.test.ts | 107 ++++++++++++++- src/stores/assetsStore.ts | 64 ++++++++- 4 files changed, 311 insertions(+), 8 deletions(-) diff --git a/src/platform/assets/composables/useMediaAssetActions.test.ts b/src/platform/assets/composables/useMediaAssetActions.test.ts index 2d0a7a291..dcf7712e7 100644 --- a/src/platform/assets/composables/useMediaAssetActions.test.ts +++ b/src/platform/assets/composables/useMediaAssetActions.test.ts @@ -35,9 +35,36 @@ vi.mock('vue-i18n', () => ({ }) })) +const mockShowDialog = vi.hoisted(() => vi.fn()) vi.mock('@/stores/dialogStore', () => ({ useDialogStore: () => ({ - showDialog: vi.fn() + showDialog: mockShowDialog + }) +})) + +const mockInvalidateModelsForCategory = vi.hoisted(() => vi.fn()) +const mockSetAssetDeleting = vi.hoisted(() => vi.fn()) +const mockUpdateHistory = vi.hoisted(() => vi.fn()) +const mockUpdateInputs = vi.hoisted(() => vi.fn()) +vi.mock('@/stores/assetsStore', () => ({ + useAssetsStore: () => ({ + setAssetDeleting: mockSetAssetDeleting, + updateHistory: mockUpdateHistory, + updateInputs: mockUpdateInputs, + invalidateModelsForCategory: mockInvalidateModelsForCategory + }) +})) + +const mockGetCategoryForNodeType = vi.hoisted(() => vi.fn()) +vi.mock('@/stores/modelToNodeStore', () => ({ + useModelToNodeStore: () => ({ + getCategoryForNodeType: mockGetCategoryForNodeType, + getAllNodeProviders: vi.fn((category: string) => { + if (category === 'checkpoints' || category === 'loras') { + return [{ nodeDef: { name: 'TestLoader' }, key: 'model_name' }] + } + return [] + }) }) })) @@ -93,14 +120,33 @@ vi.mock('@/utils/typeGuardUtil', () => ({ isResultItemType: vi.fn().mockReturnValue(true) })) +const mockGetAssetType = vi.hoisted(() => vi.fn()) vi.mock('@/platform/assets/utils/assetTypeUtil', () => ({ - getAssetType: vi.fn().mockReturnValue('input') + getAssetType: mockGetAssetType })) vi.mock('../schemas/assetMetadataSchema', () => ({ getOutputAssetMetadata: vi.fn().mockReturnValue(null) })) +const mockDeleteAsset = vi.hoisted(() => vi.fn()) +vi.mock('../services/assetService', () => ({ + assetService: { + deleteAsset: mockDeleteAsset + } +})) + +vi.mock('@/scripts/api', () => ({ + api: { + deleteItem: vi.fn(), + apiURL: vi.fn((path: string) => `http://localhost:8188/api${path}`), + internalURL: vi.fn((path: string) => `http://localhost:8188${path}`), + addEventListener: vi.fn(), + removeEventListener: vi.fn(), + user: 'test-user' + } +})) + function createMockAsset(overrides: Partial = {}): AssetItem { return { id: 'test-asset-id', @@ -218,4 +264,80 @@ describe('useMediaAssetActions', () => { }) }) }) + + describe('deleteAssets - model cache invalidation', () => { + beforeEach(() => { + mockIsCloud.value = true + mockGetAssetType.mockReturnValue('input') + mockDeleteAsset.mockResolvedValue(undefined) + mockInvalidateModelsForCategory.mockClear() + mockSetAssetDeleting.mockClear() + mockUpdateHistory.mockClear() + mockUpdateInputs.mockClear() + }) + + it('should invalidate model cache when deleting a model asset', async () => { + const actions = useMediaAssetActions() + + const modelAsset = createMockAsset({ + id: 'checkpoint-1', + name: 'model.safetensors', + tags: ['models', 'checkpoints'] + }) + + mockShowDialog.mockImplementation( + ({ props }: { props: { onConfirm: () => Promise } }) => { + void props.onConfirm() + } + ) + + await actions.deleteAssets(modelAsset) + + expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith( + 'checkpoints' + ) + }) + + it('should invalidate multiple categories for multiple assets', async () => { + const actions = useMediaAssetActions() + + const assets = [ + createMockAsset({ id: '1', tags: ['models', 'checkpoints'] }), + createMockAsset({ id: '2', tags: ['models', 'loras'] }) + ] + + mockShowDialog.mockImplementation( + ({ props }: { props: { onConfirm: () => Promise } }) => { + void props.onConfirm() + } + ) + + await actions.deleteAssets(assets) + + expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith( + 'checkpoints' + ) + expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith('loras') + }) + + it('should not invalidate model cache for non-model assets', async () => { + const actions = useMediaAssetActions() + + const inputAsset = createMockAsset({ + id: 'input-1', + name: 'image.png', + tags: ['input'] + }) + + mockShowDialog.mockImplementation( + ({ props }: { props: { onConfirm: () => Promise } }) => { + void props.onConfirm() + } + ) + + await actions.deleteAssets(inputAsset) + + expect(mockInvalidateModelsForCategory).not.toHaveBeenCalled() + }) + }) }) diff --git a/src/platform/assets/composables/useMediaAssetActions.ts b/src/platform/assets/composables/useMediaAssetActions.ts index 98a9444ba..760af4359 100644 --- a/src/platform/assets/composables/useMediaAssetActions.ts +++ b/src/platform/assets/composables/useMediaAssetActions.ts @@ -14,6 +14,7 @@ import { useNodeDefStore } from '@/stores/nodeDefStore' import { getOutputAssetMetadata } from '../schemas/assetMetadataSchema' import { useAssetsStore } from '@/stores/assetsStore' import { useDialogStore } from '@/stores/dialogStore' +import { useModelToNodeStore } from '@/stores/modelToNodeStore' import { getAssetType } from '../utils/assetTypeUtil' import { getAssetUrl } from '../utils/assetUrlUtil' import { createAnnotatedPath } from '@/utils/createAnnotatedPath' @@ -586,6 +587,27 @@ export function useMediaAssetActions() { await assetsStore.updateInputs() } + // Invalidate model caches for affected categories + const modelToNodeStore = useModelToNodeStore() + const modelCategories = new Set() + const excludedTags = ['models', 'input', 'output'] + + for (const asset of assetArray) { + for (const tag of asset.tags ?? []) { + if (excludedTags.includes(tag)) continue + const providers = modelToNodeStore.getAllNodeProviders(tag) + if (providers.length > 0) { + modelCategories.add(tag) + } + } + } + + await Promise.allSettled( + [...modelCategories].map((category) => + assetsStore.invalidateModelsForCategory(category) + ) + ) + // Show appropriate feedback based on results if (failed.length === 0) { toast.add({ diff --git a/src/stores/assetsStore.test.ts b/src/stores/assetsStore.test.ts index 7f9822001..46754b232 100644 --- a/src/stores/assetsStore.test.ts +++ b/src/stores/assetsStore.test.ts @@ -36,6 +36,32 @@ vi.mock('@/platform/distribution/types', () => ({ } })) +// Mock modelToNodeStore with proper node providers +vi.mock('@/stores/modelToNodeStore', () => ({ + useModelToNodeStore: () => ({ + getAllNodeProviders: vi.fn((category: string) => { + const providers: Record< + string, + Array<{ nodeDef: { name: string }; key: string }> + > = { + checkpoints: [ + { nodeDef: { name: 'CheckpointLoaderSimple' }, key: 'ckpt_name' }, + { nodeDef: { name: 'ImageOnlyCheckpointLoader' }, key: 'ckpt_name' } + ], + loras: [ + { nodeDef: { name: 'LoraLoader' }, key: 'lora_name' }, + { nodeDef: { name: 'LoraLoaderModelOnly' }, key: 'lora_name' } + ], + vae: [{ nodeDef: { name: 'VAELoader' }, key: 'vae_name' }] + } + return providers[category] ?? [] + }), + getNodeProvider: vi.fn(), + getCategoryForNodeType: vi.fn(), + registerDefaults: vi.fn() + }) +})) + // Mock TaskItemImpl vi.mock('@/stores/queueStore', () => ({ TaskItemImpl: class { @@ -472,12 +498,12 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => { mockIsCloud.value = false }) - const createMockAsset = (id: string) => ({ + const createMockAsset = (id: string, tags: string[] = ['models']) => ({ id, name: `asset-${id}`, size: 100, created_at: new Date().toISOString(), - tags: ['models'], + tags, preview_url: `http://test.com/${id}` }) @@ -614,4 +640,81 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => { expect(loadingStates).toContain(false) }) }) + + describe('invalidateModelsForCategory', () => { + it('should invalidate model cache for all node types providing a category', async () => { + const store = useAssetsStore() + const nodeType = 'CheckpointLoaderSimple' + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([ + createMockAsset('existing-1'), + createMockAsset('existing-2') + ]) + await store.updateModelsForNodeType(nodeType) + expect(store.getAssets(nodeType)).toHaveLength(2) + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([ + createMockAsset('existing-1') + ]) + + await store.invalidateModelsForCategory('checkpoints') + + expect(store.getAssets(nodeType)).toHaveLength(1) + }) + + it('should invalidate multiple node types for same category', async () => { + const store = useAssetsStore() + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([ + createMockAsset('asset-1') + ]) + + await store.updateModelsForNodeType('CheckpointLoaderSimple') + await store.updateModelsForNodeType('ImageOnlyCheckpointLoader') + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([]) + + await store.invalidateModelsForCategory('checkpoints') + + expect(store.getAssets('CheckpointLoaderSimple')).toHaveLength(0) + expect(store.getAssets('ImageOnlyCheckpointLoader')).toHaveLength(0) + }) + + it('should also invalidate tag-based caches', async () => { + const store = useAssetsStore() + + vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([ + createMockAsset('tag-asset-1') + ]) + + await store.updateModelsForTag('checkpoints') + expect(store.getAssets('tag:checkpoints')).toHaveLength(1) + + vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([]) + + await store.invalidateModelsForCategory('checkpoints') + + expect(store.getAssets('tag:checkpoints')).toHaveLength(0) + }) + }) + + describe('removeAssetFromCache', () => { + it('should remove specific asset from node type cache', async () => { + const store = useAssetsStore() + const nodeType = 'LoraLoader' + + vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([ + createMockAsset('lora-1'), + createMockAsset('lora-2'), + createMockAsset('lora-3') + ]) + await store.updateModelsForNodeType(nodeType) + expect(store.getAssets(nodeType)).toHaveLength(3) + + store.removeAssetFromCache('lora-2', 'loras') + + expect(store.getAssets(nodeType)).toHaveLength(2) + expect(store.getAssets(nodeType).map((a) => a.id)).not.toContain('lora-2') + }) + }) }) diff --git a/src/stores/assetsStore.ts b/src/stores/assetsStore.ts index 41fb3e878..f951d7943 100644 --- a/src/stores/assetsStore.ts +++ b/src/stores/assetsStore.ts @@ -546,6 +546,54 @@ export const useAssetsStore = defineStore('assets', () => { } } + /** + * Invalidate model caches for a given category (e.g., 'checkpoints', 'loras') + * Refreshes all node types that provide this category plus tag-based caches + * @param category The model category to invalidate (e.g., 'checkpoints') + */ + async function invalidateModelsForCategory( + category: string + ): Promise { + const providers = modelToNodeStore + .getAllNodeProviders(category) + .filter((provider) => provider.nodeDef?.name) + + const nodeTypeUpdates = providers.map((provider) => + updateModelsForNodeType(provider.nodeDef.name) + ) + + const tagUpdates = [ + updateModelsForTag(category), + updateModelsForTag('models') + ] + + await Promise.allSettled([...nodeTypeUpdates, ...tagUpdates]) + } + + /** + * Remove a specific asset from all caches for a given category + * Used for optimistic updates after asset deletion + * @param assetId The asset ID to remove + * @param category The model category (e.g., 'loras') + */ + function removeAssetFromCache(assetId: string, category: string): void { + const providers = modelToNodeStore.getAllNodeProviders(category) + + for (const provider of providers) { + const nodeType = provider.nodeDef?.name + if (!nodeType) continue + + const state = modelStateByKey.value.get(nodeType) + if (!state) continue + + state.assets.delete(assetId) + assetsArrayCache.delete(nodeType) + } + + assetsArrayCache.delete(`tag:${category}`) + assetsArrayCache.delete('tag:models') + } + return { getAssets, isLoading, @@ -555,7 +603,9 @@ export const useAssetsStore = defineStore('assets', () => { updateModelsForNodeType, updateModelsForTag, updateAssetMetadata, - updateAssetTags + updateAssetTags, + invalidateModelsForCategory, + removeAssetFromCache } } @@ -569,7 +619,9 @@ export const useAssetsStore = defineStore('assets', () => { updateModelsForNodeType: async () => {}, updateModelsForTag: async () => {}, updateAssetMetadata: async () => {}, - updateAssetTags: async () => {} + updateAssetTags: async () => {}, + invalidateModelsForCategory: async () => {}, + removeAssetFromCache: () => {} } } @@ -582,7 +634,9 @@ export const useAssetsStore = defineStore('assets', () => { updateModelsForNodeType, updateModelsForTag, updateAssetMetadata, - updateAssetTags + updateAssetTags, + invalidateModelsForCategory, + removeAssetFromCache } = getModelState() // Watch for completed downloads and refresh model caches @@ -658,6 +712,8 @@ export const useAssetsStore = defineStore('assets', () => { updateModelsForNodeType, updateModelsForTag, updateAssetMetadata, - updateAssetTags + updateAssetTags, + invalidateModelsForCategory, + removeAssetFromCache } })