mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-06-11 00:38:37 +00:00
Compare commits
5 Commits
main
...
pysssss/mo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
307e8a0d04 | ||
|
|
e2c9550664 | ||
|
|
a87e107b01 | ||
|
|
c9f1cc42ad | ||
|
|
6cc1f20bd4 |
61
browser_tests/tests/browseModelAssets.spec.ts
Normal file
61
browser_tests/tests/browseModelAssets.spec.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import { expect } from '@playwright/test'
|
||||
|
||||
import type { Asset } from '@comfyorg/ingest-types'
|
||||
import { createCloudAssetsFixture } from '@e2e/fixtures/assetApiFixture'
|
||||
import { STABLE_CHECKPOINT } from '@e2e/fixtures/data/assetFixtures'
|
||||
|
||||
const CLOUD_ASSETS: Asset[] = [STABLE_CHECKPOINT]
|
||||
|
||||
const test = createCloudAssetsFixture(CLOUD_ASSETS)
|
||||
|
||||
test.describe('Browse Model Assets - Use button', { tag: '@cloud' }, () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.settings.setSetting('Comfy.Assets.UseAssetAPI', true)
|
||||
await comfyPage.nodeOps.clearGraph()
|
||||
})
|
||||
|
||||
test.afterEach(async ({ comfyPage }) => {
|
||||
await comfyPage.nodeOps.clearGraph()
|
||||
})
|
||||
|
||||
test('Use button ghost-places a loader populated with the model', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
await comfyPage.command.executeCommand('Comfy.BrowseModelAssets')
|
||||
|
||||
const modal = comfyPage.page.locator(
|
||||
'[data-component-id="AssetBrowserModal"]'
|
||||
)
|
||||
await expect(modal).toBeVisible()
|
||||
|
||||
const card = comfyPage.page.locator(
|
||||
`[data-component-id="AssetCard"][data-asset-id="${STABLE_CHECKPOINT.id}"]`
|
||||
)
|
||||
await expect(card).toBeVisible()
|
||||
await card.getByRole('button', { name: 'Use' }).click()
|
||||
|
||||
// Dialog closes and the ghost is armed; the node is not placed until the
|
||||
// user clicks the canvas.
|
||||
await expect(modal).toBeHidden()
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getGraphNodesCount(), { timeout: 1000 })
|
||||
.toBe(0)
|
||||
|
||||
const canvasBox = (await comfyPage.canvas.boundingBox())!
|
||||
await comfyPage.canvas.click({
|
||||
position: { x: canvasBox.width / 2, y: canvasBox.height / 2 }
|
||||
})
|
||||
|
||||
await expect.poll(() => comfyPage.nodeOps.getGraphNodesCount()).toBe(1)
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getSelectedGraphNodesCount())
|
||||
.toBe(1)
|
||||
|
||||
const [loader] = await comfyPage.nodeOps.getNodeRefsByType(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(loader).toBeDefined()
|
||||
const widget = await loader.getWidgetByName('ckpt_name')
|
||||
expect(await widget.getValue()).toBe(STABLE_CHECKPOINT.name)
|
||||
})
|
||||
})
|
||||
@@ -233,4 +233,64 @@ test.describe('Model library sidebar - empty state', () => {
|
||||
await expect(tab.folderNodes).toHaveCount(0)
|
||||
await expect(tab.leafNodes).toHaveCount(0)
|
||||
})
|
||||
|
||||
test.describe('Model library sidebar - add node', () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.modelLibrary.mockFoldersWithFiles(MOCK_FOLDERS)
|
||||
await comfyPage.setup()
|
||||
await comfyPage.nodeOps.clearGraph()
|
||||
})
|
||||
|
||||
test.afterEach(async ({ comfyPage }) => {
|
||||
await comfyPage.modelLibrary.clearMocks()
|
||||
})
|
||||
|
||||
test('Clicking a model defers creation until placed on the canvas', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const tab = comfyPage.menu.modelLibraryTab
|
||||
await tab.open()
|
||||
await tab.getFolderByLabel('checkpoints').click()
|
||||
await expect(tab.getLeafByLabel('sd_xl_base_1.0')).toBeVisible()
|
||||
|
||||
await tab.getLeafByLabel('sd_xl_base_1.0').click()
|
||||
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getGraphNodesCount(), { timeout: 1000 })
|
||||
.toBe(0)
|
||||
|
||||
const canvasBox = (await comfyPage.canvas.boundingBox())!
|
||||
await comfyPage.canvas.click({
|
||||
position: { x: canvasBox.width / 2, y: canvasBox.height / 2 }
|
||||
})
|
||||
|
||||
await expect.poll(() => comfyPage.nodeOps.getGraphNodesCount()).toBe(1)
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getSelectedGraphNodesCount())
|
||||
.toBe(1)
|
||||
|
||||
const [loader] = await comfyPage.nodeOps.getNodeRefsByType(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(loader).toBeDefined()
|
||||
const widget = await loader.getWidgetByName('ckpt_name')
|
||||
expect(await widget.getValue()).toBe('sd_xl_base_1.0.safetensors')
|
||||
})
|
||||
|
||||
test('Ghost preview shows the model in the loader widget before placing', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const tab = comfyPage.menu.modelLibraryTab
|
||||
await tab.open()
|
||||
await tab.getFolderByLabel('checkpoints').click()
|
||||
await expect(tab.getLeafByLabel('sd_xl_base_1.0')).toBeVisible()
|
||||
|
||||
await tab.getLeafByLabel('sd_xl_base_1.0').click()
|
||||
|
||||
const ghost = comfyPage.page.locator(
|
||||
'[data-node-id="preview-CheckpointLoaderSimple"]'
|
||||
)
|
||||
await expect(ghost).toContainText('sd_xl_base_1.0.safetensors')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -355,7 +355,7 @@ describe('TreeExplorerV2Node', () => {
|
||||
const nodeDiv = getTreeNode(container)
|
||||
await fireEvent.dragStart(nodeDiv)
|
||||
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockData, 'native')
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockData, { mode: 'native' })
|
||||
})
|
||||
|
||||
it('does not call startDrag for folder items on dragstart', async () => {
|
||||
|
||||
@@ -93,6 +93,7 @@
|
||||
|
||||
<NodeTooltip v-if="tooltipEnabled" />
|
||||
<NodeSearchboxPopover ref="nodeSearchboxPopoverRef" />
|
||||
<NodeDragPreview />
|
||||
<VueNodeSwitchPopup />
|
||||
|
||||
<!-- Initialize components after comfyApp is ready. useAbsolutePosition requires
|
||||
@@ -136,6 +137,7 @@ import GraphCanvasMenu from '@/components/graph/GraphCanvasMenu.vue'
|
||||
import LinkOverlayCanvas from '@/components/graph/LinkOverlayCanvas.vue'
|
||||
import NodeTooltip from '@/components/graph/NodeTooltip.vue'
|
||||
import NodeContextMenu from '@/components/graph/NodeContextMenu.vue'
|
||||
import NodeDragPreview from '@/components/graph/NodeDragPreview.vue'
|
||||
import SelectionToolbox from '@/components/graph/SelectionToolbox.vue'
|
||||
import TitleEditor from '@/components/graph/TitleEditor.vue'
|
||||
import NodePropertiesPanel from '@/components/rightSidePanel/RightSidePanel.vue'
|
||||
|
||||
57
src/components/graph/NodeDragPreview.vue
Normal file
57
src/components/graph/NodeDragPreview.vue
Normal file
@@ -0,0 +1,57 @@
|
||||
<template>
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="showGhost && rafPosition"
|
||||
class="pointer-events-none fixed top-0 left-0 z-10000 will-change-transform"
|
||||
:style="{
|
||||
transform: `translate(${rafPosition.x + 12}px, ${rafPosition.y + 12}px)`
|
||||
}"
|
||||
>
|
||||
<div class="origin-top-left scale-50 opacity-80">
|
||||
<LGraphNodePreview
|
||||
:node-def="draggedNode!"
|
||||
:widget-values="pendingWidgetValues"
|
||||
position="relative"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { useMouse, useRafFn } from '@vueuse/core'
|
||||
import { computed, shallowRef, watch } from 'vue'
|
||||
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import LGraphNodePreview from '@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue'
|
||||
|
||||
const { isDragging, draggedNode, pendingWidgetValues } = useNodeDragToCanvas()
|
||||
|
||||
const { x, y, sourceType } = useMouse({ type: 'client' })
|
||||
|
||||
const showGhost = computed(() => Boolean(isDragging.value && draggedNode.value))
|
||||
const rafPosition = shallowRef<{ x: number; y: number }>()
|
||||
|
||||
const { pause, resume } = useRafFn(
|
||||
() => {
|
||||
if (sourceType.value === null) return
|
||||
const pos = rafPosition.value
|
||||
if (pos && pos.x === x.value && pos.y === y.value) return
|
||||
rafPosition.value = { x: x.value, y: y.value }
|
||||
},
|
||||
{ immediate: false }
|
||||
)
|
||||
|
||||
watch(
|
||||
showGhost,
|
||||
(show) => {
|
||||
if (show) {
|
||||
resume()
|
||||
} else {
|
||||
pause()
|
||||
rafPosition.value = undefined
|
||||
}
|
||||
},
|
||||
{ immediate: true }
|
||||
)
|
||||
</script>
|
||||
@@ -14,7 +14,7 @@ const {
|
||||
captureRoot,
|
||||
getRoot,
|
||||
resetRoot,
|
||||
mockAddNodeOnGraph,
|
||||
mockStartDrag,
|
||||
mockGetNodeProvider,
|
||||
mockToggleNodeOnEvent,
|
||||
mockRefreshModelFolder,
|
||||
@@ -29,7 +29,7 @@ const {
|
||||
resetRoot: () => {
|
||||
capturedRoot = null
|
||||
},
|
||||
mockAddNodeOnGraph: vi.fn(),
|
||||
mockStartDrag: vi.fn(),
|
||||
mockGetNodeProvider: vi.fn(),
|
||||
mockToggleNodeOnEvent: vi.fn(),
|
||||
mockRefreshModelFolder: vi.fn().mockResolvedValue(undefined),
|
||||
@@ -37,8 +37,8 @@ const {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/services/litegraphService', () => ({
|
||||
useLitegraphService: () => ({ addNodeOnGraph: mockAddNodeOnGraph })
|
||||
vi.mock('@/composables/node/useNodeDragToCanvas', () => ({
|
||||
useNodeDragToCanvas: () => ({ startDrag: mockStartDrag })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
@@ -173,16 +173,13 @@ describe('ModelLibrarySidebarTab', () => {
|
||||
expect(screen.getByTestId('search-input')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles model click and adds node to graph', async () => {
|
||||
it('starts a ghost drag carrying the widget value to fill on placement', async () => {
|
||||
const mockNodeDef = { name: 'CheckpointLoaderSimple' }
|
||||
const mockWidget = { name: 'ckpt_name', value: '' }
|
||||
const mockGraphNode = { widgets: [mockWidget] }
|
||||
|
||||
mockGetNodeProvider.mockReturnValue({
|
||||
nodeDef: mockNodeDef,
|
||||
key: 'ckpt_name'
|
||||
})
|
||||
mockAddNodeOnGraph.mockReturnValue(mockGraphNode)
|
||||
|
||||
renderComponent()
|
||||
await nextTick()
|
||||
@@ -198,8 +195,9 @@ describe('ModelLibrarySidebarTab', () => {
|
||||
await modelLeaf?.handleClick?.(mockEvent)
|
||||
|
||||
expect(mockGetNodeProvider).toHaveBeenCalledWith('checkpoints')
|
||||
expect(mockAddNodeOnGraph).toHaveBeenCalledWith(mockNodeDef)
|
||||
expect(mockWidget.value).toBe('model.safetensors')
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockNodeDef, {
|
||||
widgetValues: { ckpt_name: 'model.safetensors' }
|
||||
})
|
||||
})
|
||||
|
||||
it('toggles folder expansion on click', async () => {
|
||||
|
||||
@@ -63,10 +63,9 @@ import SidebarTabTemplate from '@/components/sidebar/tabs/SidebarTabTemplate.vue
|
||||
import ElectronDownloadItems from '@/components/sidebar/tabs/modelLibrary/ElectronDownloadItems.vue'
|
||||
import ModelTreeLeaf from '@/components/sidebar/tabs/modelLibrary/ModelTreeLeaf.vue'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { startModelLoaderDrag } from '@/composables/node/startModelNodeDragFromAsset'
|
||||
import { useTreeExpansion } from '@/composables/useTreeExpansion'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { withNodeAddSource } from '@/platform/telemetry/nodeAdded/nodeAddSource'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
|
||||
import type { ComfyModelDef, ModelFolder } from '@/stores/modelStore'
|
||||
import { ResourceState, useModelStore } from '@/stores/modelStore'
|
||||
@@ -156,15 +155,7 @@ const renderedRoot = computed<TreeExplorerNode<ModelOrFolder>>(() => {
|
||||
if (this.leaf && model) {
|
||||
const provider = modelToNodeStore.getNodeProvider(model.directory)
|
||||
if (provider) {
|
||||
const graphNode = withNodeAddSource('sidebar_drag', () =>
|
||||
useLitegraphService().addNodeOnGraph(provider.nodeDef)
|
||||
)
|
||||
const widget = graphNode?.widgets?.find(
|
||||
(widget) => widget.name === provider.key
|
||||
)
|
||||
if (widget) {
|
||||
widget.value = model.file_name
|
||||
}
|
||||
startModelLoaderDrag(provider, model.file_name)
|
||||
}
|
||||
} else {
|
||||
toggleNodeOnEvent(e, node)
|
||||
|
||||
@@ -31,11 +31,8 @@ vi.mock('@/composables/node/useNodeDragToCanvas', () => ({
|
||||
useNodeDragToCanvas: () => ({
|
||||
isDragging: { value: false },
|
||||
draggedNode: { value: null },
|
||||
cursorPosition: { value: { x: 0, y: 0 } },
|
||||
startDrag: vi.fn(),
|
||||
cancelDrag: vi.fn(),
|
||||
setupGlobalListeners: vi.fn(),
|
||||
cleanupGlobalListeners: vi.fn()
|
||||
cancelDrag: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
|
||||
@@ -115,7 +115,6 @@
|
||||
</div>
|
||||
</template>
|
||||
<template #body>
|
||||
<NodeDragPreview />
|
||||
<div class="flex h-full flex-col">
|
||||
<div
|
||||
v-if="hasNoMatches"
|
||||
@@ -215,7 +214,6 @@ import type {
|
||||
import AllNodesPanel from './nodeLibrary/AllNodesPanel.vue'
|
||||
import BlueprintsPanel from './nodeLibrary/BlueprintsPanel.vue'
|
||||
import EssentialNodesPanel from './nodeLibrary/EssentialNodesPanel.vue'
|
||||
import NodeDragPreview from './nodeLibrary/NodeDragPreview.vue'
|
||||
import SidebarTabTemplate from './SidebarTabTemplate.vue'
|
||||
|
||||
const { flags } = useFeatureFlags()
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
<template>
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="isDragging && draggedNode && showPreview"
|
||||
class="pointer-events-none fixed z-10000"
|
||||
:style="{
|
||||
left: `${previewPosition.x + 12}px`,
|
||||
top: `${previewPosition.y + 12}px`
|
||||
}"
|
||||
>
|
||||
<div class="origin-top-left scale-50 opacity-80">
|
||||
<LGraphNodePreview :node-def="draggedNode" position="relative" />
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, onUnmounted, ref } from 'vue'
|
||||
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import LGraphNodePreview from '@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue'
|
||||
|
||||
const {
|
||||
isDragging,
|
||||
draggedNode,
|
||||
cursorPosition,
|
||||
dragMode,
|
||||
setupGlobalListeners,
|
||||
cleanupGlobalListeners
|
||||
} = useNodeDragToCanvas()
|
||||
|
||||
const nativeDragPosition = ref({ x: 0, y: 0 })
|
||||
|
||||
const previewPosition = computed(() => {
|
||||
if (dragMode.value === 'native') {
|
||||
return nativeDragPosition.value
|
||||
}
|
||||
return cursorPosition.value
|
||||
})
|
||||
|
||||
const showPreview = computed(() => {
|
||||
if (dragMode.value === 'native') {
|
||||
return nativeDragPosition.value.x > 0 || nativeDragPosition.value.y > 0
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
function handleDrag(e: DragEvent) {
|
||||
if (e.clientX === 0 && e.clientY === 0) return
|
||||
nativeDragPosition.value = { x: e.clientX, y: e.clientY }
|
||||
}
|
||||
|
||||
function handleDragEnd() {
|
||||
nativeDragPosition.value = { x: 0, y: 0 }
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
setupGlobalListeners()
|
||||
document.addEventListener('drag', handleDrag)
|
||||
document.addEventListener('dragend', handleDragEnd)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
cleanupGlobalListeners()
|
||||
document.removeEventListener('drag', handleDrag)
|
||||
document.removeEventListener('dragend', handleDragEnd)
|
||||
})
|
||||
</script>
|
||||
73
src/composables/node/startModelNodeDragFromAsset.test.ts
Normal file
73
src/composables/node/startModelNodeDragFromAsset.test.ts
Normal file
@@ -0,0 +1,73 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { startModelNodeDragFromAsset } from '@/composables/node/startModelNodeDragFromAsset'
|
||||
|
||||
const { mockStartDrag, mockGetNodeProvider } = vi.hoisted(() => ({
|
||||
mockStartDrag: vi.fn(),
|
||||
mockGetNodeProvider: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/node/useNodeDragToCanvas', () => ({
|
||||
useNodeDragToCanvas: () => ({ startDrag: mockStartDrag })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({ getNodeProvider: mockGetNodeProvider })
|
||||
}))
|
||||
|
||||
function createAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'sd_xl_base_1.0.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: { filename: 'sd_xl_base_1.0.safetensors' },
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
describe('startModelNodeDragFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
|
||||
it('starts a ghost drag for the resolved node carrying the widget value', () => {
|
||||
const nodeDef = { name: 'CheckpointLoaderSimple' }
|
||||
mockGetNodeProvider.mockReturnValue({ nodeDef, key: 'ckpt_name' })
|
||||
|
||||
const error = startModelNodeDragFromAsset(createAsset())
|
||||
|
||||
expect(error).toBeUndefined()
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(nodeDef, {
|
||||
widgetValues: { ckpt_name: 'sd_xl_base_1.0.safetensors' }
|
||||
})
|
||||
})
|
||||
|
||||
it('carries no widget value when the provider has no key', () => {
|
||||
const nodeDef = { name: 'FL_ChatterboxVC' }
|
||||
mockGetNodeProvider.mockReturnValue({ nodeDef, key: '' })
|
||||
|
||||
startModelNodeDragFromAsset(
|
||||
createAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(nodeDef, {
|
||||
widgetValues: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('returns the resolution error and does not start a drag for an invalid asset', () => {
|
||||
mockGetNodeProvider.mockReturnValue(null)
|
||||
|
||||
const error = startModelNodeDragFromAsset(createAsset())
|
||||
|
||||
expect(error?.code).toBe('NO_PROVIDER')
|
||||
expect(mockStartDrag).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
35
src/composables/node/startModelNodeDragFromAsset.ts
Normal file
35
src/composables/node/startModelNodeDragFromAsset.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { resolveModelNodeFromAsset } from '@/platform/assets/utils/resolveModelNodeFromAsset'
|
||||
import type { ResolveModelNodeError } from '@/platform/assets/utils/resolveModelNodeFromAsset'
|
||||
import type { ModelNodeProvider } from '@/stores/modelToNodeStore'
|
||||
|
||||
/**
|
||||
* Arms a ghost drag for a model loader node. Providers with no widget key
|
||||
* (auto-load nodes) start the drag without widget values.
|
||||
*/
|
||||
export function startModelLoaderDrag(
|
||||
provider: ModelNodeProvider,
|
||||
filename: string
|
||||
) {
|
||||
const widgetValues = provider.key ? { [provider.key]: filename } : undefined
|
||||
useNodeDragToCanvas().startDrag(provider.nodeDef, { widgetValues })
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a ghost drag for the model loader node described by an asset. The
|
||||
* node is created where the user next clicks the canvas, with the asset's
|
||||
* filename written into the loader widget.
|
||||
*
|
||||
* @returns the resolution error when the asset cannot be mapped to a node,
|
||||
* otherwise `undefined`.
|
||||
*/
|
||||
export function startModelNodeDragFromAsset(
|
||||
asset: AssetItem
|
||||
): ResolveModelNodeError | undefined {
|
||||
const resolved = resolveModelNodeFromAsset(asset)
|
||||
if (!resolved.success) return resolved.error
|
||||
|
||||
const { provider, filename } = resolved.value
|
||||
startModelLoaderDrag(provider, filename)
|
||||
}
|
||||
@@ -7,7 +7,8 @@ const {
|
||||
mockAddNodeOnGraph,
|
||||
mockConvertEventToCanvasOffset,
|
||||
mockSelectItems,
|
||||
mockCanvas
|
||||
mockCanvas,
|
||||
mockToastAdd
|
||||
} = vi.hoisted(() => {
|
||||
const mockConvertEventToCanvasOffset = vi.fn()
|
||||
const mockSelectItems = vi.fn()
|
||||
@@ -15,6 +16,7 @@ const {
|
||||
mockAddNodeOnGraph: vi.fn(),
|
||||
mockConvertEventToCanvasOffset,
|
||||
mockSelectItems,
|
||||
mockToastAdd: vi.fn(),
|
||||
mockCanvas: {
|
||||
canvas: {
|
||||
getBoundingClientRect: vi.fn()
|
||||
@@ -37,6 +39,12 @@ vi.mock('@/services/litegraphService', () => ({
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: vi.fn(() => ({ add: mockToastAdd }))
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n', () => ({ t: (key: string) => key }))
|
||||
|
||||
describe('useNodeDragToCanvas', () => {
|
||||
let useNodeDragToCanvas: typeof UseNodeDragToCanvasType
|
||||
|
||||
@@ -54,8 +62,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
const { cleanupGlobalListeners } = useNodeDragToCanvas()
|
||||
cleanupGlobalListeners()
|
||||
const { cancelDrag } = useNodeDragToCanvas()
|
||||
cancelDrag()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
@@ -71,22 +79,6 @@ describe('useNodeDragToCanvas', () => {
|
||||
expect(isDragging.value).toBe(true)
|
||||
expect(draggedNode.value).toBe(mockNodeDef)
|
||||
})
|
||||
|
||||
it('should set dragMode to click by default', () => {
|
||||
const { dragMode, startDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(dragMode.value).toBe('click')
|
||||
})
|
||||
|
||||
it('should set dragMode to native when specified', () => {
|
||||
const { dragMode, startDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
|
||||
expect(dragMode.value).toBe('native')
|
||||
})
|
||||
})
|
||||
|
||||
describe('cancelDrag', () => {
|
||||
@@ -102,30 +94,15 @@ describe('useNodeDragToCanvas', () => {
|
||||
expect(isDragging.value).toBe(false)
|
||||
expect(draggedNode.value).toBeNull()
|
||||
})
|
||||
|
||||
it('should reset dragMode to click', () => {
|
||||
const { dragMode, startDrag, cancelDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
expect(dragMode.value).toBe('native')
|
||||
|
||||
cancelDrag()
|
||||
|
||||
expect(dragMode.value).toBe('click')
|
||||
})
|
||||
})
|
||||
|
||||
describe('setupGlobalListeners', () => {
|
||||
it('should add event listeners to document', () => {
|
||||
describe('drag listener lifecycle', () => {
|
||||
it('should attach document listeners on startDrag', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
const { setupGlobalListeners } = useNodeDragToCanvas()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointermove',
|
||||
expect.any(Function)
|
||||
)
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerdown',
|
||||
expect.any(Function),
|
||||
@@ -142,35 +119,53 @@ describe('useNodeDragToCanvas', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('should only setup listeners once', () => {
|
||||
it('should not attach drag listeners until a drag starts', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
const { setupGlobalListeners } = useNodeDragToCanvas()
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
expect(addEventListenerSpy).not.toHaveBeenCalledWith(
|
||||
'pointerup',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
expect(addEventListenerSpy).not.toHaveBeenCalledWith(
|
||||
'keydown',
|
||||
expect.any(Function)
|
||||
)
|
||||
})
|
||||
|
||||
it('should detach document listeners on cancelDrag', () => {
|
||||
const removeEventListenerSpy = vi.spyOn(document, 'removeEventListener')
|
||||
const { startDrag, cancelDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
cancelDrag()
|
||||
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerdown',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerup',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
})
|
||||
|
||||
it('should only attach listeners once across re-arms', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
const callCount = addEventListenerSpy.mock.calls.length
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(addEventListenerSpy.mock.calls.length).toBe(callCount)
|
||||
})
|
||||
})
|
||||
|
||||
describe('cursorPosition', () => {
|
||||
it('should update on pointermove', () => {
|
||||
const { cursorPosition, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
|
||||
const pointerEvent = new PointerEvent('pointermove', {
|
||||
clientX: 100,
|
||||
clientY: 200
|
||||
})
|
||||
document.dispatchEvent(pointerEvent)
|
||||
|
||||
expect(cursorPosition.value).toEqual({ x: 100, y: 200 })
|
||||
})
|
||||
})
|
||||
|
||||
describe('endDrag behavior', () => {
|
||||
it('should add node when pointer is over canvas', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
@@ -181,9 +176,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
const pointerEvent = new PointerEvent('pointerup', {
|
||||
@@ -206,10 +199,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
bottom: 500
|
||||
})
|
||||
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
const pointerEvent = new PointerEvent('pointerup', {
|
||||
@@ -224,10 +214,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should cancel drag on Escape key', () => {
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(isDragging.value).toBe(true)
|
||||
@@ -239,10 +226,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should not cancel drag on other keys', () => {
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
const keyEvent = new KeyboardEvent('keydown', { key: 'Enter' })
|
||||
@@ -262,8 +246,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
const placedNode = { id: 1 }
|
||||
mockAddNodeOnGraph.mockReturnValue(placedNode)
|
||||
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
document.dispatchEvent(
|
||||
@@ -277,6 +260,97 @@ describe('useNodeDragToCanvas', () => {
|
||||
expect(mockSelectItems).toHaveBeenCalledWith([placedNode])
|
||||
})
|
||||
|
||||
it('should apply the requested widget values to the placed node', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
right: 500,
|
||||
top: 0,
|
||||
bottom: 500
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
const widget = { name: 'ckpt_name', value: '' }
|
||||
mockAddNodeOnGraph.mockReturnValue({ id: 1, widgets: [widget] })
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, {
|
||||
widgetValues: { ckpt_name: 'model.safetensors' }
|
||||
})
|
||||
|
||||
document.dispatchEvent(
|
||||
new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
clientY: 250,
|
||||
bubbles: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(widget.value).toBe('model.safetensors')
|
||||
})
|
||||
|
||||
it('should still place the node when a requested widget is missing', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
right: 500,
|
||||
top: 0,
|
||||
bottom: 500
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
const placedNode = { id: 1, widgets: [] }
|
||||
mockAddNodeOnGraph.mockReturnValue(placedNode)
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {})
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, {
|
||||
widgetValues: { ckpt_name: 'model.safetensors' }
|
||||
})
|
||||
|
||||
document.dispatchEvent(
|
||||
new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
clientY: 250,
|
||||
bubbles: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockSelectItems).toHaveBeenCalledWith([placedNode])
|
||||
expect(mockToastAdd).not.toHaveBeenCalled()
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('ckpt_name')
|
||||
)
|
||||
})
|
||||
|
||||
it('should show an error toast when the graph fails to add the node', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
right: 500,
|
||||
top: 0,
|
||||
bottom: 500
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
mockAddNodeOnGraph.mockReturnValue(null)
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
document.dispatchEvent(
|
||||
new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
clientY: 250,
|
||||
bubbles: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'error',
|
||||
detail: 'assetBrowser.failedToCreateNode'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should not call selectItems when graph returns no node', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
@@ -286,9 +360,9 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
mockAddNodeOnGraph.mockReturnValue(null)
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
document.dispatchEvent(
|
||||
@@ -311,11 +385,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
const pointerEvent = new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
@@ -341,7 +412,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
expect(mockAddNodeOnGraph).toHaveBeenCalledWith(mockNodeDef, {
|
||||
@@ -359,7 +430,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
|
||||
const { startDrag, handleNativeDrop, isDragging } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
handleNativeDrop(600, 250)
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
@@ -377,7 +448,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'click')
|
||||
startDrag(mockNodeDef)
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
@@ -392,14 +463,12 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([200, 200])
|
||||
|
||||
const { startDrag, handleNativeDrop, isDragging, dragMode } =
|
||||
useNodeDragToCanvas()
|
||||
const { startDrag, handleNativeDrop, isDragging } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
expect(isDragging.value).toBe(false)
|
||||
expect(dragMode.value).toBe('click')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -426,31 +495,29 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should stop propagation when in click-drag mode over canvas', () => {
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(dispatchPointerDown(250, 250)).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not stop propagation when not dragging', () => {
|
||||
const { setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
it('should not stop propagation once the drag is cancelled', () => {
|
||||
const { startDrag, cancelDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
cancelDrag()
|
||||
|
||||
expect(dispatchPointerDown(250, 250)).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not stop propagation in native drag mode', () => {
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
expect(dispatchPointerDown(250, 250)).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not stop propagation when pointer is outside canvas', () => {
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(dispatchPointerDown(600, 250)).not.toHaveBeenCalled()
|
||||
@@ -477,10 +544,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
}
|
||||
|
||||
it('should prefer tracked drag position over dragend coordinates', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
fireDrag(250, 250)
|
||||
// dragend supplies a bad position (the Firefox bug); the tracked one
|
||||
@@ -494,10 +559,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should ignore drag events with (0, 0)', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
fireDrag(250, 250)
|
||||
fireDrag(0, 0)
|
||||
@@ -510,10 +573,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should fall back to dragend coordinates when no drag fired', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
@@ -523,32 +584,14 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should ignore dragover events fired before startDrag', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
|
||||
fireDrag(250, 250)
|
||||
startDrag(mockNodeDef, 'native')
|
||||
handleNativeDrop(300, 300)
|
||||
|
||||
expect(mockConvertEventToCanvasOffset).toHaveBeenCalledWith({
|
||||
clientX: 300,
|
||||
clientY: 300
|
||||
})
|
||||
})
|
||||
|
||||
it('should clear tracked position between drags', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
fireDrag(250, 250)
|
||||
handleNativeDrop(1505, 102)
|
||||
|
||||
// Second drag - no drag events, so we should fall back to args.
|
||||
startDrag(mockNodeDef, 'native')
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
handleNativeDrop(300, 300)
|
||||
|
||||
expect(mockConvertEventToCanvasOffset).toHaveBeenLastCalledWith({
|
||||
|
||||
@@ -1,23 +1,29 @@
|
||||
import { ref, shallowRef } from 'vue'
|
||||
|
||||
import { t } from '@/i18n'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { withNodeAddSource } from '@/platform/telemetry/nodeAdded/nodeAddSource'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import type { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
|
||||
|
||||
type DragMode = 'click' | 'native'
|
||||
type WidgetValues = Record<string, string>
|
||||
type Position = { x: number; y: number }
|
||||
|
||||
interface StartDragOptions {
|
||||
mode?: DragMode
|
||||
widgetValues?: WidgetValues
|
||||
}
|
||||
|
||||
const isDragging = ref(false)
|
||||
const draggedNode = shallowRef<ComfyNodeDefImpl | null>(null)
|
||||
const cursorPosition = ref({ x: 0, y: 0 })
|
||||
const dragMode = ref<DragMode>('click')
|
||||
const lastNativeDragPosition = shallowRef<{ x: number; y: number }>()
|
||||
const lastNativeDragPosition = shallowRef<Position>()
|
||||
const pendingWidgetValues = shallowRef<WidgetValues>()
|
||||
let listenersSetup = false
|
||||
|
||||
function updatePosition(e: PointerEvent) {
|
||||
cursorPosition.value = { x: e.clientX, y: e.clientY }
|
||||
}
|
||||
|
||||
// Firefox dragend can report stale clientX/Y and `drag` can fire with
|
||||
// (0, 0). dragover on the target reliably reports real client coords.
|
||||
// https://bugzilla.mozilla.org/show_bug.cgi?id=1773886
|
||||
@@ -27,11 +33,15 @@ function trackNativeDragPosition(e: DragEvent) {
|
||||
lastNativeDragPosition.value = { x: e.clientX, y: e.clientY }
|
||||
}
|
||||
|
||||
function cancelDrag() {
|
||||
isDragging.value = false
|
||||
draggedNode.value = null
|
||||
dragMode.value = 'click'
|
||||
lastNativeDragPosition.value = undefined
|
||||
function applyWidgetValues(node: LGraphNode, values: WidgetValues) {
|
||||
for (const [name, value] of Object.entries(values)) {
|
||||
const widget = node.widgets?.find((w) => w.name === name)
|
||||
if (!widget) {
|
||||
console.error(`Widget ${name} not found on node ${node.type}`)
|
||||
continue
|
||||
}
|
||||
widget.value = value
|
||||
}
|
||||
}
|
||||
|
||||
function isOverCanvas(clientX: number, clientY: number): boolean {
|
||||
@@ -62,7 +72,19 @@ function addNodeAtPosition(clientX: number, clientY: number): boolean {
|
||||
const node = withNodeAddSource('sidebar_drag', () =>
|
||||
useLitegraphService().addNodeOnGraph(nodeDef, { pos })
|
||||
)
|
||||
if (node) canvas.selectItems([node])
|
||||
if (!node) {
|
||||
console.error(`Failed to add node to graph: ${nodeDef.name}`)
|
||||
useToastStore().add({
|
||||
severity: 'error',
|
||||
summary: t('g.error'),
|
||||
detail: t('assetBrowser.failedToCreateNode')
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
if (pendingWidgetValues.value)
|
||||
applyWidgetValues(node, pendingWidgetValues.value)
|
||||
canvas.selectItems([node])
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -92,7 +114,6 @@ function setupGlobalListeners() {
|
||||
if (listenersSetup) return
|
||||
listenersSetup = true
|
||||
|
||||
document.addEventListener('pointermove', updatePosition)
|
||||
document.addEventListener('pointerdown', blockCommitPointerDown, true)
|
||||
document.addEventListener('pointerup', endDrag, true)
|
||||
document.addEventListener('keydown', handleKeydown)
|
||||
@@ -103,22 +124,31 @@ function cleanupGlobalListeners() {
|
||||
if (!listenersSetup) return
|
||||
listenersSetup = false
|
||||
|
||||
document.removeEventListener('pointermove', updatePosition)
|
||||
document.removeEventListener('pointerdown', blockCommitPointerDown, true)
|
||||
document.removeEventListener('pointerup', endDrag, true)
|
||||
document.removeEventListener('keydown', handleKeydown)
|
||||
document.removeEventListener('dragover', trackNativeDragPosition)
|
||||
}
|
||||
|
||||
if (isDragging.value && dragMode.value === 'click') {
|
||||
cancelDrag()
|
||||
}
|
||||
function cancelDrag() {
|
||||
isDragging.value = false
|
||||
draggedNode.value = null
|
||||
dragMode.value = 'click'
|
||||
lastNativeDragPosition.value = undefined
|
||||
pendingWidgetValues.value = undefined
|
||||
cleanupGlobalListeners()
|
||||
}
|
||||
|
||||
export function useNodeDragToCanvas() {
|
||||
function startDrag(nodeDef: ComfyNodeDefImpl, mode: DragMode = 'click') {
|
||||
function startDrag(
|
||||
nodeDef: ComfyNodeDefImpl,
|
||||
{ mode = 'click', widgetValues }: StartDragOptions = {}
|
||||
) {
|
||||
isDragging.value = true
|
||||
draggedNode.value = nodeDef
|
||||
dragMode.value = mode
|
||||
pendingWidgetValues.value = widgetValues
|
||||
setupGlobalListeners()
|
||||
}
|
||||
|
||||
function handleNativeDrop(clientX: number, clientY: number) {
|
||||
@@ -134,12 +164,9 @@ export function useNodeDragToCanvas() {
|
||||
return {
|
||||
isDragging,
|
||||
draggedNode,
|
||||
cursorPosition,
|
||||
dragMode,
|
||||
pendingWidgetValues,
|
||||
startDrag,
|
||||
cancelDrag,
|
||||
handleNativeDrop,
|
||||
setupGlobalListeners,
|
||||
cleanupGlobalListeners
|
||||
handleNativeDrop
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +124,9 @@ describe('useNodePreviewAndDrag', () => {
|
||||
|
||||
expect(result.isDragging.value).toBe(true)
|
||||
expect(result.isHovered.value).toBe(false)
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockNodeDef, 'native')
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockNodeDef, {
|
||||
mode: 'native'
|
||||
})
|
||||
expect(mockDataTransfer.effectAllowed).toBe('copy')
|
||||
expect(mockDataTransfer.setData).toHaveBeenCalledWith(
|
||||
'application/x-comfy-node',
|
||||
|
||||
@@ -112,7 +112,7 @@ export function useNodePreviewAndDrag(
|
||||
isDragging.value = true
|
||||
isHovered.value = false
|
||||
|
||||
startDrag(nodeDef.value, 'native')
|
||||
startDrag(nodeDef.value, { mode: 'native' })
|
||||
|
||||
if (e.dataTransfer) {
|
||||
e.dataTransfer.effectAllowed = 'copy'
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useCurrentUser } from '@/composables/auth/useCurrentUser'
|
||||
import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import { useSelectedLiteGraphItems } from '@/composables/canvas/useSelectedLiteGraphItems'
|
||||
import { useSubgraphOperations } from '@/composables/graph/useSubgraphOperations'
|
||||
import { startModelNodeDragFromAsset } from '@/composables/node/startModelNodeDragFromAsset'
|
||||
import { useExternalLink } from '@/composables/useExternalLink'
|
||||
import { useModelSelectorDialog } from '@/composables/useModelSelectorDialog'
|
||||
import {
|
||||
@@ -20,7 +21,6 @@ import {
|
||||
import type { Point } from '@/lib/litegraph/src/litegraph'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog'
|
||||
import { createModelNodeFromAsset } from '@/platform/assets/utils/createModelNodeFromAsset'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { buildSupportUrl } from '@/platform/support/config'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
@@ -1305,14 +1305,14 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
assetType: 'models',
|
||||
title: t('sideToolbar.modelLibrary'),
|
||||
onAssetSelected: (asset) => {
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
if (!result.success) {
|
||||
const error = startModelNodeDragFromAsset(asset)
|
||||
if (error) {
|
||||
toastStore.add({
|
||||
severity: 'error',
|
||||
summary: t('g.error'),
|
||||
detail: t('assetBrowser.failedToCreateNode')
|
||||
})
|
||||
console.error('Node creation failed:', result.error)
|
||||
console.error('Node creation failed:', error)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,439 +0,0 @@
|
||||
// oxlint-disable no-misused-spread
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { markRaw } from 'vue'
|
||||
import type { Raw } from 'vue'
|
||||
|
||||
import type { LGraphNode, Subgraph } from '@/lib/litegraph/src/litegraph'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type * as LitegraphModule from '@/lib/litegraph/src/litegraph'
|
||||
import type * as ModelToNodeStoreModule from '@/stores/modelToNodeStore'
|
||||
import type * as WorkflowStoreModule from '@/platform/workflow/management/stores/workflowStore'
|
||||
import type * as LitegraphServiceModule from '@/services/litegraphService'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { createModelNodeFromAsset } from '@/platform/assets/utils/createModelNodeFromAsset'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
apiURL: vi.fn((path: string) => `http://localhost:8188${path}`)
|
||||
}
|
||||
}))
|
||||
vi.mock('@/stores/modelToNodeStore', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof ModelToNodeStoreModule>()
|
||||
return {
|
||||
...actual,
|
||||
useModelToNodeStore: vi.fn()
|
||||
}
|
||||
})
|
||||
vi.mock(
|
||||
'@/platform/workflow/management/stores/workflowStore',
|
||||
async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof WorkflowStoreModule>()
|
||||
return {
|
||||
...actual,
|
||||
useWorkflowStore: vi.fn()
|
||||
}
|
||||
}
|
||||
)
|
||||
vi.mock('@/services/litegraphService', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof LitegraphServiceModule>()
|
||||
return {
|
||||
...actual,
|
||||
useLitegraphService: vi.fn()
|
||||
}
|
||||
})
|
||||
vi.mock('@/lib/litegraph/src/litegraph', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof LitegraphModule>()
|
||||
return {
|
||||
...actual,
|
||||
LiteGraph: {
|
||||
...actual.LiteGraph,
|
||||
createNode: vi.fn()
|
||||
}
|
||||
}
|
||||
})
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
canvas: {
|
||||
graph: {
|
||||
add: vi.fn()
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: {
|
||||
filename: 'models/checkpoints/test-model.safetensors'
|
||||
},
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
async function createMockNode(overrides?: {
|
||||
widgetName?: string
|
||||
widgetValue?: string
|
||||
hasWidgets?: boolean
|
||||
}): Promise<LGraphNode> {
|
||||
const {
|
||||
widgetName = 'ckpt_name',
|
||||
widgetValue = '',
|
||||
hasWidgets = true
|
||||
} = overrides || {}
|
||||
|
||||
const { LGraphNode: ActualLGraphNode } = await vi.importActual<
|
||||
typeof LitegraphModule
|
||||
>('@/lib/litegraph/src/litegraph')
|
||||
|
||||
if (!hasWidgets) {
|
||||
return Object.create(ActualLGraphNode.prototype)
|
||||
}
|
||||
|
||||
type Widget = NonNullable<LGraphNode['widgets']>[number]
|
||||
const widget: Pick<Widget, 'name' | 'value' | 'type' | 'options' | 'y'> = {
|
||||
name: widgetName,
|
||||
value: widgetValue,
|
||||
type: 'string',
|
||||
options: {},
|
||||
y: 0
|
||||
}
|
||||
|
||||
return Object.create(ActualLGraphNode.prototype, {
|
||||
widgets: { value: [widget], writable: true }
|
||||
})
|
||||
}
|
||||
function createMockNodeProvider(
|
||||
overrides: {
|
||||
nodeDef?: { name: string; display_name: string }
|
||||
key?: string
|
||||
} = {}
|
||||
) {
|
||||
return {
|
||||
nodeDef: {
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint',
|
||||
...overrides.nodeDef
|
||||
},
|
||||
key: overrides.key ?? 'ckpt_name'
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Configures all mocked dependencies with sensible defaults.
|
||||
* Uses semantic parameters for clearer test intent.
|
||||
* For error paths or edge cases, pass null values or specific overrides.
|
||||
*/
|
||||
async function setupMocks(
|
||||
overrides: {
|
||||
nodeProvider?: ReturnType<typeof createMockNodeProvider> | null
|
||||
canvasCenter?: [number, number]
|
||||
activeSubgraph?: Raw<Subgraph>
|
||||
createdNode?: Awaited<ReturnType<typeof createMockNode>> | null
|
||||
} = {}
|
||||
) {
|
||||
const {
|
||||
nodeProvider = createMockNodeProvider(),
|
||||
canvasCenter = [100, 200],
|
||||
activeSubgraph = undefined,
|
||||
createdNode = await createMockNode()
|
||||
} = overrides
|
||||
|
||||
vi.mocked(useModelToNodeStore).mockReturnValue({
|
||||
...useModelToNodeStore(),
|
||||
getNodeProvider: vi.fn().mockReturnValue(nodeProvider)
|
||||
})
|
||||
|
||||
vi.mocked(useLitegraphService).mockReturnValue({
|
||||
...useLitegraphService(),
|
||||
getCanvasCenter: vi.fn().mockReturnValue(canvasCenter)
|
||||
})
|
||||
|
||||
vi.mocked(useWorkflowStore).mockReturnValue({
|
||||
...useWorkflowStore(),
|
||||
activeSubgraph,
|
||||
isSubgraphActive: !!activeSubgraph
|
||||
})
|
||||
vi.mocked(LiteGraph.createNode).mockReturnValue(createdNode)
|
||||
}
|
||||
describe('createModelNodeFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
})
|
||||
describe('when creating nodes from valid assets', () => {
|
||||
it('should create the appropriate loader node for the asset category', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(
|
||||
vi.mocked(useModelToNodeStore)().getNodeProvider
|
||||
).toHaveBeenCalledWith('checkpoints')
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [100, 200] }
|
||||
)
|
||||
}
|
||||
})
|
||||
it('should place node at canvas center by default', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks({
|
||||
canvasCenter: [150, 250]
|
||||
})
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(
|
||||
vi.mocked(useLitegraphService)().getCanvasCenter
|
||||
).toHaveBeenCalled()
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [150, 250] }
|
||||
)
|
||||
})
|
||||
it('should place node at specified position when position is provided', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
const result = createModelNodeFromAsset(asset, { position: [300, 400] })
|
||||
expect(result.success).toBe(true)
|
||||
expect(
|
||||
vi.mocked(useLitegraphService)().getCanvasCenter
|
||||
).not.toHaveBeenCalled()
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [300, 400] }
|
||||
)
|
||||
})
|
||||
it('should populate the loader widget with the asset file path', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe(
|
||||
'models/checkpoints/test-model.safetensors'
|
||||
)
|
||||
})
|
||||
it('should add node to root graph when no subgraph is active', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(vi.mocked(app).canvas.graph!.add).toHaveBeenCalledWith(mockNode)
|
||||
})
|
||||
it('should fallback to asset.metadata.filename when user_metadata.filename missing', async () => {
|
||||
const asset = createMockAsset({
|
||||
user_metadata: {},
|
||||
metadata: { filename: 'models/checkpoints/from-metadata.safetensors' }
|
||||
})
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe(
|
||||
'models/checkpoints/from-metadata.safetensors'
|
||||
)
|
||||
})
|
||||
it('should fallback to asset.name when both filename sources missing', async () => {
|
||||
const asset = createMockAsset({
|
||||
user_metadata: {},
|
||||
metadata: undefined
|
||||
})
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe('test-model.safetensors')
|
||||
})
|
||||
it('should add node to active subgraph when present', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
const { Subgraph } = await vi.importActual<typeof LitegraphModule>(
|
||||
'@/lib/litegraph/src/litegraph'
|
||||
)
|
||||
const mockSubgraph = markRaw(
|
||||
Object.create(Subgraph.prototype, {
|
||||
add: { value: vi.fn() }
|
||||
})
|
||||
)
|
||||
await setupMocks({
|
||||
createdNode: mockNode,
|
||||
activeSubgraph: mockSubgraph
|
||||
})
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockSubgraph.add).toHaveBeenCalledWith(mockNode)
|
||||
expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled()
|
||||
})
|
||||
it('should succeed when provider has empty key (auto-load nodes)', async () => {
|
||||
const asset = createMockAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
const nodeProvider = createMockNodeProvider({
|
||||
nodeDef: {
|
||||
name: 'FL_ChatterboxVC',
|
||||
display_name: 'FL Chatterbox VC'
|
||||
},
|
||||
key: ''
|
||||
})
|
||||
await setupMocks({ createdNode: mockNode, nodeProvider })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(vi.mocked(app).canvas.graph!.add).toHaveBeenCalledWith(mockNode)
|
||||
})
|
||||
})
|
||||
describe('when asset data is incomplete or invalid', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
it.for([
|
||||
{
|
||||
case: 'missing user_metadata with no fallback',
|
||||
overrides: { user_metadata: undefined, metadata: undefined, name: '' },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
},
|
||||
{
|
||||
case: 'empty filename with no fallback',
|
||||
overrides: {
|
||||
user_metadata: { filename: '' },
|
||||
metadata: undefined,
|
||||
name: ''
|
||||
},
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
}
|
||||
])(
|
||||
'should fail when asset has $case',
|
||||
({ overrides, expectedCode, errorPattern }) => {
|
||||
const asset = createMockAsset(overrides)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe(expectedCode)
|
||||
expect(result.error.message).toMatch(errorPattern)
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
}
|
||||
}
|
||||
)
|
||||
it.for([
|
||||
{
|
||||
case: 'no tags',
|
||||
overrides: { tags: undefined },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no tags defined'
|
||||
},
|
||||
{
|
||||
case: 'only excluded tags',
|
||||
overrides: { tags: ['models', 'missing'] },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no valid category tag'
|
||||
},
|
||||
{
|
||||
case: 'only the models tag',
|
||||
overrides: { tags: ['models'] },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no valid category tag'
|
||||
}
|
||||
])(
|
||||
'should fail when asset has $case',
|
||||
({ overrides, expectedCode, errorMessage }) => {
|
||||
const asset = createMockAsset(overrides)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe(expectedCode)
|
||||
expect(result.error.message).toBe(errorMessage)
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
describe('when system resources are unavailable', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
it('should fail when no provider registered for category', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks({ nodeProvider: null })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_PROVIDER')
|
||||
expect(result.error.message).toContain('checkpoints')
|
||||
expect(result.error.details?.category).toBe('checkpoints')
|
||||
}
|
||||
})
|
||||
it('should fail when node creation fails', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
vi.mocked(LiteGraph.createNode).mockReturnValue(null)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NODE_CREATION_FAILED')
|
||||
expect(result.error.message).toContain('CheckpointLoaderSimple')
|
||||
}
|
||||
})
|
||||
it('should fail when widget is missing from node', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ widgetName: 'wrong_widget' })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('MISSING_WIDGET')
|
||||
expect(result.error.message).toContain('ckpt_name')
|
||||
expect(result.error.message).toContain('CheckpointLoaderSimple')
|
||||
expect(result.error.details?.widgetName).toBe('ckpt_name')
|
||||
}
|
||||
})
|
||||
it('should fail when node has no widgets array', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('MISSING_WIDGET')
|
||||
expect(result.error.message).toContain('ckpt_name not found')
|
||||
}
|
||||
})
|
||||
it('should not add node to graph when widget validation fails', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
createModelNodeFromAsset(asset)
|
||||
expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
describe('when graph is null', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.mocked(app).canvas.graph = null
|
||||
})
|
||||
it('should fail when no graph is available', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_GRAPH')
|
||||
expect(result.error.message).toBe('No active graph available')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,198 +0,0 @@
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { LGraphNode, Point } from '@/lib/litegraph/src/litegraph'
|
||||
import { assetItemSchema } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import {
|
||||
MISSING_TAG,
|
||||
MODELS_TAG
|
||||
} from '@/platform/assets/services/assetService'
|
||||
import { getAssetFilename } from '@/platform/assets/utils/assetMetadataUtils'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
interface ModelNodeCreateOptions {
|
||||
position?: Point
|
||||
}
|
||||
|
||||
type NodeCreationErrorCode =
|
||||
| 'INVALID_ASSET'
|
||||
| 'NO_PROVIDER'
|
||||
| 'NODE_CREATION_FAILED'
|
||||
| 'MISSING_WIDGET'
|
||||
| 'NO_GRAPH'
|
||||
|
||||
interface NodeCreationError {
|
||||
code: NodeCreationErrorCode
|
||||
message: string
|
||||
assetId: string
|
||||
details?: Record<string, unknown>
|
||||
}
|
||||
|
||||
type Result<T, E> = { success: true; value: T } | { success: false; error: E }
|
||||
|
||||
/**
|
||||
* Creates a LiteGraph node from an asset item.
|
||||
*
|
||||
* **Boundary Function**: Bridges Vue reactive domain with LiteGraph canvas domain.
|
||||
*
|
||||
* @param asset - Asset item to create node from (Vue domain)
|
||||
* @param options - Optional position and configuration
|
||||
* @returns Result with LiteGraph node (Canvas domain) or error details
|
||||
*
|
||||
* @remarks
|
||||
* This function performs side effects on the canvas graph. Validation failures
|
||||
* return error results rather than throwing to allow graceful degradation in UI contexts.
|
||||
* Widget validation occurs before graph mutation to prevent orphaned nodes.
|
||||
*/
|
||||
export function createModelNodeFromAsset(
|
||||
asset: AssetItem,
|
||||
options?: ModelNodeCreateOptions
|
||||
): Result<LGraphNode, NodeCreationError> {
|
||||
const validatedAsset = assetItemSchema.safeParse(asset)
|
||||
|
||||
if (!validatedAsset.success) {
|
||||
const errorMessage = validatedAsset.error.errors
|
||||
.map((e) => `${e.path.join('.')}: ${e.message}`)
|
||||
.join(', ')
|
||||
console.error('Invalid asset item:', errorMessage)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset schema validation failed',
|
||||
assetId: asset.id,
|
||||
details: { validationErrors: errorMessage }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const validAsset = validatedAsset.data
|
||||
|
||||
const filename = getAssetFilename(validAsset)
|
||||
if (filename.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has invalid user_metadata.filename (expected non-empty string, got ${typeof filename})`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: `Invalid filename (expected non-empty string, got ${typeof filename})`,
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (validAsset.tags.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no tags defined (expected at least one category tag)`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no tags defined',
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const category = validAsset.tags.find(
|
||||
(tag) => tag !== MODELS_TAG && tag !== MISSING_TAG
|
||||
)
|
||||
if (!category) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no valid category tag. Available tags: ${validAsset.tags.join(', ')} (expected tag other than '${MODELS_TAG}' or '${MISSING_TAG}')`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no valid category tag',
|
||||
assetId: validAsset.id,
|
||||
details: { availableTags: validAsset.tags }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
const provider = modelToNodeStore.getNodeProvider(category)
|
||||
if (!provider) {
|
||||
console.error(`No node provider registered for category: ${category}`)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NO_PROVIDER',
|
||||
message: `No node provider registered for category: ${category}`,
|
||||
assetId: validAsset.id,
|
||||
details: { category }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const litegraphService = useLitegraphService()
|
||||
const pos = options?.position ?? litegraphService.getCanvasCenter()
|
||||
|
||||
const node = LiteGraph.createNode(
|
||||
provider.nodeDef.name,
|
||||
provider.nodeDef.display_name,
|
||||
{ pos }
|
||||
)
|
||||
|
||||
if (!node) {
|
||||
console.error(`Failed to create node for type: ${provider.nodeDef.name}`)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NODE_CREATION_FAILED',
|
||||
message: `Failed to create node for type: ${provider.nodeDef.name}`,
|
||||
assetId: validAsset.id,
|
||||
details: { nodeType: provider.nodeDef.name }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const workflowStore = useWorkflowStore()
|
||||
const targetGraph = workflowStore.isSubgraphActive
|
||||
? workflowStore.activeSubgraph
|
||||
: app.canvas.graph
|
||||
|
||||
if (!targetGraph) {
|
||||
console.error('No active graph available')
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NO_GRAPH',
|
||||
message: 'No active graph available',
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set widget value if provider specifies a key (some nodes auto-load models without a widget)
|
||||
if (provider.key) {
|
||||
const widget = node.widgets?.find((w) => w.name === provider.key)
|
||||
if (!widget) {
|
||||
console.error(
|
||||
`Widget ${provider.key} not found on node ${provider.nodeDef.name}`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'MISSING_WIDGET',
|
||||
message: `Widget ${provider.key} not found on node ${provider.nodeDef.name}`,
|
||||
assetId: validAsset.id,
|
||||
details: { widgetName: provider.key, nodeType: provider.nodeDef.name }
|
||||
}
|
||||
}
|
||||
}
|
||||
widget.value = filename
|
||||
}
|
||||
|
||||
// Add the node to the graph
|
||||
targetGraph.add(node)
|
||||
|
||||
return { success: true, value: node }
|
||||
}
|
||||
208
src/platform/assets/utils/resolveModelNodeFromAsset.test.ts
Normal file
208
src/platform/assets/utils/resolveModelNodeFromAsset.test.ts
Normal file
@@ -0,0 +1,208 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { resolveModelNodeFromAsset } from '@/platform/assets/utils/resolveModelNodeFromAsset'
|
||||
|
||||
const mockGetNodeProvider = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({ getNodeProvider: mockGetNodeProvider })
|
||||
}))
|
||||
|
||||
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: {
|
||||
filename: 'models/checkpoints/test-model.safetensors'
|
||||
},
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function createMockNodeProvider(
|
||||
overrides: {
|
||||
nodeDef?: { name: string; display_name: string }
|
||||
key?: string
|
||||
} = {}
|
||||
) {
|
||||
return {
|
||||
nodeDef: {
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint',
|
||||
...overrides.nodeDef
|
||||
},
|
||||
key: overrides.key ?? 'ckpt_name'
|
||||
}
|
||||
}
|
||||
|
||||
function mockProvider(
|
||||
provider: ReturnType<typeof createMockNodeProvider> | null
|
||||
) {
|
||||
mockGetNodeProvider.mockReturnValue(provider)
|
||||
}
|
||||
|
||||
describe('resolveModelNodeFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
|
||||
describe('valid assets', () => {
|
||||
it('resolves the provider for the asset category and the filename', () => {
|
||||
mockProvider(createMockNodeProvider())
|
||||
const result = resolveModelNodeFromAsset(createMockAsset())
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(mockGetNodeProvider).toHaveBeenCalledWith('checkpoints')
|
||||
expect(result.value.provider.nodeDef.name).toBe(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(result.value.filename).toBe(
|
||||
'models/checkpoints/test-model.safetensors'
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('falls back to metadata.filename when user_metadata.filename missing', () => {
|
||||
mockProvider(createMockNodeProvider())
|
||||
const result = resolveModelNodeFromAsset(
|
||||
createMockAsset({
|
||||
user_metadata: {},
|
||||
metadata: { filename: 'models/checkpoints/from-metadata.safetensors' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.value.filename).toBe(
|
||||
'models/checkpoints/from-metadata.safetensors'
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('falls back to asset.name when both filename sources missing', () => {
|
||||
mockProvider(createMockNodeProvider())
|
||||
const result = resolveModelNodeFromAsset(
|
||||
createMockAsset({ user_metadata: {}, metadata: undefined })
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.value.filename).toBe('test-model.safetensors')
|
||||
}
|
||||
})
|
||||
|
||||
it('resolves a provider with an empty key (auto-load nodes)', () => {
|
||||
mockProvider(
|
||||
createMockNodeProvider({
|
||||
nodeDef: {
|
||||
name: 'FL_ChatterboxVC',
|
||||
display_name: 'FL Chatterbox VC'
|
||||
},
|
||||
key: ''
|
||||
})
|
||||
)
|
||||
const result = resolveModelNodeFromAsset(
|
||||
createMockAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.value.provider.key).toBe('')
|
||||
expect(result.value.filename).toBe('chatterbox_vc_model.pt')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('invalid assets', () => {
|
||||
it('fails when the asset does not match the schema', () => {
|
||||
const invalid = {
|
||||
id: 'asset-123',
|
||||
tags: ['models', 'checkpoints']
|
||||
} as unknown as AssetItem
|
||||
|
||||
const result = resolveModelNodeFromAsset(invalid)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('INVALID_ASSET')
|
||||
expect(result.error.message).toBe('Asset schema validation failed')
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
expect(result.error.details?.validationErrors).toBeTruthy()
|
||||
}
|
||||
})
|
||||
|
||||
it.for([
|
||||
{
|
||||
case: 'missing user_metadata with no fallback',
|
||||
overrides: {
|
||||
user_metadata: undefined,
|
||||
metadata: undefined,
|
||||
name: ''
|
||||
},
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
},
|
||||
{
|
||||
case: 'empty filename with no fallback',
|
||||
overrides: {
|
||||
user_metadata: { filename: '' },
|
||||
metadata: undefined,
|
||||
name: ''
|
||||
},
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
}
|
||||
])('fails when asset has $case', ({ overrides, errorPattern }) => {
|
||||
const result = resolveModelNodeFromAsset(createMockAsset(overrides))
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('INVALID_ASSET')
|
||||
expect(result.error.message).toMatch(errorPattern)
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
}
|
||||
})
|
||||
|
||||
it.for([
|
||||
{
|
||||
case: 'no tags',
|
||||
overrides: { tags: undefined },
|
||||
message: 'Asset has no tags defined'
|
||||
},
|
||||
{
|
||||
case: 'only excluded tags',
|
||||
overrides: { tags: ['models', 'missing'] },
|
||||
message: 'Asset has no valid category tag'
|
||||
},
|
||||
{
|
||||
case: 'only the models tag',
|
||||
overrides: { tags: ['models'] },
|
||||
message: 'Asset has no valid category tag'
|
||||
}
|
||||
])('fails when asset has $case', ({ overrides, message }) => {
|
||||
const result = resolveModelNodeFromAsset(createMockAsset(overrides))
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('INVALID_ASSET')
|
||||
expect(result.error.message).toBe(message)
|
||||
}
|
||||
})
|
||||
|
||||
it('fails when no provider is registered for the category', () => {
|
||||
mockProvider(null)
|
||||
const result = resolveModelNodeFromAsset(createMockAsset())
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_PROVIDER')
|
||||
expect(result.error.message).toContain('checkpoints')
|
||||
expect(result.error.details?.category).toBe('checkpoints')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
117
src/platform/assets/utils/resolveModelNodeFromAsset.ts
Normal file
117
src/platform/assets/utils/resolveModelNodeFromAsset.ts
Normal file
@@ -0,0 +1,117 @@
|
||||
import { assetItemSchema } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import {
|
||||
MISSING_TAG,
|
||||
MODELS_TAG
|
||||
} from '@/platform/assets/services/assetService'
|
||||
import { getAssetFilename } from '@/platform/assets/utils/assetMetadataUtils'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
import type { ModelNodeProvider } from '@/stores/modelToNodeStore'
|
||||
|
||||
type ResolveErrorCode = 'INVALID_ASSET' | 'NO_PROVIDER'
|
||||
|
||||
export interface ResolveModelNodeError {
|
||||
code: ResolveErrorCode
|
||||
message: string
|
||||
assetId: string
|
||||
details?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface ResolvedModelNode {
|
||||
provider: ModelNodeProvider
|
||||
filename: string
|
||||
}
|
||||
|
||||
type Result<T, E> = { success: true; value: T } | { success: false; error: E }
|
||||
|
||||
/**
|
||||
* Resolves an asset item to the node provider and filename needed to add a
|
||||
* model loader node. Validation failures return error results rather than
|
||||
* throwing, so callers can degrade gracefully in UI contexts.
|
||||
*/
|
||||
export function resolveModelNodeFromAsset(
|
||||
asset: AssetItem
|
||||
): Result<ResolvedModelNode, ResolveModelNodeError> {
|
||||
const validatedAsset = assetItemSchema.safeParse(asset)
|
||||
|
||||
if (!validatedAsset.success) {
|
||||
const errorMessage = validatedAsset.error.errors
|
||||
.map((e) => `${e.path.join('.')}: ${e.message}`)
|
||||
.join(', ')
|
||||
console.error('Invalid asset item:', errorMessage)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset schema validation failed',
|
||||
assetId: asset.id,
|
||||
details: { validationErrors: errorMessage }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const validAsset = validatedAsset.data
|
||||
|
||||
const filename = getAssetFilename(validAsset)
|
||||
if (filename.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has invalid user_metadata.filename (expected non-empty string, got ${typeof filename})`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: `Invalid filename (expected non-empty string, got ${typeof filename})`,
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (validAsset.tags.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no tags defined (expected at least one category tag)`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no tags defined',
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const category = validAsset.tags.find(
|
||||
(tag) => tag !== MODELS_TAG && tag !== MISSING_TAG
|
||||
)
|
||||
if (!category) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no valid category tag. Available tags: ${validAsset.tags.join(', ')} (expected tag other than '${MODELS_TAG}' or '${MISSING_TAG}')`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no valid category tag',
|
||||
assetId: validAsset.id,
|
||||
details: { availableTags: validAsset.tags }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const provider = useModelToNodeStore().getNodeProvider(category)
|
||||
if (!provider) {
|
||||
console.error(`No node provider registered for category: ${category}`)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NO_PROVIDER',
|
||||
message: `No node provider registered for category: ${category}`,
|
||||
assetId: validAsset.id,
|
||||
details: { category }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { success: true, value: { provider, filename } }
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ComfyNodeDef as ComfyNodeDefV2 } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import LGraphNodePreview from '@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
|
||||
vi.mock('@/stores/widgetStore', () => ({
|
||||
useWidgetStore: () => ({ inputIsWidget: () => true })
|
||||
}))
|
||||
|
||||
// Serializes the nodeData prop so tests can assert on the data contract
|
||||
// LGraphNodePreview hands to NodeWidgets. How that data renders is covered
|
||||
// by NodeWidgets.test.ts and browser_tests/tests/sidebar/modelLibrary.spec.ts.
|
||||
const NodeWidgetsProbe = {
|
||||
props: ['nodeData'],
|
||||
template: '<div data-testid="node-data">{{ JSON.stringify(nodeData) }}</div>'
|
||||
}
|
||||
|
||||
interface ProbedWidget {
|
||||
name: string
|
||||
options?: { values?: string[] }
|
||||
}
|
||||
|
||||
const nodeDef = fromPartial<ComfyNodeDefV2>({
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint',
|
||||
inputs: {
|
||||
ckpt_name: { type: 'COMBO', options: ['a.safetensors', 'b.safetensors'] }
|
||||
},
|
||||
outputs: []
|
||||
})
|
||||
|
||||
function renderedComboWidget(
|
||||
props: { widgetValues?: Record<string, string> } = {}
|
||||
) {
|
||||
render(LGraphNodePreview, {
|
||||
props: { nodeDef, ...props },
|
||||
global: {
|
||||
stubs: {
|
||||
NodeHeader: true,
|
||||
NodeSlots: true,
|
||||
NodeWidgets: NodeWidgetsProbe
|
||||
}
|
||||
}
|
||||
})
|
||||
const nodeData: { widgets?: ProbedWidget[] } = JSON.parse(
|
||||
screen.getByTestId('node-data').textContent ?? ''
|
||||
)
|
||||
return nodeData.widgets?.find((w) => w.name === 'ckpt_name')
|
||||
}
|
||||
|
||||
describe('LGraphNodePreview', () => {
|
||||
it('leads the combo options with the provided widget value', () => {
|
||||
const widget = renderedComboWidget({
|
||||
widgetValues: { ckpt_name: 'sd_xl_base_1.0.safetensors' }
|
||||
})
|
||||
|
||||
expect(widget?.options?.values).toEqual([
|
||||
'sd_xl_base_1.0.safetensors',
|
||||
'a.safetensors',
|
||||
'b.safetensors'
|
||||
])
|
||||
})
|
||||
|
||||
it('keeps the combo options untouched when no value is provided', () => {
|
||||
const widget = renderedComboWidget()
|
||||
|
||||
expect(widget?.options?.values).toEqual(['a.safetensors', 'b.safetensors'])
|
||||
})
|
||||
})
|
||||
@@ -45,9 +45,14 @@ import type { ComfyNodeDef as ComfyNodeDefV2 } from '@/schemas/nodeDef/nodeDefSc
|
||||
import { useWidgetStore } from '@/stores/widgetStore'
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
const { nodeDef, position = 'absolute' } = defineProps<{
|
||||
const {
|
||||
nodeDef,
|
||||
position = 'absolute',
|
||||
widgetValues
|
||||
} = defineProps<{
|
||||
nodeDef: ComfyNodeDefV2
|
||||
position?: 'absolute' | 'relative'
|
||||
widgetValues?: Record<string, string>
|
||||
}>()
|
||||
|
||||
const widgetStore = useWidgetStore()
|
||||
@@ -56,27 +61,32 @@ const widgetStore = useWidgetStore()
|
||||
const nodeData = computed<VueNodeData>(() => {
|
||||
const widgets = Object.entries(nodeDef.inputs || {})
|
||||
.filter(([_, input]) => widgetStore.inputIsWidget(input))
|
||||
.map(([name, input]) => ({
|
||||
nodeId: '-1',
|
||||
name,
|
||||
type: input.widgetType || input.type,
|
||||
value:
|
||||
input.default !== undefined
|
||||
? input.default
|
||||
: input.type === 'COMBO' &&
|
||||
Array.isArray(input.options) &&
|
||||
input.options.length > 0
|
||||
? input.options[0]
|
||||
: '',
|
||||
options: {
|
||||
hidden: input.hidden,
|
||||
advanced: input.advanced,
|
||||
values:
|
||||
input.type === 'COMBO' && Array.isArray(input.options)
|
||||
? input.options
|
||||
: undefined
|
||||
} satisfies IWidgetOptions
|
||||
}))
|
||||
.map(([name, input]) => {
|
||||
const comboValues =
|
||||
input.type === 'COMBO' && Array.isArray(input.options)
|
||||
? input.options
|
||||
: undefined
|
||||
// Preview nodes have no widget-value store entry, so combo widgets
|
||||
// render their first option; lead with the requested value to show it.
|
||||
const leadValue = widgetValues?.[name]
|
||||
return {
|
||||
nodeId: '-1',
|
||||
name,
|
||||
type: input.widgetType || input.type,
|
||||
value:
|
||||
input.default !== undefined
|
||||
? input.default
|
||||
: (comboValues?.[0] ?? ''),
|
||||
options: {
|
||||
hidden: input.hidden,
|
||||
advanced: input.advanced,
|
||||
values:
|
||||
leadValue && comboValues
|
||||
? [leadValue, ...comboValues.filter((o) => o !== leadValue)]
|
||||
: comboValues
|
||||
} satisfies IWidgetOptions
|
||||
}
|
||||
})
|
||||
|
||||
const inputs: INodeInputSlot[] = Object.entries(nodeDef.inputs || {})
|
||||
.filter(([_, input]) => !widgetStore.inputIsWidget(input))
|
||||
|
||||
Reference in New Issue
Block a user