fix: invalidate loader node dropdown cache after model asset deletion

Amp-Thread-ID: https://ampcode.com/threads/T-019c0830-033a-7224-85ec-766b1765426d
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Subagent 5
2026-01-28 21:47:18 -08:00
parent fe7d89d1b1
commit 90d681ad31
4 changed files with 311 additions and 8 deletions

View File

@@ -35,9 +35,36 @@ vi.mock('vue-i18n', () => ({
})
}))
const mockShowDialog = vi.hoisted(() => vi.fn())
vi.mock('@/stores/dialogStore', () => ({
useDialogStore: () => ({
showDialog: vi.fn()
showDialog: mockShowDialog
})
}))
const mockInvalidateModelsForCategory = vi.hoisted(() => vi.fn())
const mockSetAssetDeleting = vi.hoisted(() => vi.fn())
const mockUpdateHistory = vi.hoisted(() => vi.fn())
const mockUpdateInputs = vi.hoisted(() => vi.fn())
vi.mock('@/stores/assetsStore', () => ({
useAssetsStore: () => ({
setAssetDeleting: mockSetAssetDeleting,
updateHistory: mockUpdateHistory,
updateInputs: mockUpdateInputs,
invalidateModelsForCategory: mockInvalidateModelsForCategory
})
}))
const mockGetCategoryForNodeType = vi.hoisted(() => vi.fn())
vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: () => ({
getCategoryForNodeType: mockGetCategoryForNodeType,
getAllNodeProviders: vi.fn((category: string) => {
if (category === 'checkpoints' || category === 'loras') {
return [{ nodeDef: { name: 'TestLoader' }, key: 'model_name' }]
}
return []
})
})
}))
@@ -93,14 +120,33 @@ vi.mock('@/utils/typeGuardUtil', () => ({
isResultItemType: vi.fn().mockReturnValue(true)
}))
const mockGetAssetType = vi.hoisted(() => vi.fn())
vi.mock('@/platform/assets/utils/assetTypeUtil', () => ({
getAssetType: vi.fn().mockReturnValue('input')
getAssetType: mockGetAssetType
}))
vi.mock('../schemas/assetMetadataSchema', () => ({
getOutputAssetMetadata: vi.fn().mockReturnValue(null)
}))
const mockDeleteAsset = vi.hoisted(() => vi.fn())
vi.mock('../services/assetService', () => ({
assetService: {
deleteAsset: mockDeleteAsset
}
}))
vi.mock('@/scripts/api', () => ({
api: {
deleteItem: vi.fn(),
apiURL: vi.fn((path: string) => `http://localhost:8188/api${path}`),
internalURL: vi.fn((path: string) => `http://localhost:8188${path}`),
addEventListener: vi.fn(),
removeEventListener: vi.fn(),
user: 'test-user'
}
}))
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
return {
id: 'test-asset-id',
@@ -218,4 +264,80 @@ describe('useMediaAssetActions', () => {
})
})
})
describe('deleteAssets - model cache invalidation', () => {
beforeEach(() => {
mockIsCloud.value = true
mockGetAssetType.mockReturnValue('input')
mockDeleteAsset.mockResolvedValue(undefined)
mockInvalidateModelsForCategory.mockClear()
mockSetAssetDeleting.mockClear()
mockUpdateHistory.mockClear()
mockUpdateInputs.mockClear()
})
it('should invalidate model cache when deleting a model asset', async () => {
const actions = useMediaAssetActions()
const modelAsset = createMockAsset({
id: 'checkpoint-1',
name: 'model.safetensors',
tags: ['models', 'checkpoints']
})
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
await actions.deleteAssets(modelAsset)
expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith(
'checkpoints'
)
})
it('should invalidate multiple categories for multiple assets', async () => {
const actions = useMediaAssetActions()
const assets = [
createMockAsset({ id: '1', tags: ['models', 'checkpoints'] }),
createMockAsset({ id: '2', tags: ['models', 'loras'] })
]
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
await actions.deleteAssets(assets)
expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith(
'checkpoints'
)
expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith('loras')
})
it('should not invalidate model cache for non-model assets', async () => {
const actions = useMediaAssetActions()
const inputAsset = createMockAsset({
id: 'input-1',
name: 'image.png',
tags: ['input']
})
mockShowDialog.mockImplementation(
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
void props.onConfirm()
}
)
await actions.deleteAssets(inputAsset)
expect(mockInvalidateModelsForCategory).not.toHaveBeenCalled()
})
})
})

View File

@@ -14,6 +14,7 @@ import { useNodeDefStore } from '@/stores/nodeDefStore'
import { getOutputAssetMetadata } from '../schemas/assetMetadataSchema'
import { useAssetsStore } from '@/stores/assetsStore'
import { useDialogStore } from '@/stores/dialogStore'
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
import { getAssetType } from '../utils/assetTypeUtil'
import { getAssetUrl } from '../utils/assetUrlUtil'
import { createAnnotatedPath } from '@/utils/createAnnotatedPath'
@@ -586,6 +587,27 @@ export function useMediaAssetActions() {
await assetsStore.updateInputs()
}
// Invalidate model caches for affected categories
const modelToNodeStore = useModelToNodeStore()
const modelCategories = new Set<string>()
const excludedTags = ['models', 'input', 'output']
for (const asset of assetArray) {
for (const tag of asset.tags ?? []) {
if (excludedTags.includes(tag)) continue
const providers = modelToNodeStore.getAllNodeProviders(tag)
if (providers.length > 0) {
modelCategories.add(tag)
}
}
}
await Promise.allSettled(
[...modelCategories].map((category) =>
assetsStore.invalidateModelsForCategory(category)
)
)
// Show appropriate feedback based on results
if (failed.length === 0) {
toast.add({

View File

@@ -36,6 +36,32 @@ vi.mock('@/platform/distribution/types', () => ({
}
}))
// Mock modelToNodeStore with proper node providers
vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: () => ({
getAllNodeProviders: vi.fn((category: string) => {
const providers: Record<
string,
Array<{ nodeDef: { name: string }; key: string }>
> = {
checkpoints: [
{ nodeDef: { name: 'CheckpointLoaderSimple' }, key: 'ckpt_name' },
{ nodeDef: { name: 'ImageOnlyCheckpointLoader' }, key: 'ckpt_name' }
],
loras: [
{ nodeDef: { name: 'LoraLoader' }, key: 'lora_name' },
{ nodeDef: { name: 'LoraLoaderModelOnly' }, key: 'lora_name' }
],
vae: [{ nodeDef: { name: 'VAELoader' }, key: 'vae_name' }]
}
return providers[category] ?? []
}),
getNodeProvider: vi.fn(),
getCategoryForNodeType: vi.fn(),
registerDefaults: vi.fn()
})
}))
// Mock TaskItemImpl
vi.mock('@/stores/queueStore', () => ({
TaskItemImpl: class {
@@ -472,12 +498,12 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
mockIsCloud.value = false
})
const createMockAsset = (id: string) => ({
const createMockAsset = (id: string, tags: string[] = ['models']) => ({
id,
name: `asset-${id}`,
size: 100,
created_at: new Date().toISOString(),
tags: ['models'],
tags,
preview_url: `http://test.com/${id}`
})
@@ -614,4 +640,81 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
expect(loadingStates).toContain(false)
})
})
describe('invalidateModelsForCategory', () => {
it('should invalidate model cache for all node types providing a category', async () => {
const store = useAssetsStore()
const nodeType = 'CheckpointLoaderSimple'
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
createMockAsset('existing-1'),
createMockAsset('existing-2')
])
await store.updateModelsForNodeType(nodeType)
expect(store.getAssets(nodeType)).toHaveLength(2)
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
createMockAsset('existing-1')
])
await store.invalidateModelsForCategory('checkpoints')
expect(store.getAssets(nodeType)).toHaveLength(1)
})
it('should invalidate multiple node types for same category', async () => {
const store = useAssetsStore()
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([
createMockAsset('asset-1')
])
await store.updateModelsForNodeType('CheckpointLoaderSimple')
await store.updateModelsForNodeType('ImageOnlyCheckpointLoader')
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([])
await store.invalidateModelsForCategory('checkpoints')
expect(store.getAssets('CheckpointLoaderSimple')).toHaveLength(0)
expect(store.getAssets('ImageOnlyCheckpointLoader')).toHaveLength(0)
})
it('should also invalidate tag-based caches', async () => {
const store = useAssetsStore()
vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([
createMockAsset('tag-asset-1')
])
await store.updateModelsForTag('checkpoints')
expect(store.getAssets('tag:checkpoints')).toHaveLength(1)
vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([])
await store.invalidateModelsForCategory('checkpoints')
expect(store.getAssets('tag:checkpoints')).toHaveLength(0)
})
})
describe('removeAssetFromCache', () => {
it('should remove specific asset from node type cache', async () => {
const store = useAssetsStore()
const nodeType = 'LoraLoader'
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
createMockAsset('lora-1'),
createMockAsset('lora-2'),
createMockAsset('lora-3')
])
await store.updateModelsForNodeType(nodeType)
expect(store.getAssets(nodeType)).toHaveLength(3)
store.removeAssetFromCache('lora-2', 'loras')
expect(store.getAssets(nodeType)).toHaveLength(2)
expect(store.getAssets(nodeType).map((a) => a.id)).not.toContain('lora-2')
})
})
})

View File

@@ -546,6 +546,54 @@ export const useAssetsStore = defineStore('assets', () => {
}
}
/**
* Invalidate model caches for a given category (e.g., 'checkpoints', 'loras')
* Refreshes all node types that provide this category plus tag-based caches
* @param category The model category to invalidate (e.g., 'checkpoints')
*/
async function invalidateModelsForCategory(
category: string
): Promise<void> {
const providers = modelToNodeStore
.getAllNodeProviders(category)
.filter((provider) => provider.nodeDef?.name)
const nodeTypeUpdates = providers.map((provider) =>
updateModelsForNodeType(provider.nodeDef.name)
)
const tagUpdates = [
updateModelsForTag(category),
updateModelsForTag('models')
]
await Promise.allSettled([...nodeTypeUpdates, ...tagUpdates])
}
/**
* Remove a specific asset from all caches for a given category
* Used for optimistic updates after asset deletion
* @param assetId The asset ID to remove
* @param category The model category (e.g., 'loras')
*/
function removeAssetFromCache(assetId: string, category: string): void {
const providers = modelToNodeStore.getAllNodeProviders(category)
for (const provider of providers) {
const nodeType = provider.nodeDef?.name
if (!nodeType) continue
const state = modelStateByKey.value.get(nodeType)
if (!state) continue
state.assets.delete(assetId)
assetsArrayCache.delete(nodeType)
}
assetsArrayCache.delete(`tag:${category}`)
assetsArrayCache.delete('tag:models')
}
return {
getAssets,
isLoading,
@@ -555,7 +603,9 @@ export const useAssetsStore = defineStore('assets', () => {
updateModelsForNodeType,
updateModelsForTag,
updateAssetMetadata,
updateAssetTags
updateAssetTags,
invalidateModelsForCategory,
removeAssetFromCache
}
}
@@ -569,7 +619,9 @@ export const useAssetsStore = defineStore('assets', () => {
updateModelsForNodeType: async () => {},
updateModelsForTag: async () => {},
updateAssetMetadata: async () => {},
updateAssetTags: async () => {}
updateAssetTags: async () => {},
invalidateModelsForCategory: async () => {},
removeAssetFromCache: () => {}
}
}
@@ -582,7 +634,9 @@ export const useAssetsStore = defineStore('assets', () => {
updateModelsForNodeType,
updateModelsForTag,
updateAssetMetadata,
updateAssetTags
updateAssetTags,
invalidateModelsForCategory,
removeAssetFromCache
} = getModelState()
// Watch for completed downloads and refresh model caches
@@ -658,6 +712,8 @@ export const useAssetsStore = defineStore('assets', () => {
updateModelsForNodeType,
updateModelsForTag,
updateAssetMetadata,
updateAssetTags
updateAssetTags,
invalidateModelsForCategory,
removeAssetFromCache
}
})