mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-01-26 10:59:53 +00:00
fix: use WidgetSelectDropdown for models (#6607)
## Summary As the commit says, the model loaders were broken in cloud if you enabled Vue Nodes (not a thing I think user does yet). This fixes it by configuring the `WidgetSelectDropdown` to load so the user load models like they would load a input or output asset. ## Review Focus Probably `useAssetWidgetData` to make sure it's idomatic. This part of [assetsStore](https://github.com/Comfy-Org/ComfyUI_frontend/pull/6607/files#diff-18a5914c9f12c16d9c9c3a9f6d0e203a9c00598414d3d1c8637da9ca77339d83R158-R234) as well. ## Screenshots <img width="1196" height="1005" alt="Screenshot 2025-11-05 at 5 34 22 PM" src="https://github.com/user-attachments/assets/804cd3c4-3370-4667-b606-bed52fcd6278" /> ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-6607-fix-use-WidgetSelectDropdown-for-models-2a36d73d36508143b185d06d736e4af9) by [Unito](https://www.unito.io) --------- Co-authored-by: GitHub Action <action@github.com>
This commit is contained in:
@@ -50,6 +50,7 @@
|
||||
:widget="widget.simplified"
|
||||
:model-value="widget.value"
|
||||
:node-id="nodeData?.id != null ? String(nodeData.id) : ''"
|
||||
:node-type="nodeType"
|
||||
class="flex-1"
|
||||
@update:model-value="widget.updateHandler"
|
||||
/>
|
||||
@@ -162,7 +163,9 @@ const processedWidgets = computed((): ProcessedWidget[] => {
|
||||
// Update the widget value directly
|
||||
widget.value = value as WidgetValue
|
||||
|
||||
if (widget.callback) {
|
||||
// Skip callback for asset widgets - their callback opens the modal,
|
||||
// but Vue asset mode handles selection through the dropdown
|
||||
if (widget.callback && widget.type !== 'asset') {
|
||||
widget.callback(value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,14 @@
|
||||
:asset-kind="assetKind"
|
||||
:allow-upload="allowUpload"
|
||||
:upload-folder="uploadFolder"
|
||||
:is-asset-mode="isAssetMode"
|
||||
:default-layout-mode="defaultLayoutMode"
|
||||
@update:model-value="handleUpdateModelValue"
|
||||
/>
|
||||
<WidgetSelectDefault
|
||||
v-else
|
||||
v-bind="props"
|
||||
:widget="widget"
|
||||
:model-value="modelValue"
|
||||
@update:model-value="handleUpdateModelValue"
|
||||
/>
|
||||
</template>
|
||||
@@ -17,18 +20,22 @@
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
|
||||
import { assetService } from '@/platform/assets/services/assetService'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import WidgetSelectDefault from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelectDefault.vue'
|
||||
import WidgetSelectDropdown from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelectDropdown.vue'
|
||||
import type { LayoutMode } from '@/renderer/extensions/vueNodes/widgets/components/form/dropdown/types'
|
||||
import type { ResultItemType } from '@/schemas/apiSchema'
|
||||
import { isComboInputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import type { ComboInputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
|
||||
import type { AssetKind } from '@/types/widgetTypes'
|
||||
|
||||
import WidgetSelectDefault from './WidgetSelectDefault.vue'
|
||||
import WidgetSelectDropdown from './WidgetSelectDropdown.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
widget: SimplifiedWidget<string | number | undefined>
|
||||
modelValue: string | number | undefined
|
||||
nodeType?: string
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@@ -90,10 +97,30 @@ const specDescriptor = computed<{
|
||||
}
|
||||
})
|
||||
|
||||
const isAssetMode = computed(() => {
|
||||
if (isCloud) {
|
||||
const settingStore = useSettingStore()
|
||||
const isUsingAssetAPI = settingStore.get('Comfy.Assets.UseAssetAPI')
|
||||
const isEligible = assetService.isAssetBrowserEligible(
|
||||
props.nodeType,
|
||||
props.widget.name
|
||||
)
|
||||
|
||||
return isUsingAssetAPI && isEligible
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
const assetKind = computed(() => specDescriptor.value.kind)
|
||||
const isDropdownUIWidget = computed(() => assetKind.value !== 'unknown')
|
||||
const isDropdownUIWidget = computed(
|
||||
() => isAssetMode.value || assetKind.value !== 'unknown'
|
||||
)
|
||||
const allowUpload = computed(() => specDescriptor.value.allowUpload)
|
||||
const uploadFolder = computed<ResultItemType>(() => {
|
||||
return specDescriptor.value.folder ?? 'input'
|
||||
})
|
||||
const defaultLayoutMode = computed<LayoutMode>(() => {
|
||||
return isAssetMode.value ? 'list' : 'grid'
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
<script setup lang="ts">
|
||||
import { computed, provide, ref, watch } from 'vue'
|
||||
import { capitalize } from 'es-toolkit'
|
||||
import { computed, provide, ref, toRef, watch } from 'vue'
|
||||
|
||||
import { useWidgetValue } from '@/composables/graph/useWidgetValue'
|
||||
import { useTransformCompatOverlayProps } from '@/composables/useTransformCompatOverlayProps'
|
||||
import { t } from '@/i18n'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import FormDropdown from '@/renderer/extensions/vueNodes/widgets/components/form/dropdown/FormDropdown.vue'
|
||||
import { AssetKindKey } from '@/renderer/extensions/vueNodes/widgets/components/form/dropdown/types'
|
||||
import type {
|
||||
DropdownItem,
|
||||
FilterOption,
|
||||
LayoutMode,
|
||||
SelectedKey
|
||||
} from '@/renderer/extensions/vueNodes/widgets/components/form/dropdown/types'
|
||||
import WidgetLayoutField from '@/renderer/extensions/vueNodes/widgets/components/layout/WidgetLayoutField.vue'
|
||||
import { useAssetWidgetData } from '@/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData'
|
||||
import type { ResultItemType } from '@/schemas/apiSchema'
|
||||
import { api } from '@/scripts/api'
|
||||
import { useAssetsStore } from '@/stores/assetsStore'
|
||||
@@ -16,21 +27,15 @@ import {
|
||||
filterWidgetProps
|
||||
} from '@/utils/widgetPropFilter'
|
||||
|
||||
import FormDropdown from './form/dropdown/FormDropdown.vue'
|
||||
import { AssetKindKey } from './form/dropdown/types'
|
||||
import type {
|
||||
DropdownItem,
|
||||
FilterOption,
|
||||
SelectedKey
|
||||
} from './form/dropdown/types'
|
||||
import WidgetLayoutField from './layout/WidgetLayoutField.vue'
|
||||
|
||||
const props = defineProps<{
|
||||
widget: SimplifiedWidget<string | number | undefined>
|
||||
modelValue: string | number | undefined
|
||||
nodeType?: string
|
||||
assetKind?: AssetKind
|
||||
allowUpload?: boolean
|
||||
uploadFolder?: ResultItemType
|
||||
isAssetMode?: boolean
|
||||
defaultLayoutMode?: LayoutMode
|
||||
}>()
|
||||
|
||||
provide(
|
||||
@@ -59,12 +64,26 @@ const combinedProps = computed(() => ({
|
||||
...transformCompatProps.value
|
||||
}))
|
||||
|
||||
const getAssetData = () => {
|
||||
if (props.isAssetMode && props.nodeType) {
|
||||
return useAssetWidgetData(toRef(() => props.nodeType))
|
||||
}
|
||||
return null
|
||||
}
|
||||
const assetData = getAssetData()
|
||||
|
||||
const filterSelected = ref('all')
|
||||
const filterOptions = ref<FilterOption[]>([
|
||||
{ id: 'all', name: 'All' },
|
||||
{ id: 'inputs', name: 'Inputs' },
|
||||
{ id: 'outputs', name: 'Outputs' }
|
||||
])
|
||||
const filterOptions = computed<FilterOption[]>(() => {
|
||||
if (props.isAssetMode) {
|
||||
const categoryName = assetData?.category.value ?? 'All'
|
||||
return [{ id: 'all', name: capitalize(categoryName) }]
|
||||
}
|
||||
return [
|
||||
{ id: 'all', name: 'All' },
|
||||
{ id: 'inputs', name: 'Inputs' },
|
||||
{ id: 'outputs', name: 'Outputs' }
|
||||
]
|
||||
})
|
||||
|
||||
const selectedSet = ref<Set<SelectedKey>>(new Set())
|
||||
|
||||
@@ -132,9 +151,16 @@ const outputItems = computed<DropdownItem[]>(() => {
|
||||
})
|
||||
|
||||
const allItems = computed<DropdownItem[]>(() => {
|
||||
if (props.isAssetMode && assetData) {
|
||||
return assetData.dropdownItems.value
|
||||
}
|
||||
return [...inputItems.value, ...outputItems.value]
|
||||
})
|
||||
const dropdownItems = computed<DropdownItem[]>(() => {
|
||||
if (props.isAssetMode) {
|
||||
return allItems.value
|
||||
}
|
||||
|
||||
switch (filterSelected.value) {
|
||||
case 'inputs':
|
||||
return inputItems.value
|
||||
@@ -169,7 +195,10 @@ const mediaPlaceholder = computed(() => {
|
||||
return t('widgets.uploadSelect.placeholder')
|
||||
})
|
||||
|
||||
const uploadable = computed(() => props.allowUpload === true)
|
||||
const uploadable = computed(() => {
|
||||
if (props.isAssetMode) return false
|
||||
return props.allowUpload === true
|
||||
})
|
||||
|
||||
const acceptTypes = computed(() => {
|
||||
// Be permissive with accept types because backend uses libraries
|
||||
@@ -186,6 +215,8 @@ const acceptTypes = computed(() => {
|
||||
}
|
||||
})
|
||||
|
||||
const layoutMode = ref<LayoutMode>(props.defaultLayoutMode ?? 'grid')
|
||||
|
||||
watch(
|
||||
localValue,
|
||||
(currentValue) => {
|
||||
@@ -313,6 +344,7 @@ function getMediaUrl(
|
||||
<FormDropdown
|
||||
v-model:selected="selectedSet"
|
||||
v-model:filter-selected="filterSelected"
|
||||
v-model:layout-mode="layoutMode"
|
||||
:items="dropdownItems"
|
||||
:placeholder="mediaPlaceholder"
|
||||
:multiple="false"
|
||||
|
||||
@@ -124,16 +124,17 @@ function handleVideoLoad(event: Event) {
|
||||
:class="
|
||||
cn('flex gap-1', {
|
||||
'flex-col': layout === 'grid',
|
||||
'flex-col px-4 py-1 w-full justify-center': layout === 'list',
|
||||
'flex-col px-4 py-1 w-full justify-center min-w-0': layout === 'list',
|
||||
'flex-row p-2 items-center justify-between w-full':
|
||||
layout === 'list-small'
|
||||
})
|
||||
"
|
||||
>
|
||||
<span
|
||||
v-tooltip="layout === 'grid' ? (label ?? name) : undefined"
|
||||
:class="
|
||||
cn(
|
||||
'block text-[15px] line-clamp-2 wrap-break-word',
|
||||
'block text-[15px] line-clamp-2 break-words overflow-hidden',
|
||||
'transition-colors duration-150',
|
||||
// selection
|
||||
!!selected && 'text-blue-500'
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
import { computed, toValue, watch } from 'vue'
|
||||
import type { MaybeRefOrGetter } from 'vue'
|
||||
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { DropdownItem } from '@/renderer/extensions/vueNodes/widgets/components/form/dropdown/types'
|
||||
import { useAssetsStore } from '@/stores/assetsStore'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
/**
|
||||
* Composable for fetching and transforming asset data for Vue node widgets.
|
||||
* Provides reactive asset data based on node type with automatic category detection.
|
||||
* Uses store-based caching to avoid duplicate fetches across multiple instances.
|
||||
*
|
||||
* Cloud-only composable - returns empty data when not in cloud environment.
|
||||
*
|
||||
* @param nodeType - ComfyUI node type (ref, getter, or plain value). Can be undefined.
|
||||
* Accepts: ref('CheckpointLoaderSimple'), () => 'CheckpointLoaderSimple', or 'CheckpointLoaderSimple'
|
||||
* @returns Reactive data including category, assets, dropdown items, loading state, and errors
|
||||
*/
|
||||
export function useAssetWidgetData(
|
||||
nodeType: MaybeRefOrGetter<string | undefined>
|
||||
) {
|
||||
if (isCloud) {
|
||||
const assetsStore = useAssetsStore()
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
|
||||
const category = computed(() => {
|
||||
const resolvedType = toValue(nodeType)
|
||||
return resolvedType
|
||||
? modelToNodeStore.getCategoryForNodeType(resolvedType)
|
||||
: undefined
|
||||
})
|
||||
|
||||
const assets = computed<AssetItem[]>(() => {
|
||||
const resolvedType = toValue(nodeType)
|
||||
return resolvedType
|
||||
? (assetsStore.modelAssetsByNodeType.get(resolvedType) ?? [])
|
||||
: []
|
||||
})
|
||||
|
||||
const isLoading = computed(() => {
|
||||
const resolvedType = toValue(nodeType)
|
||||
return resolvedType
|
||||
? (assetsStore.modelLoadingByNodeType.get(resolvedType) ?? false)
|
||||
: false
|
||||
})
|
||||
|
||||
const error = computed<Error | null>(() => {
|
||||
const resolvedType = toValue(nodeType)
|
||||
return resolvedType
|
||||
? (assetsStore.modelErrorByNodeType.get(resolvedType) ?? null)
|
||||
: null
|
||||
})
|
||||
|
||||
const dropdownItems = computed<DropdownItem[]>(() => {
|
||||
return assets.value.map((asset) => ({
|
||||
id: asset.id,
|
||||
name:
|
||||
(asset.user_metadata?.filename as string | undefined) ?? asset.name,
|
||||
label: asset.name,
|
||||
mediaSrc: asset.preview_url ?? '',
|
||||
metadata: ''
|
||||
}))
|
||||
})
|
||||
|
||||
watch(
|
||||
() => toValue(nodeType),
|
||||
async (currentNodeType) => {
|
||||
if (!currentNodeType) {
|
||||
return
|
||||
}
|
||||
|
||||
const hasData = assetsStore.modelAssetsByNodeType.has(currentNodeType)
|
||||
|
||||
if (!hasData) {
|
||||
await assetsStore.updateModelsForNodeType(currentNodeType)
|
||||
}
|
||||
},
|
||||
{ immediate: true }
|
||||
)
|
||||
|
||||
return {
|
||||
category,
|
||||
assets,
|
||||
dropdownItems,
|
||||
isLoading,
|
||||
error
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
category: computed(() => undefined),
|
||||
assets: computed(() => []),
|
||||
dropdownItems: computed(() => []),
|
||||
isLoading: computed(() => false),
|
||||
error: computed(() => null)
|
||||
}
|
||||
}
|
||||
@@ -62,7 +62,10 @@ const coreWidgetDefinitions: Array<[string, WidgetDefinition]> = [
|
||||
essential: true
|
||||
}
|
||||
],
|
||||
['combo', { component: WidgetSelect, aliases: ['COMBO'], essential: true }],
|
||||
[
|
||||
'combo',
|
||||
{ component: WidgetSelect, aliases: ['COMBO', 'asset'], essential: true }
|
||||
],
|
||||
[
|
||||
'color',
|
||||
{ component: WidgetColorPicker, aliases: ['COLOR'], essential: false }
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useAsyncState } from '@vueuse/core'
|
||||
import { defineStore } from 'pinia'
|
||||
import { computed } from 'vue'
|
||||
import { computed, shallowReactive } from 'vue'
|
||||
|
||||
import {
|
||||
mapInputFileToAssetItem,
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { assetService } from '@/platform/assets/services/assetService'
|
||||
import { isCloud } from '@/platform/distribution/types'
|
||||
import type { HistoryTaskItem } from '@/schemas/apiSchema'
|
||||
import { api } from '@/scripts/api'
|
||||
|
||||
import { TaskItemImpl } from './queueStore'
|
||||
@@ -47,7 +48,7 @@ async function fetchInputFilesFromCloud(): Promise<AssetItem[]> {
|
||||
/**
|
||||
* Convert history task items to asset items
|
||||
*/
|
||||
function mapHistoryToAssets(historyItems: any[]): AssetItem[] {
|
||||
function mapHistoryToAssets(historyItems: HistoryTaskItem[]): AssetItem[] {
|
||||
const assetItems: AssetItem[] = []
|
||||
|
||||
for (const item of historyItems) {
|
||||
@@ -87,9 +88,13 @@ function mapHistoryToAssets(historyItems: any[]): AssetItem[] {
|
||||
export const useAssetsStore = defineStore('assets', () => {
|
||||
const maxHistoryItems = 200
|
||||
|
||||
const fetchInputFiles = isCloud
|
||||
? fetchInputFilesFromCloud
|
||||
: fetchInputFilesFromAPI
|
||||
const getFetchInputFiles = () => {
|
||||
if (isCloud) {
|
||||
return fetchInputFilesFromCloud
|
||||
}
|
||||
return fetchInputFilesFromAPI
|
||||
}
|
||||
const fetchInputFiles = getFetchInputFiles()
|
||||
|
||||
const {
|
||||
state: inputAssets,
|
||||
@@ -129,7 +134,6 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
const inputAssetsByFilename = computed(() => {
|
||||
const map = new Map<string, AssetItem>()
|
||||
for (const asset of inputAssets.value) {
|
||||
// Use asset_hash as the key (hash-based filename)
|
||||
if (asset.asset_hash) {
|
||||
map.set(asset.asset_hash, asset)
|
||||
}
|
||||
@@ -146,6 +150,96 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
return inputAssetsByFilename.value.get(filename)?.name ?? filename
|
||||
}
|
||||
|
||||
/**
|
||||
* Model assets cached by node type (e.g., 'CheckpointLoaderSimple', 'LoraLoader')
|
||||
* Used by multiple loader nodes to avoid duplicate fetches
|
||||
* Cloud-only feature - empty Maps in desktop builds
|
||||
*/
|
||||
const getModelState = () => {
|
||||
if (isCloud) {
|
||||
const modelAssetsByNodeType = shallowReactive(
|
||||
new Map<string, AssetItem[]>()
|
||||
)
|
||||
const modelLoadingByNodeType = shallowReactive(new Map<string, boolean>())
|
||||
const modelErrorByNodeType = shallowReactive(
|
||||
new Map<string, Error | null>()
|
||||
)
|
||||
|
||||
const stateByNodeType = shallowReactive(
|
||||
new Map<string, ReturnType<typeof useAsyncState<AssetItem[]>>>()
|
||||
)
|
||||
|
||||
/**
|
||||
* Fetch and cache model assets for a specific node type
|
||||
* Uses VueUse's useAsyncState for automatic loading/error tracking
|
||||
* @param nodeType The node type to fetch assets for (e.g., 'CheckpointLoaderSimple')
|
||||
* @returns Promise resolving to the fetched assets
|
||||
*/
|
||||
async function updateModelsForNodeType(
|
||||
nodeType: string
|
||||
): Promise<AssetItem[]> {
|
||||
if (!stateByNodeType.has(nodeType)) {
|
||||
stateByNodeType.set(
|
||||
nodeType,
|
||||
useAsyncState(
|
||||
() => assetService.getAssetsForNodeType(nodeType),
|
||||
[],
|
||||
{
|
||||
immediate: false,
|
||||
resetOnExecute: false,
|
||||
onError: (err) => {
|
||||
console.error(
|
||||
`Error fetching model assets for ${nodeType}:`,
|
||||
err
|
||||
)
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
const state = stateByNodeType.get(nodeType)!
|
||||
|
||||
modelLoadingByNodeType.set(nodeType, true)
|
||||
modelErrorByNodeType.set(nodeType, null)
|
||||
|
||||
try {
|
||||
await state.execute()
|
||||
const assets = state.state.value
|
||||
modelAssetsByNodeType.set(nodeType, assets)
|
||||
modelErrorByNodeType.set(
|
||||
nodeType,
|
||||
state.error.value instanceof Error ? state.error.value : null
|
||||
)
|
||||
return assets
|
||||
} finally {
|
||||
modelLoadingByNodeType.set(nodeType, state.isLoading.value)
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
modelAssetsByNodeType,
|
||||
modelLoadingByNodeType,
|
||||
modelErrorByNodeType,
|
||||
updateModelsForNodeType
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
modelAssetsByNodeType: shallowReactive(new Map<string, AssetItem[]>()),
|
||||
modelLoadingByNodeType: shallowReactive(new Map<string, boolean>()),
|
||||
modelErrorByNodeType: shallowReactive(new Map<string, Error | null>()),
|
||||
updateModelsForNodeType: async () => []
|
||||
}
|
||||
}
|
||||
|
||||
const {
|
||||
modelAssetsByNodeType,
|
||||
modelLoadingByNodeType,
|
||||
modelErrorByNodeType,
|
||||
updateModelsForNodeType
|
||||
} = getModelState()
|
||||
|
||||
return {
|
||||
// States
|
||||
inputAssets,
|
||||
@@ -161,6 +255,12 @@ export const useAssetsStore = defineStore('assets', () => {
|
||||
|
||||
// Input mapping helpers
|
||||
inputAssetsByFilename,
|
||||
getInputName
|
||||
getInputName,
|
||||
|
||||
// Model assets
|
||||
modelAssetsByNodeType,
|
||||
modelLoadingByNodeType,
|
||||
modelErrorByNodeType,
|
||||
updateModelsForNodeType
|
||||
}
|
||||
})
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { mount } from '@vue/test-utils'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import type {
|
||||
SafeWidgetData,
|
||||
VueNodeData
|
||||
} from '@/composables/graph/useGraphNodeManager'
|
||||
|
||||
import NodeWidgets from '@/renderer/extensions/vueNodes/components/NodeWidgets.vue'
|
||||
|
||||
describe('NodeWidgets', () => {
|
||||
const createMockWidget = (
|
||||
overrides: Partial<SafeWidgetData> = {}
|
||||
): SafeWidgetData => ({
|
||||
name: 'test_widget',
|
||||
type: 'combo',
|
||||
value: 'test_value',
|
||||
options: {
|
||||
values: ['option1', 'option2']
|
||||
},
|
||||
callback: undefined,
|
||||
spec: undefined,
|
||||
label: undefined,
|
||||
isDOMWidget: false,
|
||||
slotMetadata: undefined,
|
||||
...overrides
|
||||
})
|
||||
|
||||
const createMockNodeData = (
|
||||
nodeType: string = 'TestNode',
|
||||
widgets: SafeWidgetData[] = []
|
||||
): VueNodeData => ({
|
||||
id: '1',
|
||||
type: nodeType,
|
||||
widgets,
|
||||
title: 'Test Node',
|
||||
mode: 0,
|
||||
selected: false,
|
||||
executing: false,
|
||||
inputs: [],
|
||||
outputs: []
|
||||
})
|
||||
|
||||
const mountComponent = (nodeData?: VueNodeData) => {
|
||||
return mount(NodeWidgets, {
|
||||
props: {
|
||||
nodeData
|
||||
},
|
||||
global: {
|
||||
plugins: [createTestingPinia()],
|
||||
stubs: {
|
||||
// Stub InputSlot to avoid complex slot registration dependencies
|
||||
InputSlot: true
|
||||
},
|
||||
mocks: {
|
||||
$t: (key: string) => key
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
describe('node-type prop passing', () => {
|
||||
it('passes node type to widget components', () => {
|
||||
const widget = createMockWidget()
|
||||
const nodeData = createMockNodeData('CheckpointLoaderSimple', [widget])
|
||||
const wrapper = mountComponent(nodeData)
|
||||
|
||||
// Find the dynamically rendered widget component
|
||||
const widgetComponent = wrapper.find('.lg-node-widget')
|
||||
expect(widgetComponent.exists()).toBe(true)
|
||||
|
||||
// Verify node-type prop is passed
|
||||
const component = widgetComponent.findComponent({ name: 'WidgetSelect' })
|
||||
if (component.exists()) {
|
||||
expect(component.props('nodeType')).toBe('CheckpointLoaderSimple')
|
||||
}
|
||||
})
|
||||
|
||||
it('passes empty string when nodeData is undefined', () => {
|
||||
const wrapper = mountComponent(undefined)
|
||||
|
||||
// No widgets should be rendered
|
||||
const widgetComponents = wrapper.findAll('.lg-node-widget')
|
||||
expect(widgetComponents).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('passes empty string when nodeData.type is undefined', () => {
|
||||
const widget = createMockWidget()
|
||||
const nodeData = createMockNodeData('', [widget])
|
||||
const wrapper = mountComponent(nodeData)
|
||||
|
||||
const widgetComponent = wrapper.find('.lg-node-widget')
|
||||
if (widgetComponent.exists()) {
|
||||
const component = widgetComponent.findComponent({
|
||||
name: 'WidgetSelect'
|
||||
})
|
||||
if (component.exists()) {
|
||||
expect(component.props('nodeType')).toBe('')
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
it.for(['CheckpointLoaderSimple', 'LoraLoader', 'VAELoader', 'KSampler'])(
|
||||
'passes correct node type: %s',
|
||||
(nodeType) => {
|
||||
const widget = createMockWidget()
|
||||
const nodeData = createMockNodeData(nodeType, [widget])
|
||||
const wrapper = mountComponent(nodeData)
|
||||
|
||||
const widgetComponent = wrapper.find('.lg-node-widget')
|
||||
expect(widgetComponent.exists()).toBe(true)
|
||||
|
||||
const component = widgetComponent.findComponent({
|
||||
name: 'WidgetSelect'
|
||||
})
|
||||
if (component.exists()) {
|
||||
expect(component.props('nodeType')).toBe(nodeType)
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
})
|
||||
@@ -3,16 +3,47 @@ import { mount } from '@vue/test-utils'
|
||||
import PrimeVue from 'primevue/config'
|
||||
import Select from 'primevue/select'
|
||||
import type { SelectProps } from 'primevue/select'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ComboInputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
|
||||
|
||||
import WidgetSelect from './WidgetSelect.vue'
|
||||
import WidgetSelectDefault from './WidgetSelectDefault.vue'
|
||||
import WidgetSelectDropdown from './WidgetSelectDropdown.vue'
|
||||
import WidgetSelect from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelect.vue'
|
||||
import WidgetSelectDefault from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelectDefault.vue'
|
||||
import WidgetSelectDropdown from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelectDropdown.vue'
|
||||
|
||||
// Mock state for distribution and settings
|
||||
const mockDistributionState = vi.hoisted(() => ({ isCloud: false }))
|
||||
const mockSettingStoreGet = vi.hoisted(() => vi.fn(() => false))
|
||||
const mockIsAssetBrowserEligible = vi.hoisted(() => vi.fn(() => false))
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
get isCloud() {
|
||||
return mockDistributionState.isCloud
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: vi.fn(() => ({
|
||||
get: mockSettingStoreGet
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: {
|
||||
isAssetBrowserEligible: mockIsAssetBrowserEligible
|
||||
}
|
||||
}))
|
||||
|
||||
describe('WidgetSelect Value Binding', () => {
|
||||
beforeEach(() => {
|
||||
// Reset all mocks before each test
|
||||
mockDistributionState.isCloud = false
|
||||
mockSettingStoreGet.mockReturnValue(false)
|
||||
mockIsAssetBrowserEligible.mockReturnValue(false)
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
const createMockWidget = (
|
||||
value: string = 'option1',
|
||||
options: Partial<
|
||||
@@ -181,6 +212,92 @@ describe('WidgetSelect Value Binding', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe('node-type prop passing', () => {
|
||||
it('passes node-type prop to WidgetSelectDropdown', () => {
|
||||
const spec: ComboInputSpec = {
|
||||
type: 'COMBO',
|
||||
name: 'test_select',
|
||||
image_upload: true
|
||||
}
|
||||
const widget = createMockWidget('option1', {}, undefined, spec)
|
||||
const wrapper = mount(WidgetSelect, {
|
||||
props: {
|
||||
widget,
|
||||
modelValue: 'option1',
|
||||
nodeType: 'CheckpointLoaderSimple'
|
||||
},
|
||||
global: {
|
||||
plugins: [PrimeVue, createTestingPinia()],
|
||||
components: { Select }
|
||||
}
|
||||
})
|
||||
|
||||
const dropdown = wrapper.findComponent(WidgetSelectDropdown)
|
||||
expect(dropdown.exists()).toBe(true)
|
||||
expect(dropdown.props('nodeType')).toBe('CheckpointLoaderSimple')
|
||||
})
|
||||
|
||||
it('does not pass node-type prop to WidgetSelectDefault', () => {
|
||||
const widget = createMockWidget('option1')
|
||||
const wrapper = mount(WidgetSelect, {
|
||||
props: {
|
||||
widget,
|
||||
modelValue: 'option1',
|
||||
nodeType: 'KSampler'
|
||||
},
|
||||
global: {
|
||||
plugins: [PrimeVue, createTestingPinia()],
|
||||
components: { Select }
|
||||
}
|
||||
})
|
||||
|
||||
const defaultSelect = wrapper.findComponent(WidgetSelectDefault)
|
||||
expect(defaultSelect.exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Asset mode detection', () => {
|
||||
it('enables asset mode when all conditions are met', () => {
|
||||
mockDistributionState.isCloud = true
|
||||
mockSettingStoreGet.mockReturnValue(true)
|
||||
mockIsAssetBrowserEligible.mockReturnValue(true)
|
||||
|
||||
const widget = createMockWidget('test.safetensors')
|
||||
const wrapper = mount(WidgetSelect, {
|
||||
props: {
|
||||
widget,
|
||||
modelValue: 'test.safetensors',
|
||||
nodeType: 'CheckpointLoaderSimple'
|
||||
},
|
||||
global: {
|
||||
plugins: [PrimeVue, createTestingPinia()],
|
||||
components: { Select }
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.findComponent(WidgetSelectDropdown).exists()).toBe(true)
|
||||
})
|
||||
|
||||
it('disables asset mode when conditions are not met', () => {
|
||||
mockDistributionState.isCloud = false
|
||||
|
||||
const widget = createMockWidget('test.safetensors')
|
||||
const wrapper = mount(WidgetSelect, {
|
||||
props: {
|
||||
widget,
|
||||
modelValue: 'test.safetensors',
|
||||
nodeType: 'CheckpointLoaderSimple'
|
||||
},
|
||||
global: {
|
||||
plugins: [PrimeVue, createTestingPinia()],
|
||||
components: { Select }
|
||||
}
|
||||
})
|
||||
|
||||
expect(wrapper.findComponent(WidgetSelectDefault).exists()).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Spec-aware rendering', () => {
|
||||
it('uses dropdown variant when combo spec enables image uploads', () => {
|
||||
const spec: ComboInputSpec = {
|
||||
@@ -0,0 +1,42 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { ref } from 'vue'
|
||||
|
||||
import { useAssetWidgetData } from '@/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData'
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
isCloud: false
|
||||
}))
|
||||
|
||||
const mockUpdateModelsForNodeType = vi.fn()
|
||||
const mockGetCategoryForNodeType = vi.fn()
|
||||
|
||||
vi.mock('@/stores/assetsStore', () => ({
|
||||
useAssetsStore: () => ({
|
||||
modelAssetsByNodeType: new Map(),
|
||||
modelLoadingByNodeType: new Map(),
|
||||
modelErrorByNodeType: new Map(),
|
||||
updateModelsForNodeType: mockUpdateModelsForNodeType
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({
|
||||
getCategoryForNodeType: mockGetCategoryForNodeType
|
||||
})
|
||||
}))
|
||||
|
||||
describe('useAssetWidgetData (desktop/isCloud=false)', () => {
|
||||
it('returns empty/default values without calling stores', () => {
|
||||
const nodeType = ref('CheckpointLoaderSimple')
|
||||
const { category, assets, dropdownItems, isLoading, error } =
|
||||
useAssetWidgetData(nodeType)
|
||||
|
||||
expect(category.value).toBeUndefined()
|
||||
expect(assets.value).toEqual([])
|
||||
expect(dropdownItems.value).toEqual([])
|
||||
expect(isLoading.value).toBe(false)
|
||||
expect(error.value).toBeNull()
|
||||
expect(mockUpdateModelsForNodeType).not.toHaveBeenCalled()
|
||||
expect(mockGetCategoryForNodeType).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
245
tests-ui/tests/composables/useAssetWidgetData.test.ts
Normal file
245
tests-ui/tests/composables/useAssetWidgetData.test.ts
Normal file
@@ -0,0 +1,245 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { nextTick, ref } from 'vue'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { useAssetWidgetData } from '@/renderer/extensions/vueNodes/widgets/composables/useAssetWidgetData'
|
||||
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
isCloud: true
|
||||
}))
|
||||
|
||||
const mockModelAssetsByNodeType = new Map<string, AssetItem[]>()
|
||||
const mockModelLoadingByNodeType = new Map<string, boolean>()
|
||||
const mockModelErrorByNodeType = new Map<string, Error | null>()
|
||||
const mockUpdateModelsForNodeType = vi.fn()
|
||||
const mockGetCategoryForNodeType = vi.fn()
|
||||
|
||||
vi.mock('@/stores/assetsStore', () => ({
|
||||
useAssetsStore: () => ({
|
||||
modelAssetsByNodeType: mockModelAssetsByNodeType,
|
||||
modelLoadingByNodeType: mockModelLoadingByNodeType,
|
||||
modelErrorByNodeType: mockModelErrorByNodeType,
|
||||
updateModelsForNodeType: mockUpdateModelsForNodeType
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({
|
||||
getCategoryForNodeType: mockGetCategoryForNodeType
|
||||
})
|
||||
}))
|
||||
|
||||
describe('useAssetWidgetData (cloud mode, isCloud=true)', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockModelAssetsByNodeType.clear()
|
||||
mockModelLoadingByNodeType.clear()
|
||||
mockModelErrorByNodeType.clear()
|
||||
mockGetCategoryForNodeType.mockReturnValue(undefined)
|
||||
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (): Promise<AssetItem[]> => {
|
||||
return []
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
const createMockAsset = (
|
||||
id: string,
|
||||
name: string,
|
||||
filename: string,
|
||||
previewUrl?: string
|
||||
): AssetItem => ({
|
||||
id,
|
||||
name,
|
||||
size: 1024,
|
||||
tags: ['models', 'checkpoints'],
|
||||
created_at: '2025-01-01T00:00:00Z',
|
||||
preview_url: previewUrl,
|
||||
user_metadata: {
|
||||
filename
|
||||
}
|
||||
})
|
||||
|
||||
it('fetches assets and transforms to dropdown items', async () => {
|
||||
const mockAssets: AssetItem[] = [
|
||||
createMockAsset(
|
||||
'asset-1',
|
||||
'Beautiful Model',
|
||||
'models/beautiful_model.safetensors',
|
||||
'/api/preview/asset-1'
|
||||
),
|
||||
createMockAsset('asset-2', 'Model B', 'model_b.safetensors', '/preview/2')
|
||||
]
|
||||
|
||||
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
||||
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
return mockAssets
|
||||
}
|
||||
)
|
||||
|
||||
const nodeType = ref('CheckpointLoaderSimple')
|
||||
const { category, assets, dropdownItems, isLoading } =
|
||||
useAssetWidgetData(nodeType)
|
||||
|
||||
await nextTick()
|
||||
await vi.waitFor(() => !isLoading.value)
|
||||
|
||||
expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(category.value).toBe('checkpoints')
|
||||
expect(assets.value).toEqual(mockAssets)
|
||||
|
||||
expect(dropdownItems.value).toHaveLength(2)
|
||||
const item = dropdownItems.value[0]
|
||||
expect(item.id).toBe('asset-1')
|
||||
expect(item.name).toBe('models/beautiful_model.safetensors')
|
||||
expect(item.label).toBe('Beautiful Model')
|
||||
expect(item.mediaSrc).toBe('/api/preview/asset-1')
|
||||
})
|
||||
|
||||
it('handles API errors gracefully', async () => {
|
||||
const mockError = new Error('Network error')
|
||||
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelErrorByNodeType.set(_nodeType, mockError)
|
||||
mockModelAssetsByNodeType.set(_nodeType, [])
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
return []
|
||||
}
|
||||
)
|
||||
|
||||
const nodeType = ref('CheckpointLoaderSimple')
|
||||
const { assets, error, isLoading } = useAssetWidgetData(nodeType)
|
||||
|
||||
await nextTick()
|
||||
await vi.waitFor(() => !isLoading.value)
|
||||
|
||||
expect(error.value).toBe(mockError)
|
||||
expect(assets.value).toEqual([])
|
||||
})
|
||||
|
||||
it('returns empty for unknown node type', async () => {
|
||||
mockGetCategoryForNodeType.mockReturnValue(undefined)
|
||||
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, [])
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
return []
|
||||
}
|
||||
)
|
||||
|
||||
const nodeType = ref('UnknownNodeType')
|
||||
const { category, assets } = useAssetWidgetData(nodeType)
|
||||
|
||||
await nextTick()
|
||||
|
||||
expect(category.value).toBeUndefined()
|
||||
expect(assets.value).toEqual([])
|
||||
})
|
||||
|
||||
describe('MaybeRefOrGetter parameter support', () => {
|
||||
it('accepts plain string value', async () => {
|
||||
const mockAssets: AssetItem[] = [
|
||||
createMockAsset('asset-1', 'Model A', 'model_a.safetensors')
|
||||
]
|
||||
|
||||
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
return mockAssets
|
||||
}
|
||||
)
|
||||
|
||||
const { category, assets, isLoading } = useAssetWidgetData(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
|
||||
await nextTick()
|
||||
await vi.waitFor(() => !isLoading.value)
|
||||
|
||||
expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(category.value).toBe('checkpoints')
|
||||
expect(assets.value).toEqual(mockAssets)
|
||||
})
|
||||
|
||||
it('accepts getter function', async () => {
|
||||
const mockAssets: AssetItem[] = [
|
||||
createMockAsset('asset-1', 'Model A', 'model_a.safetensors')
|
||||
]
|
||||
|
||||
mockGetCategoryForNodeType.mockReturnValue('loras')
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
return mockAssets
|
||||
}
|
||||
)
|
||||
|
||||
const nodeType = ref('LoraLoader')
|
||||
const { category, assets, isLoading } = useAssetWidgetData(
|
||||
() => nodeType.value
|
||||
)
|
||||
|
||||
await nextTick()
|
||||
await vi.waitFor(() => !isLoading.value)
|
||||
|
||||
expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith('LoraLoader')
|
||||
expect(category.value).toBe('loras')
|
||||
expect(assets.value).toEqual(mockAssets)
|
||||
})
|
||||
|
||||
it('accepts ref (backward compatibility)', async () => {
|
||||
const mockAssets: AssetItem[] = [
|
||||
createMockAsset('asset-1', 'Model A', 'model_a.safetensors')
|
||||
]
|
||||
|
||||
mockGetCategoryForNodeType.mockReturnValue('checkpoints')
|
||||
mockUpdateModelsForNodeType.mockImplementation(
|
||||
async (_nodeType: string): Promise<AssetItem[]> => {
|
||||
mockModelAssetsByNodeType.set(_nodeType, mockAssets)
|
||||
mockModelLoadingByNodeType.set(_nodeType, false)
|
||||
return mockAssets
|
||||
}
|
||||
)
|
||||
|
||||
const nodeTypeRef = ref('CheckpointLoaderSimple')
|
||||
const { category, assets, isLoading } = useAssetWidgetData(nodeTypeRef)
|
||||
|
||||
await nextTick()
|
||||
await vi.waitFor(() => !isLoading.value)
|
||||
|
||||
expect(mockUpdateModelsForNodeType).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(category.value).toBe('checkpoints')
|
||||
expect(assets.value).toEqual(mockAssets)
|
||||
})
|
||||
|
||||
it('handles undefined node type gracefully', async () => {
|
||||
const { category, assets, dropdownItems, isLoading, error } =
|
||||
useAssetWidgetData(undefined)
|
||||
|
||||
await nextTick()
|
||||
|
||||
expect(mockUpdateModelsForNodeType).not.toHaveBeenCalled()
|
||||
expect(category.value).toBeUndefined()
|
||||
expect(assets.value).toEqual([])
|
||||
expect(dropdownItems.value).toEqual([])
|
||||
expect(isLoading.value).toBe(false)
|
||||
expect(error.value).toBeNull()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,89 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { flushPromises, mount } from '@vue/test-utils'
|
||||
import PrimeVue from 'primevue/config'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
|
||||
|
||||
import WidgetSelect from '@/renderer/extensions/vueNodes/widgets/components/WidgetSelect.vue'
|
||||
|
||||
// Mock modules
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
isCloud: true
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: {
|
||||
isAssetBrowserEligible: vi.fn(() => true)
|
||||
}
|
||||
}))
|
||||
|
||||
const mockSettingStoreGet = vi.fn()
|
||||
|
||||
vi.mock('@/platform/settings/settingStore', () => ({
|
||||
useSettingStore: vi.fn(() => ({
|
||||
get: mockSettingStoreGet
|
||||
}))
|
||||
}))
|
||||
|
||||
// Import after mocks are defined
|
||||
import { assetService } from '@/platform/assets/services/assetService'
|
||||
const mockAssetServiceEligible = vi.mocked(assetService.isAssetBrowserEligible)
|
||||
|
||||
describe('WidgetSelect asset mode', () => {
|
||||
const createWidget = (): SimplifiedWidget<string | number | undefined> => ({
|
||||
name: 'ckpt_name',
|
||||
type: 'combo',
|
||||
value: undefined,
|
||||
options: {
|
||||
values: []
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockAssetServiceEligible.mockReturnValue(true)
|
||||
mockSettingStoreGet.mockReturnValue(true) // Default to true for UseAssetAPI
|
||||
})
|
||||
|
||||
// Helper to mount with common setup
|
||||
const mountWidget = () => {
|
||||
return mount(WidgetSelect, {
|
||||
props: {
|
||||
widget: createWidget(),
|
||||
modelValue: undefined,
|
||||
nodeType: 'CheckpointLoaderSimple'
|
||||
},
|
||||
global: {
|
||||
plugins: [PrimeVue, createTestingPinia()]
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
it('uses dropdown when isCloud && UseAssetAPI && isEligible', async () => {
|
||||
const wrapper = mountWidget()
|
||||
await flushPromises()
|
||||
|
||||
expect(
|
||||
wrapper.findComponent({ name: 'WidgetSelectDropdown' }).exists()
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('uses default widget when UseAssetAPI setting is false', () => {
|
||||
mockSettingStoreGet.mockReturnValue(false)
|
||||
const wrapper = mountWidget()
|
||||
|
||||
expect(
|
||||
wrapper.findComponent({ name: 'WidgetSelectDefault' }).exists()
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('uses default widget when node is not eligible', () => {
|
||||
mockAssetServiceEligible.mockReturnValue(false)
|
||||
const wrapper = mountWidget()
|
||||
|
||||
expect(
|
||||
wrapper.findComponent({ name: 'WidgetSelectDefault' }).exists()
|
||||
).toBe(true)
|
||||
})
|
||||
})
|
||||
@@ -1,8 +1,23 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { useAssetsStore } from '@/stores/assetsStore'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
// Mock isCloud to be true for these tests
|
||||
vi.mock('@/platform/distribution/types', () => ({
|
||||
isCloud: true
|
||||
}))
|
||||
|
||||
// Mock assetService
|
||||
const mockGetAssetsForNodeType = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/platform/assets/services/assetService', () => ({
|
||||
assetService: {
|
||||
getAssetsForNodeType: mockGetAssetsForNodeType
|
||||
}
|
||||
}))
|
||||
|
||||
const HASH_FILENAME =
|
||||
'72e786ff2a44d682c4294db0b7098e569832bc394efc6dad644e6ec85a78efb7.png'
|
||||
@@ -24,6 +39,7 @@ function createMockAssetItem(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
describe('assetsStore', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('input asset mapping helpers', () => {
|
||||
@@ -154,4 +170,56 @@ describe('assetsStore', () => {
|
||||
expect(store.inputAssetsByFilename.size).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('model assets caching', () => {
|
||||
beforeEach(() => {
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
modelToNodeStore.registerDefaults()
|
||||
})
|
||||
|
||||
it('should cache assets by node type', async () => {
|
||||
const store = useAssetsStore()
|
||||
const mockAssets: AssetItem[] = [
|
||||
createMockAssetItem({ id: '1', name: 'model_a.safetensors' }),
|
||||
createMockAssetItem({ id: '2', name: 'model_b.safetensors' })
|
||||
]
|
||||
mockGetAssetsForNodeType.mockResolvedValue(mockAssets)
|
||||
|
||||
await store.updateModelsForNodeType('CheckpointLoaderSimple')
|
||||
|
||||
expect(mockGetAssetsForNodeType).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(store.modelAssetsByNodeType.get('CheckpointLoaderSimple')).toEqual(
|
||||
mockAssets
|
||||
)
|
||||
})
|
||||
|
||||
it('should track loading state', async () => {
|
||||
const store = useAssetsStore()
|
||||
mockGetAssetsForNodeType.mockImplementation(
|
||||
() => new Promise((resolve) => setTimeout(() => resolve([]), 100))
|
||||
)
|
||||
|
||||
const promise = store.updateModelsForNodeType('LoraLoader')
|
||||
|
||||
expect(store.modelLoadingByNodeType.get('LoraLoader')).toBe(true)
|
||||
|
||||
await promise
|
||||
|
||||
expect(store.modelLoadingByNodeType.get('LoraLoader')).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
const store = useAssetsStore()
|
||||
const mockError = new Error('Network error')
|
||||
mockGetAssetsForNodeType.mockRejectedValue(mockError)
|
||||
|
||||
await store.updateModelsForNodeType('VAELoader')
|
||||
|
||||
expect(store.modelErrorByNodeType.get('VAELoader')).toBe(mockError)
|
||||
expect(store.modelAssetsByNodeType.get('VAELoader')).toEqual([])
|
||||
expect(store.modelLoadingByNodeType.get('VAELoader')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user