fix: use WidgetSelectDropdown for models (#6607)

## Summary

As the commit says, the model loaders were broken in cloud if you
enabled Vue Nodes (not a thing I think user does yet).

This fixes it by configuring the `WidgetSelectDropdown` to load so the
user load models like they would load a input or output asset.

## Review Focus

Probably `useAssetWidgetData` to make sure it's idomatic.

This part of
[assetsStore](https://github.com/Comfy-Org/ComfyUI_frontend/pull/6607/files#diff-18a5914c9f12c16d9c9c3a9f6d0e203a9c00598414d3d1c8637da9ca77339d83R158-R234)
as well.

## Screenshots

<img width="1196" height="1005" alt="Screenshot 2025-11-05 at 5 34
22 PM"
src="https://github.com/user-attachments/assets/804cd3c4-3370-4667-b606-bed52fcd6278"
/>

┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-6607-fix-use-WidgetSelectDropdown-for-models-2a36d73d36508143b185d06d736e4af9)
by [Unito](https://www.unito.io)

---------

Co-authored-by: GitHub Action <action@github.com>
This commit is contained in:
Arjan Singh
2025-11-05 19:34:17 -08:00
committed by GitHub
parent 63cb271509
commit 8849d54e20
13 changed files with 986 additions and 37 deletions

View File

@@ -0,0 +1,123 @@
import { createTestingPinia } from '@pinia/testing'
import { mount } from '@vue/test-utils'
import { describe, expect, it } from 'vitest'
import type {
SafeWidgetData,
VueNodeData
} from '@/composables/graph/useGraphNodeManager'
import NodeWidgets from '@/renderer/extensions/vueNodes/components/NodeWidgets.vue'
describe('NodeWidgets', () => {
const createMockWidget = (
overrides: Partial<SafeWidgetData> = {}
): SafeWidgetData => ({
name: 'test_widget',
type: 'combo',
value: 'test_value',
options: {
values: ['option1', 'option2']
},
callback: undefined,
spec: undefined,
label: undefined,
isDOMWidget: false,
slotMetadata: undefined,
...overrides
})
const createMockNodeData = (
nodeType: string = 'TestNode',
widgets: SafeWidgetData[] = []
): VueNodeData => ({
id: '1',
type: nodeType,
widgets,
title: 'Test Node',
mode: 0,
selected: false,
executing: false,
inputs: [],
outputs: []
})
const mountComponent = (nodeData?: VueNodeData) => {
return mount(NodeWidgets, {
props: {
nodeData
},
global: {
plugins: [createTestingPinia()],
stubs: {
// Stub InputSlot to avoid complex slot registration dependencies
InputSlot: true
},
mocks: {
$t: (key: string) => key
}
}
})
}
describe('node-type prop passing', () => {
it('passes node type to widget components', () => {
const widget = createMockWidget()
const nodeData = createMockNodeData('CheckpointLoaderSimple', [widget])
const wrapper = mountComponent(nodeData)
// Find the dynamically rendered widget component
const widgetComponent = wrapper.find('.lg-node-widget')
expect(widgetComponent.exists()).toBe(true)
// Verify node-type prop is passed
const component = widgetComponent.findComponent({ name: 'WidgetSelect' })
if (component.exists()) {
expect(component.props('nodeType')).toBe('CheckpointLoaderSimple')
}
})
it('passes empty string when nodeData is undefined', () => {
const wrapper = mountComponent(undefined)
// No widgets should be rendered
const widgetComponents = wrapper.findAll('.lg-node-widget')
expect(widgetComponents).toHaveLength(0)
})
it('passes empty string when nodeData.type is undefined', () => {
const widget = createMockWidget()
const nodeData = createMockNodeData('', [widget])
const wrapper = mountComponent(nodeData)
const widgetComponent = wrapper.find('.lg-node-widget')
if (widgetComponent.exists()) {
const component = widgetComponent.findComponent({
name: 'WidgetSelect'
})
if (component.exists()) {
expect(component.props('nodeType')).toBe('')
}
}
})
it.for(['CheckpointLoaderSimple', 'LoraLoader', 'VAELoader', 'KSampler'])(
'passes correct node type: %s',
(nodeType) => {
const widget = createMockWidget()
const nodeData = createMockNodeData(nodeType, [widget])
const wrapper = mountComponent(nodeData)
const widgetComponent = wrapper.find('.lg-node-widget')
expect(widgetComponent.exists()).toBe(true)
const component = widgetComponent.findComponent({
name: 'WidgetSelect'
})
if (component.exists()) {
expect(component.props('nodeType')).toBe(nodeType)
}
}
)
})
})

View File

@@ -0,0 +1,341 @@
import { createTestingPinia } from '@pinia/testing'
import { mount } from '@vue/test-utils'
import PrimeVue from 'primevue/config'
import Select from 'primevue/select'
import type { SelectProps } from 'primevue/select'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { ComboInputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2'
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
import WidgetSelect from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelect.vue'
import WidgetSelectDefault from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelectDefault.vue'
import WidgetSelectDropdown from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelectDropdown.vue'
// Mock state for distribution and settings
const mockDistributionState = vi.hoisted(() => ({ isCloud: false }))
const mockSettingStoreGet = vi.hoisted(() => vi.fn(() => false))
const mockIsAssetBrowserEligible = vi.hoisted(() => vi.fn(() => false))
vi.mock('@/platform/distribution/types', () => ({
get isCloud() {
return mockDistributionState.isCloud
}
}))
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: vi.fn(() => ({
get: mockSettingStoreGet
}))
}))
vi.mock('@/platform/assets/services/assetService', () => ({
assetService: {
isAssetBrowserEligible: mockIsAssetBrowserEligible
}
}))
describe('WidgetSelect Value Binding', () => {
beforeEach(() => {
// Reset all mocks before each test
mockDistributionState.isCloud = false
mockSettingStoreGet.mockReturnValue(false)
mockIsAssetBrowserEligible.mockReturnValue(false)
vi.clearAllMocks()
})
const createMockWidget = (
value: string = 'option1',
options: Partial<
SelectProps & { values?: string[]; return_index?: boolean }
> = {},
callback?: (value: string | number | undefined) => void,
spec?: ComboInputSpec
): SimplifiedWidget<string | number | undefined> => ({
name: 'test_select',
type: 'combo',
value,
options: {
values: ['option1', 'option2', 'option3'],
...options
},
callback,
spec
})
const mountComponent = (
widget: SimplifiedWidget<string | number | undefined>,
modelValue: string | number | undefined,
readonly = false
) => {
return mount(WidgetSelect, {
props: {
widget,
modelValue,
readonly
},
global: {
plugins: [PrimeVue, createTestingPinia()],
components: { Select }
}
})
}
const setSelectValueAndEmit = async (
wrapper: ReturnType<typeof mount>,
value: string
) => {
const select = wrapper.findComponent({ name: 'Select' })
await select.setValue(value)
return wrapper.emitted('update:modelValue')
}
describe('Vue Event Emission', () => {
it('emits Vue event when selection changes', async () => {
const widget = createMockWidget('option1')
const wrapper = mountComponent(widget, 'option1')
const emitted = await setSelectValueAndEmit(wrapper, 'option2')
expect(emitted).toBeDefined()
expect(emitted![0]).toContain('option2')
})
it('emits string value for different options', async () => {
const widget = createMockWidget('option1')
const wrapper = mountComponent(widget, 'option1')
const emitted = await setSelectValueAndEmit(wrapper, 'option3')
expect(emitted).toBeDefined()
// Should emit the string value
expect(emitted![0]).toContain('option3')
})
it('handles custom option values', async () => {
const customOptions = ['custom_a', 'custom_b', 'custom_c']
const widget = createMockWidget('custom_a', { values: customOptions })
const wrapper = mountComponent(widget, 'custom_a')
const emitted = await setSelectValueAndEmit(wrapper, 'custom_b')
expect(emitted).toBeDefined()
expect(emitted![0]).toContain('custom_b')
})
it('handles missing callback gracefully', async () => {
const widget = createMockWidget('option1', {}, undefined)
const wrapper = mountComponent(widget, 'option1')
const emitted = await setSelectValueAndEmit(wrapper, 'option2')
// Should emit Vue event
expect(emitted).toBeDefined()
expect(emitted![0]).toContain('option2')
})
it('handles value changes gracefully', async () => {
const widget = createMockWidget('option1')
const wrapper = mountComponent(widget, 'option1')
const emitted = await setSelectValueAndEmit(wrapper, 'option2')
expect(emitted).toBeDefined()
expect(emitted![0]).toContain('option2')
})
})
describe('Option Handling', () => {
it('handles empty options array', async () => {
const widget = createMockWidget('', { values: [] })
const wrapper = mountComponent(widget, '')
const select = wrapper.findComponent({ name: 'Select' })
expect(select.props('options')).toEqual([])
})
it('handles single option', async () => {
const widget = createMockWidget('only_option', {
values: ['only_option']
})
const wrapper = mountComponent(widget, 'only_option')
const select = wrapper.findComponent({ name: 'Select' })
const options = select.props('options')
expect(options).toHaveLength(1)
expect(options[0]).toEqual('only_option')
})
it('handles options with special characters', async () => {
const specialOptions = [
'option with spaces',
'option@#$%',
'option/with\\slashes'
]
const widget = createMockWidget(specialOptions[0], {
values: specialOptions
})
const wrapper = mountComponent(widget, specialOptions[0])
const emitted = await setSelectValueAndEmit(wrapper, specialOptions[1])
expect(emitted).toBeDefined()
expect(emitted![0]).toContain(specialOptions[1])
})
})
describe('Edge Cases', () => {
it('handles selection of non-existent option gracefully', async () => {
const widget = createMockWidget('option1')
const wrapper = mountComponent(widget, 'option1')
const emitted = await setSelectValueAndEmit(
wrapper,
'non_existent_option'
)
// Should still emit Vue event with the value
expect(emitted).toBeDefined()
expect(emitted![0]).toContain('non_existent_option')
})
it('handles numeric string options correctly', async () => {
const numericOptions = ['1', '2', '10', '100']
const widget = createMockWidget('1', { values: numericOptions })
const wrapper = mountComponent(widget, '1')
const emitted = await setSelectValueAndEmit(wrapper, '100')
// Should maintain string type in emitted event
expect(emitted).toBeDefined()
expect(emitted![0]).toContain('100')
})
})
describe('node-type prop passing', () => {
it('passes node-type prop to WidgetSelectDropdown', () => {
const spec: ComboInputSpec = {
type: 'COMBO',
name: 'test_select',
image_upload: true
}
const widget = createMockWidget('option1', {}, undefined, spec)
const wrapper = mount(WidgetSelect, {
props: {
widget,
modelValue: 'option1',
nodeType: 'CheckpointLoaderSimple'
},
global: {
plugins: [PrimeVue, createTestingPinia()],
components: { Select }
}
})
const dropdown = wrapper.findComponent(WidgetSelectDropdown)
expect(dropdown.exists()).toBe(true)
expect(dropdown.props('nodeType')).toBe('CheckpointLoaderSimple')
})
it('does not pass node-type prop to WidgetSelectDefault', () => {
const widget = createMockWidget('option1')
const wrapper = mount(WidgetSelect, {
props: {
widget,
modelValue: 'option1',
nodeType: 'KSampler'
},
global: {
plugins: [PrimeVue, createTestingPinia()],
components: { Select }
}
})
const defaultSelect = wrapper.findComponent(WidgetSelectDefault)
expect(defaultSelect.exists()).toBe(true)
})
})
describe('Asset mode detection', () => {
it('enables asset mode when all conditions are met', () => {
mockDistributionState.isCloud = true
mockSettingStoreGet.mockReturnValue(true)
mockIsAssetBrowserEligible.mockReturnValue(true)
const widget = createMockWidget('test.safetensors')
const wrapper = mount(WidgetSelect, {
props: {
widget,
modelValue: 'test.safetensors',
nodeType: 'CheckpointLoaderSimple'
},
global: {
plugins: [PrimeVue, createTestingPinia()],
components: { Select }
}
})
expect(wrapper.findComponent(WidgetSelectDropdown).exists()).toBe(true)
})
it('disables asset mode when conditions are not met', () => {
mockDistributionState.isCloud = false
const widget = createMockWidget('test.safetensors')
const wrapper = mount(WidgetSelect, {
props: {
widget,
modelValue: 'test.safetensors',
nodeType: 'CheckpointLoaderSimple'
},
global: {
plugins: [PrimeVue, createTestingPinia()],
components: { Select }
}
})
expect(wrapper.findComponent(WidgetSelectDefault).exists()).toBe(true)
})
})
describe('Spec-aware rendering', () => {
it('uses dropdown variant when combo spec enables image uploads', () => {
const spec: ComboInputSpec = {
type: 'COMBO',
name: 'test_select',
image_upload: true
}
const widget = createMockWidget('option1', {}, undefined, spec)
const wrapper = mountComponent(widget, 'option1')
expect(wrapper.findComponent(WidgetSelectDropdown).exists()).toBe(true)
expect(wrapper.findComponent(WidgetSelectDefault).exists()).toBe(false)
})
it('uses dropdown variant for audio uploads', (context) => {
context.skip('allowUpload is not false, should it be? needs diagnosis')
const spec: ComboInputSpec = {
type: 'COMBO',
name: 'test_select',
audio_upload: true
}
const widget = createMockWidget('clip.wav', {}, undefined, spec)
const wrapper = mountComponent(widget, 'clip.wav')
const dropdown = wrapper.findComponent(WidgetSelectDropdown)
expect(dropdown.exists()).toBe(true)
expect(dropdown.props('assetKind')).toBe('audio')
expect(dropdown.props('allowUpload')).toBe(false)
})
it('keeps default select when no spec or media hints are present', () => {
const widget = createMockWidget('plain', {
values: ['plain', 'text']
})
const wrapper = mountComponent(widget, 'plain')
expect(wrapper.findComponent(WidgetSelectDefault).exists()).toBe(true)
expect(wrapper.findComponent(WidgetSelectDropdown).exists()).toBe(false)
})
})
})

View File

@@ -0,0 +1,42 @@
import { describe, expect, it, vi } from 'vitest'
import { ref } from 'vue'
import { useAssetWidgetData } from '@/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData'
vi.mock('@/platform/distribution/types', () => ({
isCloud: false
}))
const mockUpdateModelsForNodeType = vi.fn()
const mockGetCategoryForNodeType = vi.fn()
vi.mock('@/stores/assetsStore', () => ({
useAssetsStore: () => ({
modelAssetsByNodeType: new Map(),
modelLoadingByNodeType: new Map(),
modelErrorByNodeType: new Map(),
updateModelsForNodeType: mockUpdateModelsForNodeType
})
}))
vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: () => ({
getCategoryForNodeType: mockGetCategoryForNodeType
})
}))
describe('useAssetWidgetData (desktop/isCloud=false)', () => {
it('returns empty/default values without calling stores', () => {
const nodeType = ref('CheckpointLoaderSimple')
const { category, assets, dropdownItems, isLoading, error } =
useAssetWidgetData(nodeType)
expect(category.value).toBeUndefined()
expect(assets.value).toEqual([])
expect(dropdownItems.value).toEqual([])
expect(isLoading.value).toBe(false)
expect(error.value).toBeNull()
expect(mockUpdateModelsForNodeType).not.toHaveBeenCalled()
expect(mockGetCategoryForNodeType).not.toHaveBeenCalled()
})
})

View File

@@ -0,0 +1,245 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { nextTick, ref } from 'vue'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { useAssetWidgetData } from '@/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData'
vi.mock('@/platform/distribution/types', () => ({
isCloud: true
}))
const mockModelAssetsByNodeType = new Map<string, AssetItem[]>()
const mockModelLoadingByNodeType = new Map<string, boolean>()
const mockModelErrorByNodeType = new Map<string, Error | null>()
const mockUpdateModelsForNodeType = vi.fn()
const mockGetCategoryForNodeType = vi.fn()
vi.mock('@/stores/assetsStore', () => ({
useAssetsStore: () => ({
modelAssetsByNodeType: mockModelAssetsByNodeType,
modelLoadingByNodeType: mockModelLoadingByNodeType,
modelErrorByNodeType: mockModelErrorByNodeType,
updateModelsForNodeType: mockUpdateModelsForNodeType
})
}))
vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: () => ({
getCategoryForNodeType: mockGetCategoryForNodeType
})
}))
describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
beforeEach(() => {
vi.clearAllMocks()
mockModelAssetsByNodeType.clear()
mockModelLoadingByNodeType.clear()
mockModelErrorByNodeType.clear()
mockGetCategoryForNodeType.mockReturnValue(undefined)
mockUpdateModelsForNodeType.mockImplementation(
async (): Promise<AssetItem[]> => {
return []
}
)
})
const createMockAsset = (
id: string,
name: string,
filename: string,
previewUrl?: string
): AssetItem => ({
id,
name,
size: 1024,
tags: ['models', 'checkpoints'],
created_at: '2025-01-01T00:00:00Z',
preview_url: previewUrl,
user_metadata: {
filename
}
})
it('fetches assets and transforms to dropdown items', async () => {
const mockAssets: AssetItem[] = [
createMockAsset(
'asset-1',
'Beautiful Model',
'models/beautiful_model.safetensors',
'/api/preview/asset-1'
),
createMockAsset('asset-2', 'Model B', 'model_b.safetensors', '/preview/2')
]
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
return mockAssets
}
)
const nodeType = ref('CheckpointLoaderSimple')
const { category, assets, dropdownItems, isLoading } =
useAssetWidgetData(nodeType)
await nextTick()
await vi.waitFor(() => !isLoading.value)
expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith(
'CheckpointLoaderSimple'
)
expect(category.value).toBe('checkpoints')
expect(assets.value).toEqual(mockAssets)
expect(dropdownItems.value).toHaveLength(2)
const item = dropdownItems.value[0]
expect(item.id).toBe('asset-1')
expect(item.name).toBe('models/beautiful_model.safetensors')
expect(item.label).toBe('Beautiful Model')
expect(item.mediaSrc).toBe('/api/preview/asset-1')
})
it('handles API errors gracefully', async () => {
const mockError = new Error('Network error')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelErrorByNodeType.set(_nodeType, mockError)
mockModelAssetsByNodeType.set(_nodeType, [])
mockModelLoadingByNodeType.set(_nodeType, false)
return []
}
)
const nodeType = ref('CheckpointLoaderSimple')
const { assets, error, isLoading } = useAssetWidgetData(nodeType)
await nextTick()
await vi.waitFor(() => !isLoading.value)
expect(error.value).toBe(mockError)
expect(assets.value).toEqual([])
})
it('returns empty for unknown node type', async () => {
mockGetCategoryForNodeType.mockReturnValue(undefined)
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, [])
mockModelLoadingByNodeType.set(_nodeType, false)
return []
}
)
const nodeType = ref('UnknownNodeType')
const { category, assets } = useAssetWidgetData(nodeType)
await nextTick()
expect(category.value).toBeUndefined()
expect(assets.value).toEqual([])
})
describe('MaybeRefOrGetter parameter support', () => {
it('accepts plain string value', async () => {
const mockAssets: AssetItem[] = [
createMockAsset('asset-1', 'Model A', 'model_a.safetensors')
]
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
return mockAssets
}
)
const { category, assets, isLoading } = useAssetWidgetData(
'CheckpointLoaderSimple'
)
await nextTick()
await vi.waitFor(() => !isLoading.value)
expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith(
'CheckpointLoaderSimple'
)
expect(category.value).toBe('checkpoints')
expect(assets.value).toEqual(mockAssets)
})
it('accepts getter function', async () => {
const mockAssets: AssetItem[] = [
createMockAsset('asset-1', 'Model A', 'model_a.safetensors')
]
mockGetCategoryForNodeType.mockReturnValue('loras')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
return mockAssets
}
)
const nodeType = ref('LoraLoader')
const { category, assets, isLoading } = useAssetWidgetData(
() => nodeType.value
)
await nextTick()
await vi.waitFor(() => !isLoading.value)
expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith('LoraLoader')
expect(category.value).toBe('loras')
expect(assets.value).toEqual(mockAssets)
})
it('accepts ref (backward compatibility)', async () => {
const mockAssets: AssetItem[] = [
createMockAsset('asset-1', 'Model A', 'model_a.safetensors')
]
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
mockUpdateModelsForNodeType.mockImplementation(
async (_nodeType: string): Promise<AssetItem[]> => {
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
mockModelLoadingByNodeType.set(_nodeType, false)
return mockAssets
}
)
const nodeTypeRef = ref('CheckpointLoaderSimple')
const { category, assets, isLoading } = useAssetWidgetData(nodeTypeRef)
await nextTick()
await vi.waitFor(() => !isLoading.value)
expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith(
'CheckpointLoaderSimple'
)
expect(category.value).toBe('checkpoints')
expect(assets.value).toEqual(mockAssets)
})
it('handles undefined node type gracefully', async () => {
const { category, assets, dropdownItems, isLoading, error } =
useAssetWidgetData(undefined)
await nextTick()
expect(mockUpdateModelsForNodeType).not.toHaveBeenCalled()
expect(category.value).toBeUndefined()
expect(assets.value).toEqual([])
expect(dropdownItems.value).toEqual([])
expect(isLoading.value).toBe(false)
expect(error.value).toBeNull()
})
})
})

View File

@@ -0,0 +1,89 @@
import { createTestingPinia } from '@pinia/testing'
import { flushPromises, mount } from '@vue/test-utils'
import PrimeVue from 'primevue/config'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
import WidgetSelect from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelect.vue'
// Mock modules
vi.mock('@/platform/distribution/types', () => ({
isCloud: true
}))
vi.mock('@/platform/assets/services/assetService', () => ({
assetService: {
isAssetBrowserEligible: vi.fn(() => true)
}
}))
const mockSettingStoreGet = vi.fn()
vi.mock('@/platform/settings/settingStore', () => ({
useSettingStore: vi.fn(() => ({
get: mockSettingStoreGet
}))
}))
// Import after mocks are defined
import { assetService } from '@/platform/assets/services/assetService'
const mockAssetServiceEligible = vi.mocked(assetService.isAssetBrowserEligible)
describe('WidgetSelect asset mode', () => {
const createWidget = (): SimplifiedWidget<string | number | undefined> => ({
name: 'ckpt_name',
type: 'combo',
value: undefined,
options: {
values: []
}
})
beforeEach(() => {
vi.clearAllMocks()
mockAssetServiceEligible.mockReturnValue(true)
mockSettingStoreGet.mockReturnValue(true) // Default to true for UseAssetAPI
})
// Helper to mount with common setup
const mountWidget = () => {
return mount(WidgetSelect, {
props: {
widget: createWidget(),
modelValue: undefined,
nodeType: 'CheckpointLoaderSimple'
},
global: {
plugins: [PrimeVue, createTestingPinia()]
}
})
}
it('uses dropdown when isCloud && UseAssetAPI && isEligible', async () => {
const wrapper = mountWidget()
await flushPromises()
expect(
wrapper.findComponent({ name: 'WidgetSelectDropdown' }).exists()
).toBe(true)
})
it('uses default widget when UseAssetAPI setting is false', () => {
mockSettingStoreGet.mockReturnValue(false)
const wrapper = mountWidget()
expect(
wrapper.findComponent({ name: 'WidgetSelectDefault' }).exists()
).toBe(true)
})
it('uses default widget when node is not eligible', () => {
mockAssetServiceEligible.mockReturnValue(false)
const wrapper = mountWidget()
expect(
wrapper.findComponent({ name: 'WidgetSelectDefault' }).exists()
).toBe(true)
})
})

View File

@@ -1,8 +1,23 @@
import { createPinia, setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it } from 'vitest'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
import { useAssetsStore } from '@/stores/assetsStore'
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
// Mock isCloud to be true for these tests
vi.mock('@/platform/distribution/types', () => ({
isCloud: true
}))
// Mock assetService
const mockGetAssetsForNodeType = vi.hoisted(() => vi.fn())
vi.mock('@/platform/assets/services/assetService', () => ({
assetService: {
getAssetsForNodeType: mockGetAssetsForNodeType
}
}))
const HASH_FILENAME =
'72e786ff2a44d682c4294db0b7098e569832bc394efc6dad644e6ec85a78efb7.png'
@@ -24,6 +39,7 @@ function createMockAssetItem(overrides: Partial<AssetItem> = {}): AssetItem {
describe('assetsStore', () => {
beforeEach(() => {
setActivePinia(createPinia())
vi.clearAllMocks()
})
describe('input asset mapping helpers', () => {
@@ -154,4 +170,56 @@ describe('assetsStore', () => {
expect(store.inputAssetsByFilename.size).toBe(0)
})
})
describe('model assets caching', () => {
beforeEach(() => {
const modelToNodeStore = useModelToNodeStore()
modelToNodeStore.registerDefaults()
})
it('should cache assets by node type', async () => {
const store = useAssetsStore()
const mockAssets: AssetItem[] = [
createMockAssetItem({ id: '1', name: 'model_a.safetensors' }),
createMockAssetItem({ id: '2', name: 'model_b.safetensors' })
]
mockGetAssetsForNodeType.mockResolvedValue(mockAssets)
await store.updateModelsForNodeType('CheckpointLoaderSimple')
expect(mockGetAssetsForNodeType).toHaveBeenCalledWith(
'CheckpointLoaderSimple'
)
expect(store.modelAssetsByNodeType.get('CheckpointLoaderSimple')).toEqual(
mockAssets
)
})
it('should track loading state', async () => {
const store = useAssetsStore()
mockGetAssetsForNodeType.mockImplementation(
() => new Promise((resolve) => setTimeout(() => resolve([]), 100))
)
const promise = store.updateModelsForNodeType('LoraLoader')
expect(store.modelLoadingByNodeType.get('LoraLoader')).toBe(true)
await promise
expect(store.modelLoadingByNodeType.get('LoraLoader')).toBe(false)
})
it('should handle errors gracefully', async () => {
const store = useAssetsStore()
const mockError = new Error('Network error')
mockGetAssetsForNodeType.mockRejectedValue(mockError)
await store.updateModelsForNodeType('VAELoader')
expect(store.modelErrorByNodeType.get('VAELoader')).toBe(mockError)
expect(store.modelAssetsByNodeType.get('VAELoader')).toEqual([])
expect(store.modelLoadingByNodeType.get('VAELoader')).toBe(false)
})
})
})