From 113bd5e5d3e1d363be8574a0c307b618f57b07f1 Mon Sep 17 00:00:00 2001 From: Arjan Singh <1598641+arjansingh@users.noreply.github.com> Date: Wed, 10 Sep 2025 22:26:07 -0700 Subject: [PATCH] [feat] carve out path to call asset browser in combo widget (#5464) * [ci] ignore local browser tests files this is where i have claude put its one off playwright scripts * [feat] carve out path to call asset browser in combo widget * [feat] use buttons on Model Loaders when Asset API setting is on --- .gitignore | 1 + src/locales/en/main.json | 3 + .../widgets/composables/useComboWidget.ts | 50 ++++- src/services/assetService.ts | 28 ++- src/services/litegraphService.ts | 13 +- src/stores/modelToNodeStore.ts | 19 +- .../composables/useComboWidget.test.ts | 196 ++++++++++++++++-- tests-ui/tests/services/assetService.test.ts | 53 +++++ tests-ui/tests/store/modelToNodeStore.test.ts | 116 +++++++---- 9 files changed, 423 insertions(+), 56 deletions(-) diff --git a/.gitignore b/.gitignore index db0b8454c..100bcd13e 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,7 @@ tests-ui/workflows/examples /blob-report/ /playwright/.cache/ browser_tests/**/*-win32.png +browser-tests/local/ .env diff --git a/src/locales/en/main.json b/src/locales/en/main.json index 0c9ebeffe..8a663489e 100644 --- a/src/locales/en/main.json +++ b/src/locales/en/main.json @@ -1780,6 +1780,9 @@ "copiedTooltip": "Copied", "copyTooltip": "Copy message to clipboard" }, + "widgets": { + "selectModel": "Select model" + }, "nodeHelpPage": { "inputs": "Inputs", "outputs": "Outputs", diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts index ad973325d..586cfb27e 100644 --- a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts +++ b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts @@ -1,8 +1,12 @@ import { ref } from 'vue' import MultiSelectWidget from '@/components/graph/widgets/MultiSelectWidget.vue' +import { t } from '@/i18n' import type { LGraphNode } from '@/lib/litegraph/src/litegraph' -import type { IComboWidget } from '@/lib/litegraph/src/types/widgets' +import type { + IBaseWidget, + IComboWidget +} from '@/lib/litegraph/src/types/widgets' import { transformInputSpecV2ToV1 } from '@/schemas/nodeDef/migration' import { ComboInputSpec, @@ -18,6 +22,8 @@ import { type ComfyWidgetConstructorV2, addValueControlWidgets } from '@/scripts/widgets' +import { assetService } from '@/services/assetService' +import { useSettingStore } from '@/stores/settingStore' import { useRemoteWidget } from './useRemoteWidget' @@ -28,7 +34,10 @@ const getDefaultValue = (inputSpec: ComboInputSpec) => { return undefined } -const addMultiSelectWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => { +const addMultiSelectWidget = ( + node: LGraphNode, + inputSpec: ComboInputSpec +): IBaseWidget => { const widgetValue = ref([]) const widget = new ComponentWidgetImpl({ node, @@ -48,7 +57,36 @@ const addMultiSelectWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => { return widget } -const addComboWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => { +const addComboWidget = ( + node: LGraphNode, + inputSpec: ComboInputSpec +): IBaseWidget => { + const settingStore = useSettingStore() + const isUsingAssetAPI = settingStore.get('Comfy.Assets.UseAssetAPI') + const isEligible = assetService.isAssetBrowserEligible( + inputSpec.name, + node.comfyClass || '' + ) + + if (isUsingAssetAPI && isEligible) { + // Create button widget for Asset Browser + const currentValue = getDefaultValue(inputSpec) + + const widget = node.addWidget( + 'button', + inputSpec.name, + t('widgets.selectModel'), + () => { + console.log( + `Asset Browser would open here for:\nNode: ${node.type}\nWidget: ${inputSpec.name}\nCurrent Value:${currentValue}` + ) + } + ) + + return widget + } + + // Create normal combo widget const defaultValue = getDefaultValue(inputSpec) const comboOptions = inputSpec.options ?? [] const widget = node.addWidget( @@ -59,14 +97,14 @@ const addComboWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => { { values: comboOptions } - ) as IComboWidget + ) if (inputSpec.remote) { const remoteWidget = useRemoteWidget({ remoteConfig: inputSpec.remote, defaultValue, node, - widget + widget: widget as IComboWidget }) if (inputSpec.remote.refresh_button) remoteWidget.addRefreshButton() @@ -84,7 +122,7 @@ const addComboWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => { if (inputSpec.control_after_generate) { widget.linkedWidgets = addValueControlWidgets( node, - widget, + widget as IComboWidget, undefined, undefined, transformInputSpecV2ToV1(inputSpec) diff --git a/src/services/assetService.ts b/src/services/assetService.ts index 144beb231..ffc825260 100644 --- a/src/services/assetService.ts +++ b/src/services/assetService.ts @@ -7,11 +7,17 @@ import { assetResponseSchema } from '@/schemas/assetSchema' import { api } from '@/scripts/api' +import { useModelToNodeStore } from '@/stores/modelToNodeStore' const ASSETS_ENDPOINT = '/assets' const MODELS_TAG = 'models' const MISSING_TAG = 'missing' +/** + * Input names that are eligible for asset browser + */ +const WHITELISTED_INPUTS = new Set(['ckpt_name', 'lora_name', 'vae_name']) + /** * Validates asset response data using Zod schema */ @@ -102,9 +108,29 @@ function createAssetService() { ) } + /** + * Checks if a widget input should use the asset browser based on both input name and node comfyClass + * + * @param inputName - The input name (e.g., 'ckpt_name', 'lora_name') + * @param nodeType - The ComfyUI node comfyClass (e.g., 'CheckpointLoaderSimple', 'LoraLoader') + * @returns true if this input should use asset browser + */ + function isAssetBrowserEligible( + inputName: string, + nodeType: string + ): boolean { + return ( + // Must be an approved input name + WHITELISTED_INPUTS.has(inputName) && + // Must be a registered node type + useModelToNodeStore().getRegisteredNodeTypes().has(nodeType) + ) + } + return { getAssetModelFolders, - getAssetModels + getAssetModels, + isAssetBrowserEligible } } diff --git a/src/services/litegraphService.ts b/src/services/litegraphService.ts index f43d3d8b9..89e33965d 100644 --- a/src/services/litegraphService.ts +++ b/src/services/litegraphService.ts @@ -484,7 +484,18 @@ export const useLitegraphService = () => { ) ?? {} if (widget) { - widget.label = st(nameKey, widget.label ?? inputName) + // Check if this is an Asset Browser button widget + const isAssetBrowserButton = + widget.type === 'button' && widget.value === 'Select model' + + if (isAssetBrowserButton) { + // Preserve Asset Browser button label (don't translate) + widget.label = String(widget.value) + } else { + // Apply normal translation for other widgets + widget.label = st(nameKey, widget.label ?? inputName) + } + widget.options ??= {} Object.assign(widget.options, { advanced: inputSpec.advanced, diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts index 7be3a5ca3..fc78dc1e5 100644 --- a/src/stores/modelToNodeStore.ts +++ b/src/stores/modelToNodeStore.ts @@ -1,5 +1,5 @@ import { defineStore } from 'pinia' -import { ref } from 'vue' +import { computed, ref } from 'vue' import { ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore' @@ -22,6 +22,22 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { const modelToNodeMap = ref>({}) const nodeDefStore = useNodeDefStore() const haveDefaultsLoaded = ref(false) + + /** Internal computed for reactive caching of registered node types */ + const registeredNodeTypes = computed(() => { + return new Set( + Object.values(modelToNodeMap.value) + .flat() + .map((provider) => provider.nodeDef.name) + ) + }) + + /** Get set of all registered node types for efficient lookup */ + function getRegisteredNodeTypes(): Set { + registerDefaults() + return registeredNodeTypes.value + } + /** * Get the node provider for the given model type name. * @param modelType The name of the model type to get the node provider for. @@ -91,6 +107,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { return { modelToNodeMap, + getRegisteredNodeTypes, getNodeProvider, getAllNodeProviders, registerNodeProvider, diff --git a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts index 15936d1a4..ce3acffc5 100644 --- a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts +++ b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts @@ -1,39 +1,211 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' +import { LGraphNode } from '@/lib/litegraph/src/litegraph' +import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets' import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget' import type { InputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2' +import { assetService } from '@/services/assetService' vi.mock('@/scripts/widgets', () => ({ addValueControlWidgets: vi.fn() })) +const mockSettingStoreGet = vi.fn(() => false) +vi.mock('@/stores/settingStore', () => ({ + useSettingStore: vi.fn(() => ({ + get: mockSettingStoreGet + })) +})) + +vi.mock('@/i18n', () => ({ + t: vi.fn((key: string) => + key === 'widgets.selectModel' ? 'Select model' : key + ) +})) + +vi.mock('@/services/assetService', () => ({ + assetService: { + isAssetBrowserEligible: vi.fn(() => false) + } +})) + +// Test factory functions +function createMockWidget(overrides: Partial = {}): IBaseWidget { + return { + type: 'combo', + options: {}, + name: 'testWidget', + value: undefined, + ...overrides + } as IBaseWidget +} + +function createMockNode(comfyClass = 'TestNode'): LGraphNode { + const node = new LGraphNode('TestNode') + node.comfyClass = comfyClass + + // Spy on the addWidget method + vi.spyOn(node, 'addWidget').mockReturnValue(createMockWidget()) + + return node +} + +function createMockInputSpec(overrides: Partial = {}): InputSpec { + return { + type: 'COMBO', + name: 'testInput', + ...overrides + } as InputSpec +} + describe('useComboWidget', () => { beforeEach(() => { vi.clearAllMocks() + // Reset to defaults + mockSettingStoreGet.mockReturnValue(false) + vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(false) }) it('should handle undefined spec', () => { const constructor = useComboWidget() - const mockNode = { - addWidget: vi.fn().mockReturnValue({ options: {} } as any) - } + const mockWidget = createMockWidget() + const mockNode = createMockNode() + vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget) + const inputSpec = createMockInputSpec({ name: 'inputName' }) - const inputSpec: InputSpec = { - type: 'COMBO', - name: 'inputName' - } - - const widget = constructor(mockNode as any, inputSpec) + const widget = constructor(mockNode, inputSpec) expect(mockNode.addWidget).toHaveBeenCalledWith( 'combo', 'inputName', - undefined, // default value - expect.any(Function), // callback + undefined, + expect.any(Function), expect.objectContaining({ values: [] }) ) - expect(widget).toEqual({ options: {} }) + expect(widget).toBe(mockWidget) + }) + + it('should create normal combo widget when asset API is disabled', () => { + mockSettingStoreGet.mockReturnValue(false) + vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(true) + + const constructor = useComboWidget() + const mockWidget = createMockWidget() + const mockNode = createMockNode('CheckpointLoaderSimple') + vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget) + const inputSpec = createMockInputSpec({ + name: 'ckpt_name', + options: ['model1.safetensors', 'model2.safetensors'] + }) + + const widget = constructor(mockNode, inputSpec) + + expect(mockNode.addWidget).toHaveBeenCalledWith( + 'combo', + 'ckpt_name', + 'model1.safetensors', + expect.any(Function), + { values: ['model1.safetensors', 'model2.safetensors'] } + ) + expect(mockSettingStoreGet).toHaveBeenCalledWith('Comfy.Assets.UseAssetAPI') + expect(widget).toBe(mockWidget) + }) + + it('should create normal combo widget when widget is not eligible for asset browser', () => { + mockSettingStoreGet.mockReturnValue(true) + vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(false) + + const constructor = useComboWidget() + const mockWidget = createMockWidget() + const mockNode = createMockNode() + vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget) + const inputSpec = createMockInputSpec({ + name: 'not_eligible_widget', + options: ['option1', 'option2'] + }) + + const widget = constructor(mockNode, inputSpec) + + expect(mockNode.addWidget).toHaveBeenCalledWith( + 'combo', + 'not_eligible_widget', + 'option1', + expect.any(Function), + { values: ['option1', 'option2'] } + ) + expect(vi.mocked(assetService.isAssetBrowserEligible)).toHaveBeenCalledWith( + 'not_eligible_widget', + 'TestNode' + ) + expect(widget).toBe(mockWidget) + }) + + it('should create asset browser button widget when API enabled and widget eligible', () => { + mockSettingStoreGet.mockReturnValue(true) + vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(true) + + const constructor = useComboWidget() + const mockWidget = createMockWidget({ + type: 'button', + name: 'ckpt_name', + value: 'Select model' + }) + const mockNode = createMockNode('CheckpointLoaderSimple') + vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget) + const inputSpec = createMockInputSpec({ + name: 'ckpt_name', + options: ['model1.safetensors', 'model2.safetensors'] + }) + + const widget = constructor(mockNode, inputSpec) + + expect(mockNode.addWidget).toHaveBeenCalledWith( + 'button', + 'ckpt_name', + 'Select model', + expect.any(Function) + ) + expect(mockSettingStoreGet).toHaveBeenCalledWith('Comfy.Assets.UseAssetAPI') + expect(vi.mocked(assetService.isAssetBrowserEligible)).toHaveBeenCalledWith( + 'ckpt_name', + 'CheckpointLoaderSimple' + ) + expect(widget).toBe(mockWidget) + }) + + it('should use asset browser button even when inputSpec has a default value but no options', () => { + mockSettingStoreGet.mockReturnValue(true) + vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(true) + + const constructor = useComboWidget() + const mockWidget = createMockWidget({ + type: 'button', + name: 'ckpt_name', + value: 'Select model' + }) + const mockNode = createMockNode('CheckpointLoaderSimple') + vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget) + const inputSpec = createMockInputSpec({ + name: 'ckpt_name', + default: 'fallback.safetensors' + // Note: no options array provided + }) + + const widget = constructor(mockNode, inputSpec) + + expect(mockNode.addWidget).toHaveBeenCalledWith( + 'button', + 'ckpt_name', + 'Select model', + expect.any(Function) + ) + expect(mockSettingStoreGet).toHaveBeenCalledWith('Comfy.Assets.UseAssetAPI') + expect(vi.mocked(assetService.isAssetBrowserEligible)).toHaveBeenCalledWith( + 'ckpt_name', + 'CheckpointLoaderSimple' + ) + expect(widget).toBe(mockWidget) }) }) diff --git a/tests-ui/tests/services/assetService.test.ts b/tests-ui/tests/services/assetService.test.ts index 29ffee3b0..6a3981acb 100644 --- a/tests-ui/tests/services/assetService.test.ts +++ b/tests-ui/tests/services/assetService.test.ts @@ -3,6 +3,20 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { api } from '@/scripts/api' import { assetService } from '@/services/assetService' +vi.mock('@/stores/modelToNodeStore', () => ({ + useModelToNodeStore: vi.fn(() => ({ + getRegisteredNodeTypes: vi.fn( + () => + new Set([ + 'CheckpointLoaderSimple', + 'LoraLoader', + 'VAELoader', + 'TestNode' + ]) + ) + })) +})) + // Test data constants const MOCK_ASSETS = { checkpoints: { @@ -147,4 +161,43 @@ describe('assetService', () => { ) }) }) + + describe('isAssetBrowserEligible', () => { + it('should return true for eligible widget names with registered node types', () => { + expect( + assetService.isAssetBrowserEligible( + 'ckpt_name', + 'CheckpointLoaderSimple' + ) + ).toBe(true) + expect( + assetService.isAssetBrowserEligible('lora_name', 'LoraLoader') + ).toBe(true) + expect(assetService.isAssetBrowserEligible('vae_name', 'VAELoader')).toBe( + true + ) + }) + + it('should return false for non-eligible widget names', () => { + expect(assetService.isAssetBrowserEligible('seed', 'TestNode')).toBe( + false + ) + expect(assetService.isAssetBrowserEligible('steps', 'TestNode')).toBe( + false + ) + expect( + assetService.isAssetBrowserEligible('sampler_name', 'TestNode') + ).toBe(false) + expect(assetService.isAssetBrowserEligible('', 'TestNode')).toBe(false) + }) + + it('should return false for eligible widget names with unregistered node types', () => { + expect( + assetService.isAssetBrowserEligible('ckpt_name', 'UnknownNode') + ).toBe(false) + expect( + assetService.isAssetBrowserEligible('lora_name', 'UnknownNode') + ).toBe(false) + }) + }) }) diff --git a/tests-ui/tests/store/modelToNodeStore.test.ts b/tests-ui/tests/store/modelToNodeStore.test.ts index dc08e10fc..179c76b9e 100644 --- a/tests-ui/tests/store/modelToNodeStore.test.ts +++ b/tests-ui/tests/store/modelToNodeStore.test.ts @@ -21,45 +21,44 @@ const EXPECTED_DEFAULT_TYPES = [ type NodeDefStoreType = typeof import('@/stores/nodeDefStore') +// Create minimal but valid ComfyNodeDefImpl for testing +function createMockNodeDef(name: string): ComfyNodeDefImpl { + const def: ComfyNodeDefV1 = { + name, + display_name: name, + category: 'test', + python_module: 'nodes', + description: '', + input: { required: {}, optional: {} }, + output: [], + output_name: [], + output_is_list: [], + output_node: false + } + return new ComfyNodeDefImpl(def) +} + +const MOCK_NODE_NAMES = [ + 'CheckpointLoaderSimple', + 'ImageOnlyCheckpointLoader', + 'LoraLoader', + 'LoraLoaderModelOnly', + 'VAELoader', + 'ControlNetLoader', + 'UNETLoader', + 'UpscaleModelLoader', + 'StyleModelLoader', + 'GLIGENLoader' +] as const + +const mockNodeDefsByName = Object.fromEntries( + MOCK_NODE_NAMES.map((name) => [name, createMockNodeDef(name)]) +) + // Mock nodeDefStore dependency - modelToNodeStore relies on this for registration // Most tests expect this to be populated; tests that need empty state can override vi.mock('@/stores/nodeDefStore', async (importOriginal) => { const original = await importOriginal() - const { ComfyNodeDefImpl } = original - - // Create minimal but valid ComfyNodeDefImpl for testing - function createMockNodeDef(name: string): ComfyNodeDefImpl { - const def: ComfyNodeDefV1 = { - name, - display_name: name, - category: 'test', - python_module: 'nodes', - description: '', - input: { required: {}, optional: {} }, - output: [], - output_name: [], - output_is_list: [], - output_node: false - } - return new ComfyNodeDefImpl(def) - } - - const MOCK_NODE_NAMES = [ - 'CheckpointLoaderSimple', - 'ImageOnlyCheckpointLoader', - 'LoraLoader', - 'LoraLoaderModelOnly', - 'VAELoader', - 'ControlNetLoader', - 'UNETLoader', - 'UpscaleModelLoader', - 'StyleModelLoader', - 'GLIGENLoader' - ] as const - - const mockNodeDefsByName = Object.fromEntries( - MOCK_NODE_NAMES.map((name) => [name, createMockNodeDef(name)]) - ) return { ...original, @@ -72,6 +71,7 @@ vi.mock('@/stores/nodeDefStore', async (importOriginal) => { describe('useModelToNodeStore', () => { beforeEach(() => { setActivePinia(createPinia()) + vi.clearAllMocks() }) describe('modelToNodeMap', () => { @@ -288,12 +288,58 @@ describe('useModelToNodeStore', () => { }) it('should not register when nodeDefStore is empty', () => { + // Create fresh Pinia for this test to avoid state persistence + setActivePinia(createPinia()) + vi.mocked(useNodeDefStore, { partial: true }).mockReturnValue({ nodeDefsByName: {} }) const modelToNodeStore = useModelToNodeStore() modelToNodeStore.registerDefaults() expect(modelToNodeStore.getNodeProvider('checkpoints')).toBeUndefined() + + // Restore original mock for subsequent tests + vi.mocked(useNodeDefStore, { partial: true }).mockReturnValue({ + nodeDefsByName: mockNodeDefsByName + }) + }) + }) + + describe('getRegisteredNodeTypes', () => { + it('should return a Set instance', () => { + const modelToNodeStore = useModelToNodeStore() + const result = modelToNodeStore.getRegisteredNodeTypes() + expect(result).toBeInstanceOf(Set) + }) + + it('should return empty set when nodeDefStore is empty', () => { + // Create fresh Pinia for this test to avoid state persistence + setActivePinia(createPinia()) + + vi.mocked(useNodeDefStore, { partial: true }).mockReturnValue({ + nodeDefsByName: {} + }) + const modelToNodeStore = useModelToNodeStore() + + const result = modelToNodeStore.getRegisteredNodeTypes() + expect(result.size).toBe(0) + + // Restore original mock for subsequent tests + vi.mocked(useNodeDefStore, { partial: true }).mockReturnValue({ + nodeDefsByName: mockNodeDefsByName + }) + }) + + it('should contain node types for efficient Set.has() lookups', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + const result = modelToNodeStore.getRegisteredNodeTypes() + + // Test Set.has() functionality which assetService depends on + expect(result.has('CheckpointLoaderSimple')).toBe(true) + expect(result.has('LoraLoader')).toBe(true) + expect(result.has('NonExistentNode')).toBe(false) }) })