mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-08 09:00:05 +00:00
[feat] carve out path to call asset browser in combo widget (#5464)
* [ci] ignore local browser tests files this is where i have claude put its one off playwright scripts * [feat] carve out path to call asset browser in combo widget * [feat] use buttons on Model Loaders when Asset API setting is on
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -51,6 +51,7 @@ tests-ui/workflows/examples
|
||||
/blob-report/
|
||||
/playwright/.cache/
|
||||
browser_tests/**/*-win32.png
|
||||
browser-tests/local/
|
||||
|
||||
.env
|
||||
|
||||
|
||||
@@ -1780,6 +1780,9 @@
|
||||
"copiedTooltip": "Copied",
|
||||
"copyTooltip": "Copy message to clipboard"
|
||||
},
|
||||
"widgets": {
|
||||
"selectModel": "Select model"
|
||||
},
|
||||
"nodeHelpPage": {
|
||||
"inputs": "Inputs",
|
||||
"outputs": "Outputs",
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import { ref } from 'vue'
|
||||
|
||||
import MultiSelectWidget from '@/components/graph/widgets/MultiSelectWidget.vue'
|
||||
import { t } from '@/i18n'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IComboWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import type {
|
||||
IBaseWidget,
|
||||
IComboWidget
|
||||
} from '@/lib/litegraph/src/types/widgets'
|
||||
import { transformInputSpecV2ToV1 } from '@/schemas/nodeDef/migration'
|
||||
import {
|
||||
ComboInputSpec,
|
||||
@@ -18,6 +22,8 @@ import {
|
||||
type ComfyWidgetConstructorV2,
|
||||
addValueControlWidgets
|
||||
} from '@/scripts/widgets'
|
||||
import { assetService } from '@/services/assetService'
|
||||
import { useSettingStore } from '@/stores/settingStore'
|
||||
|
||||
import { useRemoteWidget } from './useRemoteWidget'
|
||||
|
||||
@@ -28,7 +34,10 @@ const getDefaultValue = (inputSpec: ComboInputSpec) => {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const addMultiSelectWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => {
|
||||
const addMultiSelectWidget = (
|
||||
node: LGraphNode,
|
||||
inputSpec: ComboInputSpec
|
||||
): IBaseWidget => {
|
||||
const widgetValue = ref<string[]>([])
|
||||
const widget = new ComponentWidgetImpl({
|
||||
node,
|
||||
@@ -48,7 +57,36 @@ const addMultiSelectWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => {
|
||||
return widget
|
||||
}
|
||||
|
||||
const addComboWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => {
|
||||
const addComboWidget = (
|
||||
node: LGraphNode,
|
||||
inputSpec: ComboInputSpec
|
||||
): IBaseWidget => {
|
||||
const settingStore = useSettingStore()
|
||||
const isUsingAssetAPI = settingStore.get('Comfy.Assets.UseAssetAPI')
|
||||
const isEligible = assetService.isAssetBrowserEligible(
|
||||
inputSpec.name,
|
||||
node.comfyClass || ''
|
||||
)
|
||||
|
||||
if (isUsingAssetAPI && isEligible) {
|
||||
// Create button widget for Asset Browser
|
||||
const currentValue = getDefaultValue(inputSpec)
|
||||
|
||||
const widget = node.addWidget(
|
||||
'button',
|
||||
inputSpec.name,
|
||||
t('widgets.selectModel'),
|
||||
() => {
|
||||
console.log(
|
||||
`Asset Browser would open here for:\nNode: ${node.type}\nWidget: ${inputSpec.name}\nCurrent Value:${currentValue}`
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return widget
|
||||
}
|
||||
|
||||
// Create normal combo widget
|
||||
const defaultValue = getDefaultValue(inputSpec)
|
||||
const comboOptions = inputSpec.options ?? []
|
||||
const widget = node.addWidget(
|
||||
@@ -59,14 +97,14 @@ const addComboWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => {
|
||||
{
|
||||
values: comboOptions
|
||||
}
|
||||
) as IComboWidget
|
||||
)
|
||||
|
||||
if (inputSpec.remote) {
|
||||
const remoteWidget = useRemoteWidget({
|
||||
remoteConfig: inputSpec.remote,
|
||||
defaultValue,
|
||||
node,
|
||||
widget
|
||||
widget: widget as IComboWidget
|
||||
})
|
||||
if (inputSpec.remote.refresh_button) remoteWidget.addRefreshButton()
|
||||
|
||||
@@ -84,7 +122,7 @@ const addComboWidget = (node: LGraphNode, inputSpec: ComboInputSpec) => {
|
||||
if (inputSpec.control_after_generate) {
|
||||
widget.linkedWidgets = addValueControlWidgets(
|
||||
node,
|
||||
widget,
|
||||
widget as IComboWidget,
|
||||
undefined,
|
||||
undefined,
|
||||
transformInputSpecV2ToV1(inputSpec)
|
||||
|
||||
@@ -7,11 +7,17 @@ import {
|
||||
assetResponseSchema
|
||||
} from '@/schemas/assetSchema'
|
||||
import { api } from '@/scripts/api'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
const ASSETS_ENDPOINT = '/assets'
|
||||
const MODELS_TAG = 'models'
|
||||
const MISSING_TAG = 'missing'
|
||||
|
||||
/**
|
||||
* Input names that are eligible for asset browser
|
||||
*/
|
||||
const WHITELISTED_INPUTS = new Set(['ckpt_name', 'lora_name', 'vae_name'])
|
||||
|
||||
/**
|
||||
* Validates asset response data using Zod schema
|
||||
*/
|
||||
@@ -102,9 +108,29 @@ function createAssetService() {
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a widget input should use the asset browser based on both input name and node comfyClass
|
||||
*
|
||||
* @param inputName - The input name (e.g., 'ckpt_name', 'lora_name')
|
||||
* @param nodeType - The ComfyUI node comfyClass (e.g., 'CheckpointLoaderSimple', 'LoraLoader')
|
||||
* @returns true if this input should use asset browser
|
||||
*/
|
||||
function isAssetBrowserEligible(
|
||||
inputName: string,
|
||||
nodeType: string
|
||||
): boolean {
|
||||
return (
|
||||
// Must be an approved input name
|
||||
WHITELISTED_INPUTS.has(inputName) &&
|
||||
// Must be a registered node type
|
||||
useModelToNodeStore().getRegisteredNodeTypes().has(nodeType)
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
getAssetModelFolders,
|
||||
getAssetModels
|
||||
getAssetModels,
|
||||
isAssetBrowserEligible
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -484,7 +484,18 @@ export const useLitegraphService = () => {
|
||||
) ?? {}
|
||||
|
||||
if (widget) {
|
||||
widget.label = st(nameKey, widget.label ?? inputName)
|
||||
// Check if this is an Asset Browser button widget
|
||||
const isAssetBrowserButton =
|
||||
widget.type === 'button' && widget.value === 'Select model'
|
||||
|
||||
if (isAssetBrowserButton) {
|
||||
// Preserve Asset Browser button label (don't translate)
|
||||
widget.label = String(widget.value)
|
||||
} else {
|
||||
// Apply normal translation for other widgets
|
||||
widget.label = st(nameKey, widget.label ?? inputName)
|
||||
}
|
||||
|
||||
widget.options ??= {}
|
||||
Object.assign(widget.options, {
|
||||
advanced: inputSpec.advanced,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { defineStore } from 'pinia'
|
||||
import { ref } from 'vue'
|
||||
import { computed, ref } from 'vue'
|
||||
|
||||
import { ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore'
|
||||
|
||||
@@ -22,6 +22,22 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
|
||||
const modelToNodeMap = ref<Record<string, ModelNodeProvider[]>>({})
|
||||
const nodeDefStore = useNodeDefStore()
|
||||
const haveDefaultsLoaded = ref(false)
|
||||
|
||||
/** Internal computed for reactive caching of registered node types */
|
||||
const registeredNodeTypes = computed(() => {
|
||||
return new Set(
|
||||
Object.values(modelToNodeMap.value)
|
||||
.flat()
|
||||
.map((provider) => provider.nodeDef.name)
|
||||
)
|
||||
})
|
||||
|
||||
/** Get set of all registered node types for efficient lookup */
|
||||
function getRegisteredNodeTypes(): Set<string> {
|
||||
registerDefaults()
|
||||
return registeredNodeTypes.value
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the node provider for the given model type name.
|
||||
* @param modelType The name of the model type to get the node provider for.
|
||||
@@ -91,6 +107,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => {
|
||||
|
||||
return {
|
||||
modelToNodeMap,
|
||||
getRegisteredNodeTypes,
|
||||
getNodeProvider,
|
||||
getAllNodeProviders,
|
||||
registerNodeProvider,
|
||||
|
||||
@@ -1,39 +1,211 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget'
|
||||
import type { InputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import { assetService } from '@/services/assetService'
|
||||
|
||||
vi.mock('@/scripts/widgets', () => ({
|
||||
addValueControlWidgets: vi.fn()
|
||||
}))
|
||||
|
||||
const mockSettingStoreGet = vi.fn(() => false)
|
||||
vi.mock('@/stores/settingStore', () => ({
|
||||
useSettingStore: vi.fn(() => ({
|
||||
get: mockSettingStoreGet
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n', () => ({
|
||||
t: vi.fn((key: string) =>
|
||||
key === 'widgets.selectModel' ? 'Select model' : key
|
||||
)
|
||||
}))
|
||||
|
||||
vi.mock('@/services/assetService', () => ({
|
||||
assetService: {
|
||||
isAssetBrowserEligible: vi.fn(() => false)
|
||||
}
|
||||
}))
|
||||
|
||||
// Test factory functions
|
||||
function createMockWidget(overrides: Partial<IBaseWidget> = {}): IBaseWidget {
|
||||
return {
|
||||
type: 'combo',
|
||||
options: {},
|
||||
name: 'testWidget',
|
||||
value: undefined,
|
||||
...overrides
|
||||
} as IBaseWidget
|
||||
}
|
||||
|
||||
function createMockNode(comfyClass = 'TestNode'): LGraphNode {
|
||||
const node = new LGraphNode('TestNode')
|
||||
node.comfyClass = comfyClass
|
||||
|
||||
// Spy on the addWidget method
|
||||
vi.spyOn(node, 'addWidget').mockReturnValue(createMockWidget())
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
function createMockInputSpec(overrides: Partial<InputSpec> = {}): InputSpec {
|
||||
return {
|
||||
type: 'COMBO',
|
||||
name: 'testInput',
|
||||
...overrides
|
||||
} as InputSpec
|
||||
}
|
||||
|
||||
describe('useComboWidget', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
// Reset to defaults
|
||||
mockSettingStoreGet.mockReturnValue(false)
|
||||
vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('should handle undefined spec', () => {
|
||||
const constructor = useComboWidget()
|
||||
const mockNode = {
|
||||
addWidget: vi.fn().mockReturnValue({ options: {} } as any)
|
||||
}
|
||||
const mockWidget = createMockWidget()
|
||||
const mockNode = createMockNode()
|
||||
vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget)
|
||||
const inputSpec = createMockInputSpec({ name: 'inputName' })
|
||||
|
||||
const inputSpec: InputSpec = {
|
||||
type: 'COMBO',
|
||||
name: 'inputName'
|
||||
}
|
||||
|
||||
const widget = constructor(mockNode as any, inputSpec)
|
||||
const widget = constructor(mockNode, inputSpec)
|
||||
|
||||
expect(mockNode.addWidget).toHaveBeenCalledWith(
|
||||
'combo',
|
||||
'inputName',
|
||||
undefined, // default value
|
||||
expect.any(Function), // callback
|
||||
undefined,
|
||||
expect.any(Function),
|
||||
expect.objectContaining({
|
||||
values: []
|
||||
})
|
||||
)
|
||||
expect(widget).toEqual({ options: {} })
|
||||
expect(widget).toBe(mockWidget)
|
||||
})
|
||||
|
||||
it('should create normal combo widget when asset API is disabled', () => {
|
||||
mockSettingStoreGet.mockReturnValue(false)
|
||||
vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(true)
|
||||
|
||||
const constructor = useComboWidget()
|
||||
const mockWidget = createMockWidget()
|
||||
const mockNode = createMockNode('CheckpointLoaderSimple')
|
||||
vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget)
|
||||
const inputSpec = createMockInputSpec({
|
||||
name: 'ckpt_name',
|
||||
options: ['model1.safetensors', 'model2.safetensors']
|
||||
})
|
||||
|
||||
const widget = constructor(mockNode, inputSpec)
|
||||
|
||||
expect(mockNode.addWidget).toHaveBeenCalledWith(
|
||||
'combo',
|
||||
'ckpt_name',
|
||||
'model1.safetensors',
|
||||
expect.any(Function),
|
||||
{ values: ['model1.safetensors', 'model2.safetensors'] }
|
||||
)
|
||||
expect(mockSettingStoreGet).toHaveBeenCalledWith('Comfy.Assets.UseAssetAPI')
|
||||
expect(widget).toBe(mockWidget)
|
||||
})
|
||||
|
||||
it('should create normal combo widget when widget is not eligible for asset browser', () => {
|
||||
mockSettingStoreGet.mockReturnValue(true)
|
||||
vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(false)
|
||||
|
||||
const constructor = useComboWidget()
|
||||
const mockWidget = createMockWidget()
|
||||
const mockNode = createMockNode()
|
||||
vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget)
|
||||
const inputSpec = createMockInputSpec({
|
||||
name: 'not_eligible_widget',
|
||||
options: ['option1', 'option2']
|
||||
})
|
||||
|
||||
const widget = constructor(mockNode, inputSpec)
|
||||
|
||||
expect(mockNode.addWidget).toHaveBeenCalledWith(
|
||||
'combo',
|
||||
'not_eligible_widget',
|
||||
'option1',
|
||||
expect.any(Function),
|
||||
{ values: ['option1', 'option2'] }
|
||||
)
|
||||
expect(vi.mocked(assetService.isAssetBrowserEligible)).toHaveBeenCalledWith(
|
||||
'not_eligible_widget',
|
||||
'TestNode'
|
||||
)
|
||||
expect(widget).toBe(mockWidget)
|
||||
})
|
||||
|
||||
it('should create asset browser button widget when API enabled and widget eligible', () => {
|
||||
mockSettingStoreGet.mockReturnValue(true)
|
||||
vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(true)
|
||||
|
||||
const constructor = useComboWidget()
|
||||
const mockWidget = createMockWidget({
|
||||
type: 'button',
|
||||
name: 'ckpt_name',
|
||||
value: 'Select model'
|
||||
})
|
||||
const mockNode = createMockNode('CheckpointLoaderSimple')
|
||||
vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget)
|
||||
const inputSpec = createMockInputSpec({
|
||||
name: 'ckpt_name',
|
||||
options: ['model1.safetensors', 'model2.safetensors']
|
||||
})
|
||||
|
||||
const widget = constructor(mockNode, inputSpec)
|
||||
|
||||
expect(mockNode.addWidget).toHaveBeenCalledWith(
|
||||
'button',
|
||||
'ckpt_name',
|
||||
'Select model',
|
||||
expect.any(Function)
|
||||
)
|
||||
expect(mockSettingStoreGet).toHaveBeenCalledWith('Comfy.Assets.UseAssetAPI')
|
||||
expect(vi.mocked(assetService.isAssetBrowserEligible)).toHaveBeenCalledWith(
|
||||
'ckpt_name',
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(widget).toBe(mockWidget)
|
||||
})
|
||||
|
||||
it('should use asset browser button even when inputSpec has a default value but no options', () => {
|
||||
mockSettingStoreGet.mockReturnValue(true)
|
||||
vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(true)
|
||||
|
||||
const constructor = useComboWidget()
|
||||
const mockWidget = createMockWidget({
|
||||
type: 'button',
|
||||
name: 'ckpt_name',
|
||||
value: 'Select model'
|
||||
})
|
||||
const mockNode = createMockNode('CheckpointLoaderSimple')
|
||||
vi.mocked(mockNode.addWidget).mockReturnValue(mockWidget)
|
||||
const inputSpec = createMockInputSpec({
|
||||
name: 'ckpt_name',
|
||||
default: 'fallback.safetensors'
|
||||
// Note: no options array provided
|
||||
})
|
||||
|
||||
const widget = constructor(mockNode, inputSpec)
|
||||
|
||||
expect(mockNode.addWidget).toHaveBeenCalledWith(
|
||||
'button',
|
||||
'ckpt_name',
|
||||
'Select model',
|
||||
expect.any(Function)
|
||||
)
|
||||
expect(mockSettingStoreGet).toHaveBeenCalledWith('Comfy.Assets.UseAssetAPI')
|
||||
expect(vi.mocked(assetService.isAssetBrowserEligible)).toHaveBeenCalledWith(
|
||||
'ckpt_name',
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(widget).toBe(mockWidget)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -3,6 +3,20 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { api } from '@/scripts/api'
|
||||
import { assetService } from '@/services/assetService'
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: vi.fn(() => ({
|
||||
getRegisteredNodeTypes: vi.fn(
|
||||
() =>
|
||||
new Set([
|
||||
'CheckpointLoaderSimple',
|
||||
'LoraLoader',
|
||||
'VAELoader',
|
||||
'TestNode'
|
||||
])
|
||||
)
|
||||
}))
|
||||
}))
|
||||
|
||||
// Test data constants
|
||||
const MOCK_ASSETS = {
|
||||
checkpoints: {
|
||||
@@ -147,4 +161,43 @@ describe('assetService', () => {
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isAssetBrowserEligible', () => {
|
||||
it('should return true for eligible widget names with registered node types', () => {
|
||||
expect(
|
||||
assetService.isAssetBrowserEligible(
|
||||
'ckpt_name',
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
).toBe(true)
|
||||
expect(
|
||||
assetService.isAssetBrowserEligible('lora_name', 'LoraLoader')
|
||||
).toBe(true)
|
||||
expect(assetService.isAssetBrowserEligible('vae_name', 'VAELoader')).toBe(
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('should return false for non-eligible widget names', () => {
|
||||
expect(assetService.isAssetBrowserEligible('seed', 'TestNode')).toBe(
|
||||
false
|
||||
)
|
||||
expect(assetService.isAssetBrowserEligible('steps', 'TestNode')).toBe(
|
||||
false
|
||||
)
|
||||
expect(
|
||||
assetService.isAssetBrowserEligible('sampler_name', 'TestNode')
|
||||
).toBe(false)
|
||||
expect(assetService.isAssetBrowserEligible('', 'TestNode')).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for eligible widget names with unregistered node types', () => {
|
||||
expect(
|
||||
assetService.isAssetBrowserEligible('ckpt_name', 'UnknownNode')
|
||||
).toBe(false)
|
||||
expect(
|
||||
assetService.isAssetBrowserEligible('lora_name', 'UnknownNode')
|
||||
).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -21,45 +21,44 @@ const EXPECTED_DEFAULT_TYPES = [
|
||||
|
||||
type NodeDefStoreType = typeof import('@/stores/nodeDefStore')
|
||||
|
||||
// Create minimal but valid ComfyNodeDefImpl for testing
|
||||
function createMockNodeDef(name: string): ComfyNodeDefImpl {
|
||||
const def: ComfyNodeDefV1 = {
|
||||
name,
|
||||
display_name: name,
|
||||
category: 'test',
|
||||
python_module: 'nodes',
|
||||
description: '',
|
||||
input: { required: {}, optional: {} },
|
||||
output: [],
|
||||
output_name: [],
|
||||
output_is_list: [],
|
||||
output_node: false
|
||||
}
|
||||
return new ComfyNodeDefImpl(def)
|
||||
}
|
||||
|
||||
const MOCK_NODE_NAMES = [
|
||||
'CheckpointLoaderSimple',
|
||||
'ImageOnlyCheckpointLoader',
|
||||
'LoraLoader',
|
||||
'LoraLoaderModelOnly',
|
||||
'VAELoader',
|
||||
'ControlNetLoader',
|
||||
'UNETLoader',
|
||||
'UpscaleModelLoader',
|
||||
'StyleModelLoader',
|
||||
'GLIGENLoader'
|
||||
] as const
|
||||
|
||||
const mockNodeDefsByName = Object.fromEntries(
|
||||
MOCK_NODE_NAMES.map((name) => [name, createMockNodeDef(name)])
|
||||
)
|
||||
|
||||
// Mock nodeDefStore dependency - modelToNodeStore relies on this for registration
|
||||
// Most tests expect this to be populated; tests that need empty state can override
|
||||
vi.mock('@/stores/nodeDefStore', async (importOriginal) => {
|
||||
const original = await importOriginal<NodeDefStoreType>()
|
||||
const { ComfyNodeDefImpl } = original
|
||||
|
||||
// Create minimal but valid ComfyNodeDefImpl for testing
|
||||
function createMockNodeDef(name: string): ComfyNodeDefImpl {
|
||||
const def: ComfyNodeDefV1 = {
|
||||
name,
|
||||
display_name: name,
|
||||
category: 'test',
|
||||
python_module: 'nodes',
|
||||
description: '',
|
||||
input: { required: {}, optional: {} },
|
||||
output: [],
|
||||
output_name: [],
|
||||
output_is_list: [],
|
||||
output_node: false
|
||||
}
|
||||
return new ComfyNodeDefImpl(def)
|
||||
}
|
||||
|
||||
const MOCK_NODE_NAMES = [
|
||||
'CheckpointLoaderSimple',
|
||||
'ImageOnlyCheckpointLoader',
|
||||
'LoraLoader',
|
||||
'LoraLoaderModelOnly',
|
||||
'VAELoader',
|
||||
'ControlNetLoader',
|
||||
'UNETLoader',
|
||||
'UpscaleModelLoader',
|
||||
'StyleModelLoader',
|
||||
'GLIGENLoader'
|
||||
] as const
|
||||
|
||||
const mockNodeDefsByName = Object.fromEntries(
|
||||
MOCK_NODE_NAMES.map((name) => [name, createMockNodeDef(name)])
|
||||
)
|
||||
|
||||
return {
|
||||
...original,
|
||||
@@ -72,6 +71,7 @@ vi.mock('@/stores/nodeDefStore', async (importOriginal) => {
|
||||
describe('useModelToNodeStore', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('modelToNodeMap', () => {
|
||||
@@ -288,12 +288,58 @@ describe('useModelToNodeStore', () => {
|
||||
})
|
||||
|
||||
it('should not register when nodeDefStore is empty', () => {
|
||||
// Create fresh Pinia for this test to avoid state persistence
|
||||
setActivePinia(createPinia())
|
||||
|
||||
vi.mocked(useNodeDefStore, { partial: true }).mockReturnValue({
|
||||
nodeDefsByName: {}
|
||||
})
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
expect(modelToNodeStore.getNodeProvider('checkpoints')).toBeUndefined()
|
||||
|
||||
// Restore original mock for subsequent tests
|
||||
vi.mocked(useNodeDefStore, { partial: true }).mockReturnValue({
|
||||
nodeDefsByName: mockNodeDefsByName
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getRegisteredNodeTypes', () => {
|
||||
it('should return a Set instance', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
const result = modelToNodeStore.getRegisteredNodeTypes()
|
||||
expect(result).toBeInstanceOf(Set)
|
||||
})
|
||||
|
||||
it('should return empty set when nodeDefStore is empty', () => {
|
||||
// Create fresh Pinia for this test to avoid state persistence
|
||||
setActivePinia(createPinia())
|
||||
|
||||
vi.mocked(useNodeDefStore, { partial: true }).mockReturnValue({
|
||||
nodeDefsByName: {}
|
||||
})
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
|
||||
const result = modelToNodeStore.getRegisteredNodeTypes()
|
||||
expect(result.size).toBe(0)
|
||||
|
||||
// Restore original mock for subsequent tests
|
||||
vi.mocked(useNodeDefStore, { partial: true }).mockReturnValue({
|
||||
nodeDefsByName: mockNodeDefsByName
|
||||
})
|
||||
})
|
||||
|
||||
it('should contain node types for efficient Set.has() lookups', () => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
|
||||
const result = modelToNodeStore.getRegisteredNodeTypes()
|
||||
|
||||
// Test Set.has() functionality which assetService depends on
|
||||
expect(result.has('CheckpointLoaderSimple')).toBe(true)
|
||||
expect(result.has('LoraLoader')).toBe(true)
|
||||
expect(result.has('NonExistentNode')).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user