mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-22 07:44:11 +00:00
fix: invalidate loader node dropdown cache after model asset deletion
Amp-Thread-ID: https://ampcode.com/threads/T-019c0830-033a-7224-85ec-766b1765426d Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
@@ -35,9 +35,36 @@ vi.mock('vue-i18n', () => ({
|
||||
})
|
||||
}))
|
||||
|
||||
const mockShowDialog = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/dialogStore', () => ({
|
||||
useDialogStore: () => ({
|
||||
showDialog: vi.fn()
|
||||
showDialog: mockShowDialog
|
||||
})
|
||||
}))
|
||||
|
||||
const mockInvalidateModelsForCategory = vi.hoisted(() => vi.fn())
|
||||
const mockSetAssetDeleting = vi.hoisted(() => vi.fn())
|
||||
const mockUpdateHistory = vi.hoisted(() => vi.fn())
|
||||
const mockUpdateInputs = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/assetsStore', () => ({
|
||||
useAssetsStore: () => ({
|
||||
setAssetDeleting: mockSetAssetDeleting,
|
||||
updateHistory: mockUpdateHistory,
|
||||
updateInputs: mockUpdateInputs,
|
||||
invalidateModelsForCategory: mockInvalidateModelsForCategory
|
||||
})
|
||||
}))
|
||||
|
||||
const mockGetCategoryForNodeType = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({
|
||||
getCategoryForNodeType: mockGetCategoryForNodeType,
|
||||
getAllNodeProviders: vi.fn((category: string) => {
|
||||
if (category === 'checkpoints' || category === 'loras') {
|
||||
return [{ nodeDef: { name: 'TestLoader' }, key: 'model_name' }]
|
||||
}
|
||||
return []
|
||||
})
|
||||
})
|
||||
}))
|
||||
|
||||
@@ -93,14 +120,33 @@ vi.mock('@/utils/typeGuardUtil', () => ({
|
||||
isResultItemType: vi.fn().mockReturnValue(true)
|
||||
}))
|
||||
|
||||
const mockGetAssetType = vi.hoisted(() => vi.fn())
|
||||
vi.mock('@/platform/assets/utils/assetTypeUtil', () => ({
|
||||
getAssetType: vi.fn().mockReturnValue('input')
|
||||
getAssetType: mockGetAssetType
|
||||
}))
|
||||
|
||||
vi.mock('../schemas/assetMetadataSchema', () => ({
|
||||
getOutputAssetMetadata: vi.fn().mockReturnValue(null)
|
||||
}))
|
||||
|
||||
const mockDeleteAsset = vi.hoisted(() => vi.fn())
|
||||
vi.mock('../services/assetService', () => ({
|
||||
assetService: {
|
||||
deleteAsset: mockDeleteAsset
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
deleteItem: vi.fn(),
|
||||
apiURL: vi.fn((path: string) => `http://localhost:8188/api${path}`),
|
||||
internalURL: vi.fn((path: string) => `http://localhost:8188${path}`),
|
||||
addEventListener: vi.fn(),
|
||||
removeEventListener: vi.fn(),
|
||||
user: 'test-user'
|
||||
}
|
||||
}))
|
||||
|
||||
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'test-asset-id',
|
||||
@@ -218,4 +264,80 @@ describe('useMediaAssetActions', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('deleteAssets - model cache invalidation', () => {
|
||||
beforeEach(() => {
|
||||
mockIsCloud.value = true
|
||||
mockGetAssetType.mockReturnValue('input')
|
||||
mockDeleteAsset.mockResolvedValue(undefined)
|
||||
mockInvalidateModelsForCategory.mockClear()
|
||||
mockSetAssetDeleting.mockClear()
|
||||
mockUpdateHistory.mockClear()
|
||||
mockUpdateInputs.mockClear()
|
||||
})
|
||||
|
||||
it('should invalidate model cache when deleting a model asset', async () => {
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
const modelAsset = createMockAsset({
|
||||
id: 'checkpoint-1',
|
||||
name: 'model.safetensors',
|
||||
tags: ['models', 'checkpoints']
|
||||
})
|
||||
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
|
||||
await actions.deleteAssets(modelAsset)
|
||||
|
||||
expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith(
|
||||
'checkpoints'
|
||||
)
|
||||
})
|
||||
|
||||
it('should invalidate multiple categories for multiple assets', async () => {
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
const assets = [
|
||||
createMockAsset({ id: '1', tags: ['models', 'checkpoints'] }),
|
||||
createMockAsset({ id: '2', tags: ['models', 'loras'] })
|
||||
]
|
||||
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
|
||||
await actions.deleteAssets(assets)
|
||||
|
||||
expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith(
|
||||
'checkpoints'
|
||||
)
|
||||
expect(mockInvalidateModelsForCategory).toHaveBeenCalledWith('loras')
|
||||
})
|
||||
|
||||
it('should not invalidate model cache for non-model assets', async () => {
|
||||
const actions = useMediaAssetActions()
|
||||
|
||||
const inputAsset = createMockAsset({
|
||||
id: 'input-1',
|
||||
name: 'image.png',
|
||||
tags: ['input']
|
||||
})
|
||||
|
||||
mockShowDialog.mockImplementation(
|
||||
({ props }: { props: { onConfirm: () => Promise<void> } }) => {
|
||||
void props.onConfirm()
|
||||
}
|
||||
)
|
||||
|
||||
await actions.deleteAssets(inputAsset)
|
||||
|
||||
expect(mockInvalidateModelsForCategory).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -14,6 +14,7 @@ import { useNodeDefStore } from '@/stores/nodeDefStore'
|
||||
import { getOutputAssetMetadata } from '../schemas/assetMetadataSchema'
|
||||
import { useAssetsStore } from '@/stores/assetsStore'
|
||||
import { useDialogStore } from '@/stores/dialogStore'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
import { getAssetType } from '../utils/assetTypeUtil'
|
||||
import { getAssetUrl } from '../utils/assetUrlUtil'
|
||||
import { createAnnotatedPath } from '@/utils/createAnnotatedPath'
|
||||
@@ -586,6 +587,27 @@ export function useMediaAssetActions() {
|
||||
await assetsStore.updateInputs()
|
||||
}
|
||||
|
||||
// Invalidate model caches for affected categories
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
const modelCategories = new Set<string>()
|
||||
const excludedTags = ['models', 'input', 'output']
|
||||
|
||||
for (const asset of assetArray) {
|
||||
for (const tag of asset.tags ?? []) {
|
||||
if (excludedTags.includes(tag)) continue
|
||||
const providers = modelToNodeStore.getAllNodeProviders(tag)
|
||||
if (providers.length > 0) {
|
||||
modelCategories.add(tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await Promise.allSettled(
|
||||
[...modelCategories].map((category) =>
|
||||
assetsStore.invalidateModelsForCategory(category)
|
||||
)
|
||||
)
|
||||
|
||||
// Show appropriate feedback based on results
|
||||
if (failed.length === 0) {
|
||||
toast.add({
|
||||
|
||||
@@ -36,6 +36,32 @@ vi.mock('@/platform/distribution/types', () => ({
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock modelToNodeStore with proper node providers
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({
|
||||
getAllNodeProviders: vi.fn((category: string) => {
|
||||
const providers: Record<
|
||||
string,
|
||||
Array<{ nodeDef: { name: string }; key: string }>
|
||||
> = {
|
||||
checkpoints: [
|
||||
{ nodeDef: { name: 'CheckpointLoaderSimple' }, key: 'ckpt_name' },
|
||||
{ nodeDef: { name: 'ImageOnlyCheckpointLoader' }, key: 'ckpt_name' }
|
||||
],
|
||||
loras: [
|
||||
{ nodeDef: { name: 'LoraLoader' }, key: 'lora_name' },
|
||||
{ nodeDef: { name: 'LoraLoaderModelOnly' }, key: 'lora_name' }
|
||||
],
|
||||
vae: [{ nodeDef: { name: 'VAELoader' }, key: 'vae_name' }]
|
||||
}
|
||||
return providers[category] ?? []
|
||||
}),
|
||||
getNodeProvider: vi.fn(),
|
||||
getCategoryForNodeType: vi.fn(),
|
||||
registerDefaults: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
// Mock TaskItemImpl
|
||||
vi.mock('@/stores/queueStore', () => ({
|
||||
TaskItemImpl: class {
|
||||
@@ -472,12 +498,12 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||
mockIsCloud.value = false
|
||||
})
|
||||
|
||||
const createMockAsset = (id: string) => ({
|
||||
const createMockAsset = (id: string, tags: string[] = ['models']) => ({
|
||||
id,
|
||||
name: `asset-${id}`,
|
||||
size: 100,
|
||||
created_at: new Date().toISOString(),
|
||||
tags: ['models'],
|
||||
tags,
|
||||
preview_url: `http://test.com/${id}`
|
||||
})
|
||||
|
||||
@@ -614,4 +640,81 @@ describe('assetsStore - Model Assets Cache (Cloud)', () => {
|
||||
expect(loadingStates).toContain(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('invalidateModelsForCategory', () => {
|
||||
it('should invalidate model cache for all node types providing a category', async () => {
|
||||
const store = useAssetsStore()
|
||||
const nodeType = 'CheckpointLoaderSimple'
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
|
||||
createMockAsset('existing-1'),
|
||||
createMockAsset('existing-2')
|
||||
])
|
||||
await store.updateModelsForNodeType(nodeType)
|
||||
expect(store.getAssets(nodeType)).toHaveLength(2)
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
|
||||
createMockAsset('existing-1')
|
||||
])
|
||||
|
||||
await store.invalidateModelsForCategory('checkpoints')
|
||||
|
||||
expect(store.getAssets(nodeType)).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should invalidate multiple node types for same category', async () => {
|
||||
const store = useAssetsStore()
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([
|
||||
createMockAsset('asset-1')
|
||||
])
|
||||
|
||||
await store.updateModelsForNodeType('CheckpointLoaderSimple')
|
||||
await store.updateModelsForNodeType('ImageOnlyCheckpointLoader')
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValue([])
|
||||
|
||||
await store.invalidateModelsForCategory('checkpoints')
|
||||
|
||||
expect(store.getAssets('CheckpointLoaderSimple')).toHaveLength(0)
|
||||
expect(store.getAssets('ImageOnlyCheckpointLoader')).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should also invalidate tag-based caches', async () => {
|
||||
const store = useAssetsStore()
|
||||
|
||||
vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([
|
||||
createMockAsset('tag-asset-1')
|
||||
])
|
||||
|
||||
await store.updateModelsForTag('checkpoints')
|
||||
expect(store.getAssets('tag:checkpoints')).toHaveLength(1)
|
||||
|
||||
vi.mocked(assetService.getAssetsByTag).mockResolvedValueOnce([])
|
||||
|
||||
await store.invalidateModelsForCategory('checkpoints')
|
||||
|
||||
expect(store.getAssets('tag:checkpoints')).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('removeAssetFromCache', () => {
|
||||
it('should remove specific asset from node type cache', async () => {
|
||||
const store = useAssetsStore()
|
||||
const nodeType = 'LoraLoader'
|
||||
|
||||
vi.mocked(assetService.getAssetsForNodeType).mockResolvedValueOnce([
|
||||
createMockAsset('lora-1'),
|
||||
createMockAsset('lora-2'),
|
||||
createMockAsset('lora-3')
|
||||
])
|
||||
await store.updateModelsForNodeType(nodeType)
|
||||
expect(store.getAssets(nodeType)).toHaveLength(3)
|
||||
|
||||
store.removeAssetFromCache('lora-2', 'loras')
|
||||
|
||||
expect(store.getAssets(nodeType)).toHaveLength(2)
|
||||
expect(store.getAssets(nodeType).map((a) => a.id)).not.toContain('lora-2')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -546,6 +546,54 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Invalidate model caches for a given category (e.g., 'checkpoints', 'loras')
|
||||
* Refreshes all node types that provide this category plus tag-based caches
|
||||
* @param category The model category to invalidate (e.g., 'checkpoints')
|
||||
*/
|
||||
async function invalidateModelsForCategory(
|
||||
category: string
|
||||
): Promise<void> {
|
||||
const providers = modelToNodeStore
|
||||
.getAllNodeProviders(category)
|
||||
.filter((provider) => provider.nodeDef?.name)
|
||||
|
||||
const nodeTypeUpdates = providers.map((provider) =>
|
||||
updateModelsForNodeType(provider.nodeDef.name)
|
||||
)
|
||||
|
||||
const tagUpdates = [
|
||||
updateModelsForTag(category),
|
||||
updateModelsForTag('models')
|
||||
]
|
||||
|
||||
await Promise.allSettled([...nodeTypeUpdates, ...tagUpdates])
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a specific asset from all caches for a given category
|
||||
* Used for optimistic updates after asset deletion
|
||||
* @param assetId The asset ID to remove
|
||||
* @param category The model category (e.g., 'loras')
|
||||
*/
|
||||
function removeAssetFromCache(assetId: string, category: string): void {
|
||||
const providers = modelToNodeStore.getAllNodeProviders(category)
|
||||
|
||||
for (const provider of providers) {
|
||||
const nodeType = provider.nodeDef?.name
|
||||
if (!nodeType) continue
|
||||
|
||||
const state = modelStateByKey.value.get(nodeType)
|
||||
if (!state) continue
|
||||
|
||||
state.assets.delete(assetId)
|
||||
assetsArrayCache.delete(nodeType)
|
||||
}
|
||||
|
||||
assetsArrayCache.delete(`tag:${category}`)
|
||||
assetsArrayCache.delete('tag:models')
|
||||
}
|
||||
|
||||
return {
|
||||
getAssets,
|
||||
isLoading,
|
||||
@@ -555,7 +603,9 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
updateModelsForNodeType,
|
||||
updateModelsForTag,
|
||||
updateAssetMetadata,
|
||||
updateAssetTags
|
||||
updateAssetTags,
|
||||
invalidateModelsForCategory,
|
||||
removeAssetFromCache
|
||||
}
|
||||
}
|
||||
|
||||
@@ -569,7 +619,9 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
updateModelsForNodeType: async () => {},
|
||||
updateModelsForTag: async () => {},
|
||||
updateAssetMetadata: async () => {},
|
||||
updateAssetTags: async () => {}
|
||||
updateAssetTags: async () => {},
|
||||
invalidateModelsForCategory: async () => {},
|
||||
removeAssetFromCache: () => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -582,7 +634,9 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
updateModelsForNodeType,
|
||||
updateModelsForTag,
|
||||
updateAssetMetadata,
|
||||
updateAssetTags
|
||||
updateAssetTags,
|
||||
invalidateModelsForCategory,
|
||||
removeAssetFromCache
|
||||
} = getModelState()
|
||||
|
||||
// Watch for completed downloads and refresh model caches
|
||||
@@ -658,6 +712,8 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
updateModelsForNodeType,
|
||||
updateModelsForTag,
|
||||
updateAssetMetadata,
|
||||
updateAssetTags
|
||||
updateAssetTags,
|
||||
invalidateModelsForCategory,
|
||||
removeAssetFromCache
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user