diff --git a/src/platform/assets/services/assetService.ts b/src/platform/assets/services/assetService.ts index 626a52d20..3c816fcd7 100644 --- a/src/platform/assets/services/assetService.ts +++ b/src/platform/assets/services/assetService.ts @@ -113,11 +113,16 @@ function createAssetService() { * Checks if a widget input should use the asset browser based on both input name and node comfyClass * * @param nodeType - The ComfyUI node comfyClass (e.g., 'CheckpointLoaderSimple', 'LoraLoader') + * @param widgetName - The name of the widget to check (e.g., 'ckpt_name') * @returns true if this input should use asset browser */ - function isAssetBrowserEligible(nodeType: string = ''): boolean { + function isAssetBrowserEligible( + nodeType: string | undefined, + widgetName: string + ): boolean { + if (!nodeType || !widgetName) return false return ( - !!nodeType && useModelToNodeStore().getRegisteredNodeTypes().has(nodeType) + useModelToNodeStore().getRegisteredNodeTypes()[nodeType] === widgetName ) } diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts index dcd30fde5..24f38f502 100644 --- a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts +++ b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts @@ -356,7 +356,10 @@ const addComboWidget = ( ): IBaseWidget => { const settingStore = useSettingStore() const isUsingAssetAPI = settingStore.get('Comfy.Assets.UseAssetAPI') - const isEligible = assetService.isAssetBrowserEligible(node.comfyClass) + const isEligible = assetService.isAssetBrowserEligible( + node.comfyClass, + inputSpec.name + ) if (isUsingAssetAPI && isEligible) { const currentValue = getDefaultValue(inputSpec) diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts index 436dac26d..cc9a911c8 100644 --- a/src/stores/modelToNodeStore.ts +++ b/src/stores/modelToNodeStore.ts @@ -25,12 +25,12 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { const haveDefaultsLoaded = ref(false) /** Internal computed for reactive caching of registered node types */ - const registeredNodeTypes = computed(() => { - return new Set( + const registeredNodeTypes = computed>(() => { + return Object.fromEntries( Object.values(modelToNodeMap.value) .flat() .filter((provider) => !!provider.nodeDef) - .map((provider) => provider.nodeDef.name) + .map((provider) => [provider.nodeDef.name, provider.key]) ) }) @@ -51,7 +51,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { }) /** Get set of all registered node types for efficient lookup */ - function getRegisteredNodeTypes(): Set { + function getRegisteredNodeTypes(): Record { registerDefaults() return registeredNodeTypes.value } diff --git a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts index 5db6c5f00..3bb985a47 100644 --- a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts +++ b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts @@ -163,7 +163,8 @@ describe('useComboWidget', () => { ) expect(mockSettingStoreGet).toHaveBeenCalledWith('Comfy.Assets.UseAssetAPI') expect(vi.mocked(assetService.isAssetBrowserEligible)).toHaveBeenCalledWith( - 'CheckpointLoaderSimple' + 'CheckpointLoaderSimple', + 'ckpt_name' ) expect(widget).toBe(mockWidget) }) diff --git a/tests-ui/tests/services/assetService.test.ts b/tests-ui/tests/services/assetService.test.ts index b6386bc2b..e9b757cf0 100644 --- a/tests-ui/tests/services/assetService.test.ts +++ b/tests-ui/tests/services/assetService.test.ts @@ -16,15 +16,12 @@ const mockGetCategoryForNodeType = vi.fn() vi.mock('@/stores/modelToNodeStore', () => ({ useModelToNodeStore: vi.fn(() => ({ - getRegisteredNodeTypes: vi.fn( - () => - new Set([ - 'CheckpointLoaderSimple', - 'LoraLoader', - 'VAELoader', - 'TestNode' - ]) - ), + getRegisteredNodeTypes: vi.fn(() => ({ + CheckpointLoaderSimple: 'ckpt_name', + LoraLoader: 'lora_name', + VAELoader: 'vae_name', + TestNode: '' + })), getCategoryForNodeType: mockGetCategoryForNodeType, modelToNodeMap: { checkpoints: [{ nodeDef: { name: 'CheckpointLoaderSimple' } }], @@ -191,19 +188,19 @@ describe('assetService', () => { }) describe('isAssetBrowserEligible', () => { - it('should return true for registered node types', () => { - expect( - assetService.isAssetBrowserEligible('CheckpointLoaderSimple') - ).toBe(true) - expect(assetService.isAssetBrowserEligible('LoraLoader')).toBe(true) - expect(assetService.isAssetBrowserEligible('VAELoader')).toBe(true) - }) - - it('should return false for unregistered node types', () => { - expect(assetService.isAssetBrowserEligible('UnknownNode')).toBe(false) - expect(assetService.isAssetBrowserEligible('NotRegistered')).toBe(false) - expect(assetService.isAssetBrowserEligible('')).toBe(false) - }) + it.for<[string, string, boolean, string]>([ + ['CheckpointLoaderSimple', 'ckpt_name', true, 'valid inputs'], + ['LoraLoader', 'lora_name', true, 'valid inputs'], + ['VAELoader', 'vae_name', true, 'valid inputs'], + ['CheckpointLoaderSimple', 'type', false, 'other combo widgets'], + ['UnknownNode', 'widget', false, 'unregistered types'], + ['NotRegistered', 'widget', false, 'unregistered types'] + ])( + 'isAssetBrowserEligible("%s", "%s") should return %s for %s', + ([type, name, expected]) => { + expect(assetService.isAssetBrowserEligible(type, name)).toBe(expected) + } + ) }) describe('getAssetsForNodeType', () => { diff --git a/tests-ui/tests/store/modelToNodeStore.test.ts b/tests-ui/tests/store/modelToNodeStore.test.ts index c04f1764d..c79af19b5 100644 --- a/tests-ui/tests/store/modelToNodeStore.test.ts +++ b/tests-ui/tests/store/modelToNodeStore.test.ts @@ -288,7 +288,7 @@ describe('useModelToNodeStore', () => { // Non-existent nodes are filtered out from registered types const types = modelToNodeStore.getRegisteredNodeTypes() - expect(types.has('NonExistentLoader')).toBe(false) + expect(types['NonExistentLoader']).toBe(undefined) expect( modelToNodeStore.getCategoryForNodeType('NonExistentLoader') @@ -347,13 +347,13 @@ describe('useModelToNodeStore', () => { }) describe('getRegisteredNodeTypes', () => { - it('should return a Set instance', () => { + it('should return an object', () => { const modelToNodeStore = useModelToNodeStore() const result = modelToNodeStore.getRegisteredNodeTypes() - expect(result).toBeInstanceOf(Set) + expect(result).toBeTypeOf('object') }) - it('should return empty set when nodeDefStore is empty', () => { + it('should return empty Record when nodeDefStore is empty', () => { // Create fresh Pinia for this test to avoid state persistence setActivePinia(createPinia()) @@ -363,7 +363,7 @@ describe('useModelToNodeStore', () => { const modelToNodeStore = useModelToNodeStore() const result = modelToNodeStore.getRegisteredNodeTypes() - expect(result.size).toBe(0) + expect(result).toStrictEqual({}) // Restore original mock for subsequent tests vi.mocked(useNodeDefStore, { partial: true }).mockReturnValue({ @@ -371,16 +371,15 @@ describe('useModelToNodeStore', () => { }) }) - it('should contain node types for efficient Set.has() lookups', () => { + it('should contain node types to resolve widget name', () => { const modelToNodeStore = useModelToNodeStore() modelToNodeStore.registerDefaults() const result = modelToNodeStore.getRegisteredNodeTypes() - // Test Set.has() functionality which assetService depends on - expect(result.has('CheckpointLoaderSimple')).toBe(true) - expect(result.has('LoraLoader')).toBe(true) - expect(result.has('NonExistentNode')).toBe(false) + expect(result['CheckpointLoaderSimple']).toBe('ckpt_name') + expect(result['LoraLoader']).toBe('lora_name') + expect(result['NonExistentNode']).toBe(undefined) }) })