mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-06-11 00:38:37 +00:00
Compare commits
2 Commits
pysssss/mo
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25205c0f55 | ||
|
|
598cf33ab7 |
@@ -1,61 +0,0 @@
|
||||
import { expect } from '@playwright/test'
|
||||
|
||||
import type { Asset } from '@comfyorg/ingest-types'
|
||||
import { createCloudAssetsFixture } from '@e2e/fixtures/assetApiFixture'
|
||||
import { STABLE_CHECKPOINT } from '@e2e/fixtures/data/assetFixtures'
|
||||
|
||||
const CLOUD_ASSETS: Asset[] = [STABLE_CHECKPOINT]
|
||||
|
||||
const test = createCloudAssetsFixture(CLOUD_ASSETS)
|
||||
|
||||
test.describe('Browse Model Assets - Use button', { tag: '@cloud' }, () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.settings.setSetting('Comfy.Assets.UseAssetAPI', true)
|
||||
await comfyPage.nodeOps.clearGraph()
|
||||
})
|
||||
|
||||
test.afterEach(async ({ comfyPage }) => {
|
||||
await comfyPage.nodeOps.clearGraph()
|
||||
})
|
||||
|
||||
test('Use button ghost-places a loader populated with the model', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
await comfyPage.command.executeCommand('Comfy.BrowseModelAssets')
|
||||
|
||||
const modal = comfyPage.page.locator(
|
||||
'[data-component-id="AssetBrowserModal"]'
|
||||
)
|
||||
await expect(modal).toBeVisible()
|
||||
|
||||
const card = comfyPage.page.locator(
|
||||
`[data-component-id="AssetCard"][data-asset-id="${STABLE_CHECKPOINT.id}"]`
|
||||
)
|
||||
await expect(card).toBeVisible()
|
||||
await card.getByRole('button', { name: 'Use' }).click()
|
||||
|
||||
// Dialog closes and the ghost is armed; the node is not placed until the
|
||||
// user clicks the canvas.
|
||||
await expect(modal).toBeHidden()
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getGraphNodesCount(), { timeout: 1000 })
|
||||
.toBe(0)
|
||||
|
||||
const canvasBox = (await comfyPage.canvas.boundingBox())!
|
||||
await comfyPage.canvas.click({
|
||||
position: { x: canvasBox.width / 2, y: canvasBox.height / 2 }
|
||||
})
|
||||
|
||||
await expect.poll(() => comfyPage.nodeOps.getGraphNodesCount()).toBe(1)
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getSelectedGraphNodesCount())
|
||||
.toBe(1)
|
||||
|
||||
const [loader] = await comfyPage.nodeOps.getNodeRefsByType(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(loader).toBeDefined()
|
||||
const widget = await loader.getWidgetByName('ckpt_name')
|
||||
expect(await widget.getValue()).toBe(STABLE_CHECKPOINT.name)
|
||||
})
|
||||
})
|
||||
@@ -233,64 +233,4 @@ test.describe('Model library sidebar - empty state', () => {
|
||||
await expect(tab.folderNodes).toHaveCount(0)
|
||||
await expect(tab.leafNodes).toHaveCount(0)
|
||||
})
|
||||
|
||||
test.describe('Model library sidebar - add node', () => {
|
||||
test.beforeEach(async ({ comfyPage }) => {
|
||||
await comfyPage.modelLibrary.mockFoldersWithFiles(MOCK_FOLDERS)
|
||||
await comfyPage.setup()
|
||||
await comfyPage.nodeOps.clearGraph()
|
||||
})
|
||||
|
||||
test.afterEach(async ({ comfyPage }) => {
|
||||
await comfyPage.modelLibrary.clearMocks()
|
||||
})
|
||||
|
||||
test('Clicking a model defers creation until placed on the canvas', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const tab = comfyPage.menu.modelLibraryTab
|
||||
await tab.open()
|
||||
await tab.getFolderByLabel('checkpoints').click()
|
||||
await expect(tab.getLeafByLabel('sd_xl_base_1.0')).toBeVisible()
|
||||
|
||||
await tab.getLeafByLabel('sd_xl_base_1.0').click()
|
||||
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getGraphNodesCount(), { timeout: 1000 })
|
||||
.toBe(0)
|
||||
|
||||
const canvasBox = (await comfyPage.canvas.boundingBox())!
|
||||
await comfyPage.canvas.click({
|
||||
position: { x: canvasBox.width / 2, y: canvasBox.height / 2 }
|
||||
})
|
||||
|
||||
await expect.poll(() => comfyPage.nodeOps.getGraphNodesCount()).toBe(1)
|
||||
await expect
|
||||
.poll(() => comfyPage.nodeOps.getSelectedGraphNodesCount())
|
||||
.toBe(1)
|
||||
|
||||
const [loader] = await comfyPage.nodeOps.getNodeRefsByType(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(loader).toBeDefined()
|
||||
const widget = await loader.getWidgetByName('ckpt_name')
|
||||
expect(await widget.getValue()).toBe('sd_xl_base_1.0.safetensors')
|
||||
})
|
||||
|
||||
test('Ghost preview shows the model in the loader widget before placing', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const tab = comfyPage.menu.modelLibraryTab
|
||||
await tab.open()
|
||||
await tab.getFolderByLabel('checkpoints').click()
|
||||
await expect(tab.getLeafByLabel('sd_xl_base_1.0')).toBeVisible()
|
||||
|
||||
await tab.getLeafByLabel('sd_xl_base_1.0').click()
|
||||
|
||||
const ghost = comfyPage.page.locator(
|
||||
'[data-node-id="preview-CheckpointLoaderSimple"]'
|
||||
)
|
||||
await expect(ghost).toContainText('sd_xl_base_1.0.safetensors')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
101
browser_tests/tests/workspaceSwitcher.spec.ts
Normal file
101
browser_tests/tests/workspaceSwitcher.spec.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
import { expect } from '@playwright/test'
|
||||
|
||||
import type { RemoteConfig } from '@/platform/remoteConfig/types'
|
||||
import type { WorkspaceWithRole } from '@/platform/workspace/api/workspaceApi'
|
||||
import type { WorkspaceTokenResponse } from '@/platform/workspace/stores/workspaceAuthStore'
|
||||
import { comfyPageFixture } from '@e2e/fixtures/ComfyPage'
|
||||
|
||||
const PERSONAL_WORKSPACE_NAME = 'Personal Workspace'
|
||||
const LONG_WORKSPACE_NAME =
|
||||
'Quantum Renaissance Collective for Hyperdimensional Latent Diffusion Research and Experimental Workflow Engineering'
|
||||
|
||||
// text-sm rows render a single 20px line; a wrapped name is 40px+.
|
||||
const SINGLE_LINE_MAX_HEIGHT_PX = 28
|
||||
|
||||
const mockRemoteConfig: RemoteConfig = { team_workspaces_enabled: true }
|
||||
|
||||
const mockListWorkspacesResponse: { workspaces: WorkspaceWithRole[] } = {
|
||||
workspaces: [
|
||||
{
|
||||
id: 'ws-personal',
|
||||
name: PERSONAL_WORKSPACE_NAME,
|
||||
type: 'personal',
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
joined_at: '2026-01-01T00:00:00Z',
|
||||
role: 'owner'
|
||||
},
|
||||
{
|
||||
id: 'ws-team-long',
|
||||
name: LONG_WORKSPACE_NAME,
|
||||
type: 'team',
|
||||
created_at: '2026-01-02T00:00:00Z',
|
||||
joined_at: '2026-01-02T00:00:00Z',
|
||||
role: 'member'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
const mockTokenResponse: WorkspaceTokenResponse = {
|
||||
token: 'mock-workspace-token',
|
||||
expires_at: new Date(Date.now() + 60 * 60 * 1000).toISOString(),
|
||||
workspace: {
|
||||
id: 'ws-personal',
|
||||
name: PERSONAL_WORKSPACE_NAME,
|
||||
type: 'personal'
|
||||
},
|
||||
role: 'owner',
|
||||
permissions: []
|
||||
}
|
||||
|
||||
const test = comfyPageFixture.extend({
|
||||
page: async ({ page }, use) => {
|
||||
await page.route('**/api/features', (route) =>
|
||||
route.fulfill({
|
||||
status: 200,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(mockRemoteConfig)
|
||||
})
|
||||
)
|
||||
|
||||
await page.route('**/api/workspaces', async (route) => {
|
||||
if (route.request().method() !== 'GET') {
|
||||
await route.fallback()
|
||||
return
|
||||
}
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(mockListWorkspacesResponse)
|
||||
})
|
||||
})
|
||||
|
||||
await page.route('**/api/auth/token', (route) =>
|
||||
route.fulfill({
|
||||
status: 200,
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify(mockTokenResponse)
|
||||
})
|
||||
)
|
||||
|
||||
await use(page)
|
||||
}
|
||||
})
|
||||
|
||||
test.describe('Workspace switcher', { tag: '@cloud' }, () => {
|
||||
test('renders a long team workspace name on a single line', async ({
|
||||
comfyPage
|
||||
}) => {
|
||||
const page = comfyPage.page
|
||||
|
||||
await comfyPage.toast.closeToasts()
|
||||
await page.getByRole('button', { name: 'Current user' }).click()
|
||||
await page.getByText(PERSONAL_WORKSPACE_NAME).click()
|
||||
|
||||
const longName = page.getByText(LONG_WORKSPACE_NAME)
|
||||
await expect(longName).toBeVisible()
|
||||
|
||||
const box = await longName.boundingBox()
|
||||
expect(box).not.toBeNull()
|
||||
expect(box!.height).toBeLessThan(SINGLE_LINE_MAX_HEIGHT_PX)
|
||||
})
|
||||
})
|
||||
@@ -111,6 +111,7 @@ describe('formatUtil', () => {
|
||||
expect(getMediaTypeFromFilename('scene.fbx')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('asset.gltf')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('binary.glb')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('print.stl')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('apple.usdz')).toBe('3D')
|
||||
expect(getMediaTypeFromFilename('scan.ply')).toBe('3D')
|
||||
})
|
||||
|
||||
@@ -591,7 +591,15 @@ const IMAGE_EXTENSIONS = [
|
||||
] as const
|
||||
const VIDEO_EXTENSIONS = ['mp4', 'm4v', 'webm', 'mov', 'avi', 'mkv'] as const
|
||||
const AUDIO_EXTENSIONS = ['mp3', 'wav', 'ogg', 'flac'] as const
|
||||
const THREE_D_EXTENSIONS = ['obj', 'fbx', 'gltf', 'glb', 'usdz', 'ply'] as const
|
||||
const THREE_D_EXTENSIONS = [
|
||||
'obj',
|
||||
'fbx',
|
||||
'gltf',
|
||||
'glb',
|
||||
'stl',
|
||||
'usdz',
|
||||
'ply'
|
||||
] as const
|
||||
const TEXT_EXTENSIONS = [
|
||||
'txt',
|
||||
'md',
|
||||
|
||||
@@ -355,7 +355,7 @@ describe('TreeExplorerV2Node', () => {
|
||||
const nodeDiv = getTreeNode(container)
|
||||
await fireEvent.dragStart(nodeDiv)
|
||||
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockData, { mode: 'native' })
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockData, 'native')
|
||||
})
|
||||
|
||||
it('does not call startDrag for folder items on dragstart', async () => {
|
||||
|
||||
@@ -93,7 +93,6 @@
|
||||
|
||||
<NodeTooltip v-if="tooltipEnabled" />
|
||||
<NodeSearchboxPopover ref="nodeSearchboxPopoverRef" />
|
||||
<NodeDragPreview />
|
||||
<VueNodeSwitchPopup />
|
||||
|
||||
<!-- Initialize components after comfyApp is ready. useAbsolutePosition requires
|
||||
@@ -137,7 +136,6 @@ import GraphCanvasMenu from '@/components/graph/GraphCanvasMenu.vue'
|
||||
import LinkOverlayCanvas from '@/components/graph/LinkOverlayCanvas.vue'
|
||||
import NodeTooltip from '@/components/graph/NodeTooltip.vue'
|
||||
import NodeContextMenu from '@/components/graph/NodeContextMenu.vue'
|
||||
import NodeDragPreview from '@/components/graph/NodeDragPreview.vue'
|
||||
import SelectionToolbox from '@/components/graph/SelectionToolbox.vue'
|
||||
import TitleEditor from '@/components/graph/TitleEditor.vue'
|
||||
import NodePropertiesPanel from '@/components/rightSidePanel/RightSidePanel.vue'
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
<template>
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="showGhost && rafPosition"
|
||||
class="pointer-events-none fixed top-0 left-0 z-10000 will-change-transform"
|
||||
:style="{
|
||||
transform: `translate(${rafPosition.x + 12}px, ${rafPosition.y + 12}px)`
|
||||
}"
|
||||
>
|
||||
<div class="origin-top-left scale-50 opacity-80">
|
||||
<LGraphNodePreview
|
||||
:node-def="draggedNode!"
|
||||
:widget-values="pendingWidgetValues"
|
||||
position="relative"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { useMouse, useRafFn } from '@vueuse/core'
|
||||
import { computed, shallowRef, watch } from 'vue'
|
||||
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import LGraphNodePreview from '@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue'
|
||||
|
||||
const { isDragging, draggedNode, pendingWidgetValues } = useNodeDragToCanvas()
|
||||
|
||||
const { x, y, sourceType } = useMouse({ type: 'client' })
|
||||
|
||||
const showGhost = computed(() => Boolean(isDragging.value && draggedNode.value))
|
||||
const rafPosition = shallowRef<{ x: number; y: number }>()
|
||||
|
||||
const { pause, resume } = useRafFn(
|
||||
() => {
|
||||
if (sourceType.value === null) return
|
||||
const pos = rafPosition.value
|
||||
if (pos && pos.x === x.value && pos.y === y.value) return
|
||||
rafPosition.value = { x: x.value, y: y.value }
|
||||
},
|
||||
{ immediate: false }
|
||||
)
|
||||
|
||||
watch(
|
||||
showGhost,
|
||||
(show) => {
|
||||
if (show) {
|
||||
resume()
|
||||
} else {
|
||||
pause()
|
||||
rafPosition.value = undefined
|
||||
}
|
||||
},
|
||||
{ immediate: true }
|
||||
)
|
||||
</script>
|
||||
@@ -23,6 +23,8 @@
|
||||
:can-use-gizmo="canUseGizmo"
|
||||
:can-use-lighting="canUseLighting"
|
||||
:can-export="canExport"
|
||||
:can-use-hdri="canUseHdri"
|
||||
:can-use-background-image="canUseBackgroundImage"
|
||||
:material-modes="materialModes"
|
||||
:has-skeleton="hasSkeleton"
|
||||
@update-background-image="handleBackgroundImageUpdate"
|
||||
@@ -86,7 +88,7 @@
|
||||
/>
|
||||
|
||||
<RecordingControls
|
||||
v-if="!isPreview"
|
||||
v-if="canUseRecording && !isPreview"
|
||||
v-model:is-recording="isRecording"
|
||||
v-model:has-recording="hasRecording"
|
||||
v-model:recording-duration="recordingDuration"
|
||||
@@ -117,9 +119,18 @@ import { resolveNode } from '@/utils/litegraphUtil'
|
||||
import type { ComponentWidget } from '@/scripts/domWidget'
|
||||
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
|
||||
|
||||
const props = defineProps<{
|
||||
const {
|
||||
widget,
|
||||
nodeId,
|
||||
canUseRecording = true,
|
||||
canUseHdri = true,
|
||||
canUseBackgroundImage = true
|
||||
} = defineProps<{
|
||||
widget: ComponentWidget<string[]> | SimplifiedWidget
|
||||
nodeId?: NodeId
|
||||
canUseRecording?: boolean
|
||||
canUseHdri?: boolean
|
||||
canUseBackgroundImage?: boolean
|
||||
}>()
|
||||
|
||||
function isComponentWidget(
|
||||
@@ -130,11 +141,11 @@ function isComponentWidget(
|
||||
|
||||
const node = ref<LGraphNode | null>(null)
|
||||
|
||||
if (isComponentWidget(props.widget)) {
|
||||
node.value = props.widget.node
|
||||
} else if (props.nodeId) {
|
||||
if (isComponentWidget(widget)) {
|
||||
node.value = widget.node
|
||||
} else if (nodeId) {
|
||||
onMounted(() => {
|
||||
node.value = resolveNode(props.nodeId!) ?? null
|
||||
node.value = resolveNode(nodeId) ?? null
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
47
src/components/load3d/Load3DAdvanced.test.ts
Normal file
47
src/components/load3d/Load3DAdvanced.test.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
import { render } from '@testing-library/vue'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { defineComponent, h, ref } from 'vue'
|
||||
|
||||
const lastProps = ref<Record<string, unknown> | null>(null)
|
||||
|
||||
vi.mock('@/components/load3d/Load3D.vue', () => ({
|
||||
default: defineComponent({
|
||||
name: 'Load3D',
|
||||
props: {
|
||||
widget: { type: null, required: false, default: undefined },
|
||||
nodeId: { type: null, required: false, default: undefined },
|
||||
canUseRecording: { type: Boolean, default: true },
|
||||
canUseHdri: { type: Boolean, default: true },
|
||||
canUseBackgroundImage: { type: Boolean, default: true }
|
||||
},
|
||||
setup(props: Record<string, unknown>) {
|
||||
lastProps.value = { ...props }
|
||||
return () => h('div', { 'data-testid': 'load3d-stub' })
|
||||
}
|
||||
})
|
||||
}))
|
||||
|
||||
import Load3DAdvanced from '@/components/load3d/Load3DAdvanced.vue'
|
||||
|
||||
describe('Load3DAdvanced', () => {
|
||||
it('renders the inner Load3D with all expressive features disabled', () => {
|
||||
const MOCK_NODE = { id: 'node', type: 'Load3DAdvanced' }
|
||||
render(Load3DAdvanced, {
|
||||
props: {
|
||||
widget: { node: MOCK_NODE } as never
|
||||
}
|
||||
})
|
||||
expect(lastProps.value).toMatchObject({
|
||||
canUseRecording: false,
|
||||
canUseHdri: false,
|
||||
canUseBackgroundImage: false
|
||||
})
|
||||
})
|
||||
|
||||
it('forwards widget and nodeId to the inner Load3D', () => {
|
||||
const widget = { node: { id: 'a', type: 'Load3DAdvanced' } }
|
||||
render(Load3DAdvanced, { props: { widget: widget as never, nodeId: 'a' } })
|
||||
expect(lastProps.value?.widget).toEqual(widget)
|
||||
expect(lastProps.value?.nodeId).toBe('a')
|
||||
})
|
||||
})
|
||||
21
src/components/load3d/Load3DAdvanced.vue
Normal file
21
src/components/load3d/Load3DAdvanced.vue
Normal file
@@ -0,0 +1,21 @@
|
||||
<template>
|
||||
<Load3D
|
||||
:widget="widget"
|
||||
:node-id="nodeId"
|
||||
:can-use-recording="false"
|
||||
:can-use-hdri="false"
|
||||
:can-use-background-image="false"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import Load3D from '@/components/load3d/Load3D.vue'
|
||||
import type { NodeId } from '@/platform/workflow/validation/schemas/workflowSchema'
|
||||
import type { ComponentWidget } from '@/scripts/domWidget'
|
||||
import type { SimplifiedWidget } from '@/types/simplifiedWidget'
|
||||
|
||||
defineProps<{
|
||||
widget: ComponentWidget<string[]> | SimplifiedWidget
|
||||
nodeId?: NodeId
|
||||
}>()
|
||||
</script>
|
||||
@@ -52,6 +52,7 @@
|
||||
v-model:background-image="sceneConfig!.backgroundImage"
|
||||
v-model:background-render-mode="sceneConfig!.backgroundRenderMode"
|
||||
v-model:fov="cameraConfig!.fov"
|
||||
:show-background-image="canUseBackgroundImage"
|
||||
:hdri-active="
|
||||
!!lightConfig?.hdri?.hdriPath && !!lightConfig?.hdri?.enabled
|
||||
"
|
||||
@@ -81,6 +82,7 @@
|
||||
/>
|
||||
|
||||
<HDRIControls
|
||||
v-if="canUseHdri"
|
||||
v-model:hdri-config="lightConfig!.hdri"
|
||||
:has-background-image="!!sceneConfig?.backgroundImage"
|
||||
@update-hdri-file="handleHDRIFileUpdate"
|
||||
@@ -129,12 +131,16 @@ const {
|
||||
canUseGizmo = true,
|
||||
canUseLighting = true,
|
||||
canExport = true,
|
||||
canUseHdri = true,
|
||||
canUseBackgroundImage = true,
|
||||
materialModes = ['original', 'normal', 'wireframe'],
|
||||
hasSkeleton = false
|
||||
} = defineProps<{
|
||||
canUseGizmo?: boolean
|
||||
canUseLighting?: boolean
|
||||
canExport?: boolean
|
||||
canUseHdri?: boolean
|
||||
canUseBackgroundImage?: boolean
|
||||
materialModes?: readonly MaterialMode[]
|
||||
hasSkeleton?: boolean
|
||||
}>()
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div v-if="!hasBackgroundImage">
|
||||
<div v-if="showBackgroundImage && !hasBackgroundImage">
|
||||
<Button
|
||||
v-tooltip.right="{
|
||||
value: $t('load3d.uploadBackgroundImage'),
|
||||
@@ -61,7 +61,7 @@
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div v-if="hasBackgroundImage">
|
||||
<div v-if="showBackgroundImage && hasBackgroundImage">
|
||||
<Button
|
||||
v-tooltip.right="{
|
||||
value: $t('load3d.panoramaMode'),
|
||||
@@ -83,12 +83,16 @@
|
||||
</div>
|
||||
|
||||
<PopupSlider
|
||||
v-if="hasBackgroundImage && backgroundRenderMode === 'panorama'"
|
||||
v-if="
|
||||
showBackgroundImage &&
|
||||
hasBackgroundImage &&
|
||||
backgroundRenderMode === 'panorama'
|
||||
"
|
||||
v-model="fov"
|
||||
:tooltip-text="$t('load3d.fov')"
|
||||
/>
|
||||
|
||||
<div v-if="hasBackgroundImage">
|
||||
<div v-if="showBackgroundImage && hasBackgroundImage">
|
||||
<Button
|
||||
v-tooltip.right="{
|
||||
value: $t('load3d.removeBackgroundImage'),
|
||||
@@ -114,8 +118,9 @@ import Button from '@/components/ui/button/Button.vue'
|
||||
import type { BackgroundRenderModeType } from '@/extensions/core/load3d/interfaces'
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
const { hdriActive = false } = defineProps<{
|
||||
const { hdriActive = false, showBackgroundImage = true } = defineProps<{
|
||||
hdriActive?: boolean
|
||||
showBackgroundImage?: boolean
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
|
||||
@@ -14,7 +14,7 @@ const {
|
||||
captureRoot,
|
||||
getRoot,
|
||||
resetRoot,
|
||||
mockStartDrag,
|
||||
mockAddNodeOnGraph,
|
||||
mockGetNodeProvider,
|
||||
mockToggleNodeOnEvent,
|
||||
mockRefreshModelFolder,
|
||||
@@ -29,7 +29,7 @@ const {
|
||||
resetRoot: () => {
|
||||
capturedRoot = null
|
||||
},
|
||||
mockStartDrag: vi.fn(),
|
||||
mockAddNodeOnGraph: vi.fn(),
|
||||
mockGetNodeProvider: vi.fn(),
|
||||
mockToggleNodeOnEvent: vi.fn(),
|
||||
mockRefreshModelFolder: vi.fn().mockResolvedValue(undefined),
|
||||
@@ -37,8 +37,8 @@ const {
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/composables/node/useNodeDragToCanvas', () => ({
|
||||
useNodeDragToCanvas: () => ({ startDrag: mockStartDrag })
|
||||
vi.mock('@/services/litegraphService', () => ({
|
||||
useLitegraphService: () => ({ addNodeOnGraph: mockAddNodeOnGraph })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
@@ -173,13 +173,16 @@ describe('ModelLibrarySidebarTab', () => {
|
||||
expect(screen.getByTestId('search-input')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('starts a ghost drag carrying the widget value to fill on placement', async () => {
|
||||
it('handles model click and adds node to graph', async () => {
|
||||
const mockNodeDef = { name: 'CheckpointLoaderSimple' }
|
||||
const mockWidget = { name: 'ckpt_name', value: '' }
|
||||
const mockGraphNode = { widgets: [mockWidget] }
|
||||
|
||||
mockGetNodeProvider.mockReturnValue({
|
||||
nodeDef: mockNodeDef,
|
||||
key: 'ckpt_name'
|
||||
})
|
||||
mockAddNodeOnGraph.mockReturnValue(mockGraphNode)
|
||||
|
||||
renderComponent()
|
||||
await nextTick()
|
||||
@@ -195,9 +198,8 @@ describe('ModelLibrarySidebarTab', () => {
|
||||
await modelLeaf?.handleClick?.(mockEvent)
|
||||
|
||||
expect(mockGetNodeProvider).toHaveBeenCalledWith('checkpoints')
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockNodeDef, {
|
||||
widgetValues: { ckpt_name: 'model.safetensors' }
|
||||
})
|
||||
expect(mockAddNodeOnGraph).toHaveBeenCalledWith(mockNodeDef)
|
||||
expect(mockWidget.value).toBe('model.safetensors')
|
||||
})
|
||||
|
||||
it('toggles folder expansion on click', async () => {
|
||||
|
||||
@@ -63,9 +63,10 @@ import SidebarTabTemplate from '@/components/sidebar/tabs/SidebarTabTemplate.vue
|
||||
import ElectronDownloadItems from '@/components/sidebar/tabs/modelLibrary/ElectronDownloadItems.vue'
|
||||
import ModelTreeLeaf from '@/components/sidebar/tabs/modelLibrary/ModelTreeLeaf.vue'
|
||||
import Button from '@/components/ui/button/Button.vue'
|
||||
import { startModelLoaderDrag } from '@/composables/node/startModelNodeDragFromAsset'
|
||||
import { useTreeExpansion } from '@/composables/useTreeExpansion'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { withNodeAddSource } from '@/platform/telemetry/nodeAdded/nodeAddSource'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useAssetDownloadStore } from '@/stores/assetDownloadStore'
|
||||
import type { ComfyModelDef, ModelFolder } from '@/stores/modelStore'
|
||||
import { ResourceState, useModelStore } from '@/stores/modelStore'
|
||||
@@ -155,7 +156,15 @@ const renderedRoot = computed<TreeExplorerNode<ModelOrFolder>>(() => {
|
||||
if (this.leaf && model) {
|
||||
const provider = modelToNodeStore.getNodeProvider(model.directory)
|
||||
if (provider) {
|
||||
startModelLoaderDrag(provider, model.file_name)
|
||||
const graphNode = withNodeAddSource('sidebar_drag', () =>
|
||||
useLitegraphService().addNodeOnGraph(provider.nodeDef)
|
||||
)
|
||||
const widget = graphNode?.widgets?.find(
|
||||
(widget) => widget.name === provider.key
|
||||
)
|
||||
if (widget) {
|
||||
widget.value = model.file_name
|
||||
}
|
||||
}
|
||||
} else {
|
||||
toggleNodeOnEvent(e, node)
|
||||
|
||||
@@ -31,8 +31,11 @@ vi.mock('@/composables/node/useNodeDragToCanvas', () => ({
|
||||
useNodeDragToCanvas: () => ({
|
||||
isDragging: { value: false },
|
||||
draggedNode: { value: null },
|
||||
cursorPosition: { value: { x: 0, y: 0 } },
|
||||
startDrag: vi.fn(),
|
||||
cancelDrag: vi.fn()
|
||||
cancelDrag: vi.fn(),
|
||||
setupGlobalListeners: vi.fn(),
|
||||
cleanupGlobalListeners: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
|
||||
@@ -115,6 +115,7 @@
|
||||
</div>
|
||||
</template>
|
||||
<template #body>
|
||||
<NodeDragPreview />
|
||||
<div class="flex h-full flex-col">
|
||||
<div
|
||||
v-if="hasNoMatches"
|
||||
@@ -214,6 +215,7 @@ import type {
|
||||
import AllNodesPanel from './nodeLibrary/AllNodesPanel.vue'
|
||||
import BlueprintsPanel from './nodeLibrary/BlueprintsPanel.vue'
|
||||
import EssentialNodesPanel from './nodeLibrary/EssentialNodesPanel.vue'
|
||||
import NodeDragPreview from './nodeLibrary/NodeDragPreview.vue'
|
||||
import SidebarTabTemplate from './SidebarTabTemplate.vue'
|
||||
|
||||
const { flags } = useFeatureFlags()
|
||||
|
||||
69
src/components/sidebar/tabs/nodeLibrary/NodeDragPreview.vue
Normal file
69
src/components/sidebar/tabs/nodeLibrary/NodeDragPreview.vue
Normal file
@@ -0,0 +1,69 @@
|
||||
<template>
|
||||
<Teleport to="body">
|
||||
<div
|
||||
v-if="isDragging && draggedNode && showPreview"
|
||||
class="pointer-events-none fixed z-10000"
|
||||
:style="{
|
||||
left: `${previewPosition.x + 12}px`,
|
||||
top: `${previewPosition.y + 12}px`
|
||||
}"
|
||||
>
|
||||
<div class="origin-top-left scale-50 opacity-80">
|
||||
<LGraphNodePreview :node-def="draggedNode" position="relative" />
|
||||
</div>
|
||||
</div>
|
||||
</Teleport>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, onMounted, onUnmounted, ref } from 'vue'
|
||||
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import LGraphNodePreview from '@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue'
|
||||
|
||||
const {
|
||||
isDragging,
|
||||
draggedNode,
|
||||
cursorPosition,
|
||||
dragMode,
|
||||
setupGlobalListeners,
|
||||
cleanupGlobalListeners
|
||||
} = useNodeDragToCanvas()
|
||||
|
||||
const nativeDragPosition = ref({ x: 0, y: 0 })
|
||||
|
||||
const previewPosition = computed(() => {
|
||||
if (dragMode.value === 'native') {
|
||||
return nativeDragPosition.value
|
||||
}
|
||||
return cursorPosition.value
|
||||
})
|
||||
|
||||
const showPreview = computed(() => {
|
||||
if (dragMode.value === 'native') {
|
||||
return nativeDragPosition.value.x > 0 || nativeDragPosition.value.y > 0
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
function handleDrag(e: DragEvent) {
|
||||
if (e.clientX === 0 && e.clientY === 0) return
|
||||
nativeDragPosition.value = { x: e.clientX, y: e.clientY }
|
||||
}
|
||||
|
||||
function handleDragEnd() {
|
||||
nativeDragPosition.value = { x: 0, y: 0 }
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
setupGlobalListeners()
|
||||
document.addEventListener('drag', handleDrag)
|
||||
document.addEventListener('dragend', handleDragEnd)
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
cleanupGlobalListeners()
|
||||
document.removeEventListener('drag', handleDrag)
|
||||
document.removeEventListener('dragend', handleDragEnd)
|
||||
})
|
||||
</script>
|
||||
@@ -1,73 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { startModelNodeDragFromAsset } from '@/composables/node/startModelNodeDragFromAsset'
|
||||
|
||||
const { mockStartDrag, mockGetNodeProvider } = vi.hoisted(() => ({
|
||||
mockStartDrag: vi.fn(),
|
||||
mockGetNodeProvider: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/node/useNodeDragToCanvas', () => ({
|
||||
useNodeDragToCanvas: () => ({ startDrag: mockStartDrag })
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({ getNodeProvider: mockGetNodeProvider })
|
||||
}))
|
||||
|
||||
function createAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'sd_xl_base_1.0.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: { filename: 'sd_xl_base_1.0.safetensors' },
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
describe('startModelNodeDragFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
|
||||
it('starts a ghost drag for the resolved node carrying the widget value', () => {
|
||||
const nodeDef = { name: 'CheckpointLoaderSimple' }
|
||||
mockGetNodeProvider.mockReturnValue({ nodeDef, key: 'ckpt_name' })
|
||||
|
||||
const error = startModelNodeDragFromAsset(createAsset())
|
||||
|
||||
expect(error).toBeUndefined()
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(nodeDef, {
|
||||
widgetValues: { ckpt_name: 'sd_xl_base_1.0.safetensors' }
|
||||
})
|
||||
})
|
||||
|
||||
it('carries no widget value when the provider has no key', () => {
|
||||
const nodeDef = { name: 'FL_ChatterboxVC' }
|
||||
mockGetNodeProvider.mockReturnValue({ nodeDef, key: '' })
|
||||
|
||||
startModelNodeDragFromAsset(
|
||||
createAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(nodeDef, {
|
||||
widgetValues: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('returns the resolution error and does not start a drag for an invalid asset', () => {
|
||||
mockGetNodeProvider.mockReturnValue(null)
|
||||
|
||||
const error = startModelNodeDragFromAsset(createAsset())
|
||||
|
||||
expect(error?.code).toBe('NO_PROVIDER')
|
||||
expect(mockStartDrag).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
@@ -1,35 +0,0 @@
|
||||
import { useNodeDragToCanvas } from '@/composables/node/useNodeDragToCanvas'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { resolveModelNodeFromAsset } from '@/platform/assets/utils/resolveModelNodeFromAsset'
|
||||
import type { ResolveModelNodeError } from '@/platform/assets/utils/resolveModelNodeFromAsset'
|
||||
import type { ModelNodeProvider } from '@/stores/modelToNodeStore'
|
||||
|
||||
/**
|
||||
* Arms a ghost drag for a model loader node. Providers with no widget key
|
||||
* (auto-load nodes) start the drag without widget values.
|
||||
*/
|
||||
export function startModelLoaderDrag(
|
||||
provider: ModelNodeProvider,
|
||||
filename: string
|
||||
) {
|
||||
const widgetValues = provider.key ? { [provider.key]: filename } : undefined
|
||||
useNodeDragToCanvas().startDrag(provider.nodeDef, { widgetValues })
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a ghost drag for the model loader node described by an asset. The
|
||||
* node is created where the user next clicks the canvas, with the asset's
|
||||
* filename written into the loader widget.
|
||||
*
|
||||
* @returns the resolution error when the asset cannot be mapped to a node,
|
||||
* otherwise `undefined`.
|
||||
*/
|
||||
export function startModelNodeDragFromAsset(
|
||||
asset: AssetItem
|
||||
): ResolveModelNodeError | undefined {
|
||||
const resolved = resolveModelNodeFromAsset(asset)
|
||||
if (!resolved.success) return resolved.error
|
||||
|
||||
const { provider, filename } = resolved.value
|
||||
startModelLoaderDrag(provider, filename)
|
||||
}
|
||||
@@ -7,8 +7,7 @@ const {
|
||||
mockAddNodeOnGraph,
|
||||
mockConvertEventToCanvasOffset,
|
||||
mockSelectItems,
|
||||
mockCanvas,
|
||||
mockToastAdd
|
||||
mockCanvas
|
||||
} = vi.hoisted(() => {
|
||||
const mockConvertEventToCanvasOffset = vi.fn()
|
||||
const mockSelectItems = vi.fn()
|
||||
@@ -16,7 +15,6 @@ const {
|
||||
mockAddNodeOnGraph: vi.fn(),
|
||||
mockConvertEventToCanvasOffset,
|
||||
mockSelectItems,
|
||||
mockToastAdd: vi.fn(),
|
||||
mockCanvas: {
|
||||
canvas: {
|
||||
getBoundingClientRect: vi.fn()
|
||||
@@ -39,12 +37,6 @@ vi.mock('@/services/litegraphService', () => ({
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@/platform/updates/common/toastStore', () => ({
|
||||
useToastStore: vi.fn(() => ({ add: mockToastAdd }))
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n', () => ({ t: (key: string) => key }))
|
||||
|
||||
describe('useNodeDragToCanvas', () => {
|
||||
let useNodeDragToCanvas: typeof UseNodeDragToCanvasType
|
||||
|
||||
@@ -62,8 +54,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
const { cancelDrag } = useNodeDragToCanvas()
|
||||
cancelDrag()
|
||||
const { cleanupGlobalListeners } = useNodeDragToCanvas()
|
||||
cleanupGlobalListeners()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
@@ -79,6 +71,22 @@ describe('useNodeDragToCanvas', () => {
|
||||
expect(isDragging.value).toBe(true)
|
||||
expect(draggedNode.value).toBe(mockNodeDef)
|
||||
})
|
||||
|
||||
it('should set dragMode to click by default', () => {
|
||||
const { dragMode, startDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(dragMode.value).toBe('click')
|
||||
})
|
||||
|
||||
it('should set dragMode to native when specified', () => {
|
||||
const { dragMode, startDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
|
||||
expect(dragMode.value).toBe('native')
|
||||
})
|
||||
})
|
||||
|
||||
describe('cancelDrag', () => {
|
||||
@@ -94,78 +102,75 @@ describe('useNodeDragToCanvas', () => {
|
||||
expect(isDragging.value).toBe(false)
|
||||
expect(draggedNode.value).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('drag listener lifecycle', () => {
|
||||
it('should attach document listeners on startDrag', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
it('should reset dragMode to click', () => {
|
||||
const { dragMode, startDrag, cancelDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
startDrag(mockNodeDef, 'native')
|
||||
expect(dragMode.value).toBe('native')
|
||||
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerdown',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerup',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'keydown',
|
||||
expect.any(Function)
|
||||
)
|
||||
})
|
||||
|
||||
it('should not attach drag listeners until a drag starts', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
useNodeDragToCanvas()
|
||||
|
||||
expect(addEventListenerSpy).not.toHaveBeenCalledWith(
|
||||
'pointerup',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
expect(addEventListenerSpy).not.toHaveBeenCalledWith(
|
||||
'keydown',
|
||||
expect.any(Function)
|
||||
)
|
||||
})
|
||||
|
||||
it('should detach document listeners on cancelDrag', () => {
|
||||
const removeEventListenerSpy = vi.spyOn(document, 'removeEventListener')
|
||||
const { startDrag, cancelDrag } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
cancelDrag()
|
||||
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledWith(
|
||||
expect(dragMode.value).toBe('click')
|
||||
})
|
||||
})
|
||||
|
||||
describe('setupGlobalListeners', () => {
|
||||
it('should add event listeners to document', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
const { setupGlobalListeners } = useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointermove',
|
||||
expect.any(Function)
|
||||
)
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerdown',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledWith(
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'pointerup',
|
||||
expect.any(Function),
|
||||
true
|
||||
)
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith(
|
||||
'keydown',
|
||||
expect.any(Function)
|
||||
)
|
||||
})
|
||||
|
||||
it('should only attach listeners once across re-arms', () => {
|
||||
it('should only setup listeners once', () => {
|
||||
const addEventListenerSpy = vi.spyOn(document, 'addEventListener')
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
const { setupGlobalListeners } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
setupGlobalListeners()
|
||||
const callCount = addEventListenerSpy.mock.calls.length
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
setupGlobalListeners()
|
||||
|
||||
expect(addEventListenerSpy.mock.calls.length).toBe(callCount)
|
||||
})
|
||||
})
|
||||
|
||||
describe('cursorPosition', () => {
|
||||
it('should update on pointermove', () => {
|
||||
const { cursorPosition, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
|
||||
const pointerEvent = new PointerEvent('pointermove', {
|
||||
clientX: 100,
|
||||
clientY: 200
|
||||
})
|
||||
document.dispatchEvent(pointerEvent)
|
||||
|
||||
expect(cursorPosition.value).toEqual({ x: 100, y: 200 })
|
||||
})
|
||||
})
|
||||
|
||||
describe('endDrag behavior', () => {
|
||||
it('should add node when pointer is over canvas', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
@@ -176,7 +181,9 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
const pointerEvent = new PointerEvent('pointerup', {
|
||||
@@ -199,7 +206,10 @@ describe('useNodeDragToCanvas', () => {
|
||||
bottom: 500
|
||||
})
|
||||
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
const pointerEvent = new PointerEvent('pointerup', {
|
||||
@@ -214,7 +224,10 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should cancel drag on Escape key', () => {
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(isDragging.value).toBe(true)
|
||||
@@ -226,7 +239,10 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should not cancel drag on other keys', () => {
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
const keyEvent = new KeyboardEvent('keydown', { key: 'Enter' })
|
||||
@@ -246,7 +262,8 @@ describe('useNodeDragToCanvas', () => {
|
||||
const placedNode = { id: 1 }
|
||||
mockAddNodeOnGraph.mockReturnValue(placedNode)
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
document.dispatchEvent(
|
||||
@@ -260,97 +277,6 @@ describe('useNodeDragToCanvas', () => {
|
||||
expect(mockSelectItems).toHaveBeenCalledWith([placedNode])
|
||||
})
|
||||
|
||||
it('should apply the requested widget values to the placed node', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
right: 500,
|
||||
top: 0,
|
||||
bottom: 500
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
const widget = { name: 'ckpt_name', value: '' }
|
||||
mockAddNodeOnGraph.mockReturnValue({ id: 1, widgets: [widget] })
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, {
|
||||
widgetValues: { ckpt_name: 'model.safetensors' }
|
||||
})
|
||||
|
||||
document.dispatchEvent(
|
||||
new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
clientY: 250,
|
||||
bubbles: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(widget.value).toBe('model.safetensors')
|
||||
})
|
||||
|
||||
it('should still place the node when a requested widget is missing', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
right: 500,
|
||||
top: 0,
|
||||
bottom: 500
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
const placedNode = { id: 1, widgets: [] }
|
||||
mockAddNodeOnGraph.mockReturnValue(placedNode)
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {})
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, {
|
||||
widgetValues: { ckpt_name: 'model.safetensors' }
|
||||
})
|
||||
|
||||
document.dispatchEvent(
|
||||
new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
clientY: 250,
|
||||
bubbles: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockSelectItems).toHaveBeenCalledWith([placedNode])
|
||||
expect(mockToastAdd).not.toHaveBeenCalled()
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('ckpt_name')
|
||||
)
|
||||
})
|
||||
|
||||
it('should show an error toast when the graph fails to add the node', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
right: 500,
|
||||
top: 0,
|
||||
bottom: 500
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
mockAddNodeOnGraph.mockReturnValue(null)
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
document.dispatchEvent(
|
||||
new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
clientY: 250,
|
||||
bubbles: true
|
||||
})
|
||||
)
|
||||
|
||||
expect(mockToastAdd).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
severity: 'error',
|
||||
detail: 'assetBrowser.failedToCreateNode'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should not call selectItems when graph returns no node', () => {
|
||||
mockCanvas.canvas.getBoundingClientRect.mockReturnValue({
|
||||
left: 0,
|
||||
@@ -360,9 +286,9 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
mockAddNodeOnGraph.mockReturnValue(null)
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
document.dispatchEvent(
|
||||
@@ -385,8 +311,11 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([150, 150])
|
||||
|
||||
const { startDrag, isDragging } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
const { startDrag, setupGlobalListeners, isDragging } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
|
||||
const pointerEvent = new PointerEvent('pointerup', {
|
||||
clientX: 250,
|
||||
@@ -412,7 +341,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
startDrag(mockNodeDef, 'native')
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
expect(mockAddNodeOnGraph).toHaveBeenCalledWith(mockNodeDef, {
|
||||
@@ -430,7 +359,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
|
||||
const { startDrag, handleNativeDrop, isDragging } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
startDrag(mockNodeDef, 'native')
|
||||
handleNativeDrop(600, 250)
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
@@ -448,7 +377,7 @@ describe('useNodeDragToCanvas', () => {
|
||||
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef)
|
||||
startDrag(mockNodeDef, 'click')
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
expect(mockAddNodeOnGraph).not.toHaveBeenCalled()
|
||||
@@ -463,12 +392,14 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
mockConvertEventToCanvasOffset.mockReturnValue([200, 200])
|
||||
|
||||
const { startDrag, handleNativeDrop, isDragging } = useNodeDragToCanvas()
|
||||
const { startDrag, handleNativeDrop, isDragging, dragMode } =
|
||||
useNodeDragToCanvas()
|
||||
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
startDrag(mockNodeDef, 'native')
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
expect(isDragging.value).toBe(false)
|
||||
expect(dragMode.value).toBe('click')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -495,29 +426,31 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should stop propagation when in click-drag mode over canvas', () => {
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(dispatchPointerDown(250, 250)).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not stop propagation once the drag is cancelled', () => {
|
||||
const { startDrag, cancelDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef)
|
||||
cancelDrag()
|
||||
it('should not stop propagation when not dragging', () => {
|
||||
const { setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
|
||||
expect(dispatchPointerDown(250, 250)).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not stop propagation in native drag mode', () => {
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
|
||||
expect(dispatchPointerDown(250, 250)).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should not stop propagation when pointer is outside canvas', () => {
|
||||
const { startDrag } = useNodeDragToCanvas()
|
||||
const { startDrag, setupGlobalListeners } = useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef)
|
||||
|
||||
expect(dispatchPointerDown(600, 250)).not.toHaveBeenCalled()
|
||||
@@ -544,8 +477,10 @@ describe('useNodeDragToCanvas', () => {
|
||||
}
|
||||
|
||||
it('should prefer tracked drag position over dragend coordinates', () => {
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
|
||||
fireDrag(250, 250)
|
||||
// dragend supplies a bad position (the Firefox bug); the tracked one
|
||||
@@ -559,8 +494,10 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should ignore drag events with (0, 0)', () => {
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
|
||||
fireDrag(250, 250)
|
||||
fireDrag(0, 0)
|
||||
@@ -573,8 +510,10 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
|
||||
it('should fall back to dragend coordinates when no drag fired', () => {
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
startDrag(mockNodeDef, 'native')
|
||||
|
||||
handleNativeDrop(250, 250)
|
||||
|
||||
@@ -584,14 +523,32 @@ describe('useNodeDragToCanvas', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should ignore dragover events fired before startDrag', () => {
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
|
||||
fireDrag(250, 250)
|
||||
startDrag(mockNodeDef, 'native')
|
||||
handleNativeDrop(300, 300)
|
||||
|
||||
expect(mockConvertEventToCanvasOffset).toHaveBeenCalledWith({
|
||||
clientX: 300,
|
||||
clientY: 300
|
||||
})
|
||||
})
|
||||
|
||||
it('should clear tracked position between drags', () => {
|
||||
const { startDrag, handleNativeDrop } = useNodeDragToCanvas()
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
const { startDrag, setupGlobalListeners, handleNativeDrop } =
|
||||
useNodeDragToCanvas()
|
||||
setupGlobalListeners()
|
||||
|
||||
startDrag(mockNodeDef, 'native')
|
||||
fireDrag(250, 250)
|
||||
handleNativeDrop(1505, 102)
|
||||
|
||||
// Second drag - no drag events, so we should fall back to args.
|
||||
startDrag(mockNodeDef, { mode: 'native' })
|
||||
startDrag(mockNodeDef, 'native')
|
||||
handleNativeDrop(300, 300)
|
||||
|
||||
expect(mockConvertEventToCanvasOffset).toHaveBeenLastCalledWith({
|
||||
|
||||
@@ -1,29 +1,23 @@
|
||||
import { ref, shallowRef } from 'vue'
|
||||
|
||||
import { t } from '@/i18n'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { withNodeAddSource } from '@/platform/telemetry/nodeAdded/nodeAddSource'
|
||||
import { useToastStore } from '@/platform/updates/common/toastStore'
|
||||
import { useCanvasStore } from '@/renderer/core/canvas/canvasStore'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import type { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
|
||||
|
||||
type DragMode = 'click' | 'native'
|
||||
type WidgetValues = Record<string, string>
|
||||
type Position = { x: number; y: number }
|
||||
|
||||
interface StartDragOptions {
|
||||
mode?: DragMode
|
||||
widgetValues?: WidgetValues
|
||||
}
|
||||
|
||||
const isDragging = ref(false)
|
||||
const draggedNode = shallowRef<ComfyNodeDefImpl | null>(null)
|
||||
const cursorPosition = ref({ x: 0, y: 0 })
|
||||
const dragMode = ref<DragMode>('click')
|
||||
const lastNativeDragPosition = shallowRef<Position>()
|
||||
const pendingWidgetValues = shallowRef<WidgetValues>()
|
||||
const lastNativeDragPosition = shallowRef<{ x: number; y: number }>()
|
||||
let listenersSetup = false
|
||||
|
||||
function updatePosition(e: PointerEvent) {
|
||||
cursorPosition.value = { x: e.clientX, y: e.clientY }
|
||||
}
|
||||
|
||||
// Firefox dragend can report stale clientX/Y and `drag` can fire with
|
||||
// (0, 0). dragover on the target reliably reports real client coords.
|
||||
// https://bugzilla.mozilla.org/show_bug.cgi?id=1773886
|
||||
@@ -33,15 +27,11 @@ function trackNativeDragPosition(e: DragEvent) {
|
||||
lastNativeDragPosition.value = { x: e.clientX, y: e.clientY }
|
||||
}
|
||||
|
||||
function applyWidgetValues(node: LGraphNode, values: WidgetValues) {
|
||||
for (const [name, value] of Object.entries(values)) {
|
||||
const widget = node.widgets?.find((w) => w.name === name)
|
||||
if (!widget) {
|
||||
console.error(`Widget ${name} not found on node ${node.type}`)
|
||||
continue
|
||||
}
|
||||
widget.value = value
|
||||
}
|
||||
function cancelDrag() {
|
||||
isDragging.value = false
|
||||
draggedNode.value = null
|
||||
dragMode.value = 'click'
|
||||
lastNativeDragPosition.value = undefined
|
||||
}
|
||||
|
||||
function isOverCanvas(clientX: number, clientY: number): boolean {
|
||||
@@ -72,19 +62,7 @@ function addNodeAtPosition(clientX: number, clientY: number): boolean {
|
||||
const node = withNodeAddSource('sidebar_drag', () =>
|
||||
useLitegraphService().addNodeOnGraph(nodeDef, { pos })
|
||||
)
|
||||
if (!node) {
|
||||
console.error(`Failed to add node to graph: ${nodeDef.name}`)
|
||||
useToastStore().add({
|
||||
severity: 'error',
|
||||
summary: t('g.error'),
|
||||
detail: t('assetBrowser.failedToCreateNode')
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
if (pendingWidgetValues.value)
|
||||
applyWidgetValues(node, pendingWidgetValues.value)
|
||||
canvas.selectItems([node])
|
||||
if (node) canvas.selectItems([node])
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -114,6 +92,7 @@ function setupGlobalListeners() {
|
||||
if (listenersSetup) return
|
||||
listenersSetup = true
|
||||
|
||||
document.addEventListener('pointermove', updatePosition)
|
||||
document.addEventListener('pointerdown', blockCommitPointerDown, true)
|
||||
document.addEventListener('pointerup', endDrag, true)
|
||||
document.addEventListener('keydown', handleKeydown)
|
||||
@@ -124,31 +103,22 @@ function cleanupGlobalListeners() {
|
||||
if (!listenersSetup) return
|
||||
listenersSetup = false
|
||||
|
||||
document.removeEventListener('pointermove', updatePosition)
|
||||
document.removeEventListener('pointerdown', blockCommitPointerDown, true)
|
||||
document.removeEventListener('pointerup', endDrag, true)
|
||||
document.removeEventListener('keydown', handleKeydown)
|
||||
document.removeEventListener('dragover', trackNativeDragPosition)
|
||||
}
|
||||
|
||||
function cancelDrag() {
|
||||
isDragging.value = false
|
||||
draggedNode.value = null
|
||||
dragMode.value = 'click'
|
||||
lastNativeDragPosition.value = undefined
|
||||
pendingWidgetValues.value = undefined
|
||||
cleanupGlobalListeners()
|
||||
if (isDragging.value && dragMode.value === 'click') {
|
||||
cancelDrag()
|
||||
}
|
||||
}
|
||||
|
||||
export function useNodeDragToCanvas() {
|
||||
function startDrag(
|
||||
nodeDef: ComfyNodeDefImpl,
|
||||
{ mode = 'click', widgetValues }: StartDragOptions = {}
|
||||
) {
|
||||
function startDrag(nodeDef: ComfyNodeDefImpl, mode: DragMode = 'click') {
|
||||
isDragging.value = true
|
||||
draggedNode.value = nodeDef
|
||||
dragMode.value = mode
|
||||
pendingWidgetValues.value = widgetValues
|
||||
setupGlobalListeners()
|
||||
}
|
||||
|
||||
function handleNativeDrop(clientX: number, clientY: number) {
|
||||
@@ -164,9 +134,12 @@ export function useNodeDragToCanvas() {
|
||||
return {
|
||||
isDragging,
|
||||
draggedNode,
|
||||
pendingWidgetValues,
|
||||
cursorPosition,
|
||||
dragMode,
|
||||
startDrag,
|
||||
cancelDrag,
|
||||
handleNativeDrop
|
||||
handleNativeDrop,
|
||||
setupGlobalListeners,
|
||||
cleanupGlobalListeners
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,9 +124,7 @@ describe('useNodePreviewAndDrag', () => {
|
||||
|
||||
expect(result.isDragging.value).toBe(true)
|
||||
expect(result.isHovered.value).toBe(false)
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockNodeDef, {
|
||||
mode: 'native'
|
||||
})
|
||||
expect(mockStartDrag).toHaveBeenCalledWith(mockNodeDef, 'native')
|
||||
expect(mockDataTransfer.effectAllowed).toBe('copy')
|
||||
expect(mockDataTransfer.setData).toHaveBeenCalledWith(
|
||||
'application/x-comfy-node',
|
||||
|
||||
@@ -112,7 +112,7 @@ export function useNodePreviewAndDrag(
|
||||
isDragging.value = true
|
||||
isHovered.value = false
|
||||
|
||||
startDrag(nodeDef.value, { mode: 'native' })
|
||||
startDrag(nodeDef.value, 'native')
|
||||
|
||||
if (e.dataTransfer) {
|
||||
e.dataTransfer.effectAllowed = 'copy'
|
||||
|
||||
@@ -2,7 +2,6 @@ import { useCurrentUser } from '@/composables/auth/useCurrentUser'
|
||||
import { useAuthActions } from '@/composables/auth/useAuthActions'
|
||||
import { useSelectedLiteGraphItems } from '@/composables/canvas/useSelectedLiteGraphItems'
|
||||
import { useSubgraphOperations } from '@/composables/graph/useSubgraphOperations'
|
||||
import { startModelNodeDragFromAsset } from '@/composables/node/startModelNodeDragFromAsset'
|
||||
import { useExternalLink } from '@/composables/useExternalLink'
|
||||
import { useModelSelectorDialog } from '@/composables/useModelSelectorDialog'
|
||||
import {
|
||||
@@ -21,6 +20,7 @@ import {
|
||||
import type { Point } from '@/lib/litegraph/src/litegraph'
|
||||
import { useBillingContext } from '@/composables/billing/useBillingContext'
|
||||
import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog'
|
||||
import { createModelNodeFromAsset } from '@/platform/assets/utils/createModelNodeFromAsset'
|
||||
import { useSettingStore } from '@/platform/settings/settingStore'
|
||||
import { buildSupportUrl } from '@/platform/support/config'
|
||||
import { useTelemetry } from '@/platform/telemetry'
|
||||
@@ -1305,14 +1305,14 @@ export function useCoreCommands(): ComfyCommand[] {
|
||||
assetType: 'models',
|
||||
title: t('sideToolbar.modelLibrary'),
|
||||
onAssetSelected: (asset) => {
|
||||
const error = startModelNodeDragFromAsset(asset)
|
||||
if (error) {
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
if (!result.success) {
|
||||
toastStore.add({
|
||||
severity: 'error',
|
||||
summary: t('g.error'),
|
||||
detail: t('assetBrowser.failedToCreateNode')
|
||||
})
|
||||
console.error('Node creation failed:', error)
|
||||
console.error('Node creation failed:', result.error)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
LOAD3D_NONE_MODEL,
|
||||
SUPPORTED_EXTENSIONS_ACCEPT
|
||||
} from '@/extensions/core/load3d/constants'
|
||||
import { snapshotLoad3dState } from '@/extensions/core/load3d/load3dSerialize'
|
||||
import Load3dUtils from '@/extensions/core/load3d/Load3dUtils'
|
||||
import { t } from '@/i18n'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
@@ -413,16 +414,10 @@ useExtensionService().registerExtension({
|
||||
if (cached) return cached
|
||||
}
|
||||
|
||||
const cameraConfig: CameraConfig = (node.properties[
|
||||
'Camera Config'
|
||||
] as CameraConfig | undefined) || {
|
||||
cameraType: currentLoad3d.getCurrentCameraType(),
|
||||
fov: currentLoad3d.cameraManager.perspectiveCamera.fov
|
||||
}
|
||||
cameraConfig.state = currentLoad3d.getCameraState()
|
||||
node.properties['Camera Config'] = cameraConfig
|
||||
|
||||
currentLoad3d.stopRecording()
|
||||
const { camera_info, model_3d_info } = snapshotLoad3dState(
|
||||
node,
|
||||
currentLoad3d
|
||||
)
|
||||
|
||||
const {
|
||||
scene: imageData,
|
||||
@@ -441,16 +436,11 @@ useExtensionService().registerExtension({
|
||||
|
||||
currentLoad3d.handleResize()
|
||||
|
||||
const modelInfo = currentLoad3d.getModelInfo()
|
||||
const model_3d_info: Model3DInfo = modelInfo ? [modelInfo] : []
|
||||
|
||||
const returnVal: Load3dCachedOutput = {
|
||||
image: `threed/${data.name} [temp]`,
|
||||
mask: `threed/${dataMask.name} [temp]`,
|
||||
normal: `threed/${dataNormal.name} [temp]`,
|
||||
camera_info:
|
||||
(node.properties['Camera Config'] as CameraConfig | undefined)
|
||||
?.state || null,
|
||||
camera_info,
|
||||
recording: '',
|
||||
model_3d_info
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ const mtlLoaderStub = {
|
||||
const objLoaderStub = {
|
||||
setWorkerUrl: vi.fn(),
|
||||
setMaterials: vi.fn(),
|
||||
setBaseObject3d: vi.fn(),
|
||||
loadAsync: vi.fn<(url: string) => Promise<THREE.Object3D>>()
|
||||
}
|
||||
|
||||
@@ -58,6 +59,7 @@ vi.mock('wwobjloader2', () => ({
|
||||
OBJLoader2Parallel: class {
|
||||
setWorkerUrl = objLoaderStub.setWorkerUrl
|
||||
setMaterials = objLoaderStub.setMaterials
|
||||
setBaseObject3d = objLoaderStub.setBaseObject3d
|
||||
loadAsync = objLoaderStub.loadAsync
|
||||
},
|
||||
MtlObjBridge: {
|
||||
@@ -247,6 +249,24 @@ describe('MeshModelAdapter', () => {
|
||||
|
||||
expect(ctx.registerOriginalMaterial).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('resets baseObject3d on every load so meshes do not accumulate across calls', async () => {
|
||||
objLoaderStub.loadAsync.mockResolvedValue(makeFbxLikeGroup())
|
||||
|
||||
const adapter = new MeshModelAdapter()
|
||||
const ctx = makeContext('wireframe')
|
||||
await adapter.load(ctx, '/api/view/', 'first.obj')
|
||||
await adapter.load(ctx, '/api/view/', 'second.obj')
|
||||
|
||||
expect(objLoaderStub.setBaseObject3d).toHaveBeenCalledTimes(2)
|
||||
const bases = objLoaderStub.setBaseObject3d.mock.calls.map(
|
||||
([base]) => base
|
||||
)
|
||||
expect(bases[0]).toBeInstanceOf(THREE.Object3D)
|
||||
expect(bases[1]).toBeInstanceOf(THREE.Object3D)
|
||||
// Each call should hand the loader a fresh container, not the same one.
|
||||
expect(bases[0]).not.toBe(bases[1])
|
||||
})
|
||||
})
|
||||
|
||||
describe('GLTF loader path', () => {
|
||||
|
||||
@@ -102,6 +102,8 @@ export class MeshModelAdapter implements ModelAdapter {
|
||||
path: string,
|
||||
filename: string
|
||||
): Promise<THREE.Object3D> {
|
||||
this.objLoader.setBaseObject3d(new THREE.Object3D())
|
||||
|
||||
if (ctx.materialMode === 'original') {
|
||||
try {
|
||||
this.mtlLoader.setPath(path)
|
||||
|
||||
87
src/extensions/core/load3d/load3dSerialize.test.ts
Normal file
87
src/extensions/core/load3d/load3dSerialize.test.ts
Normal file
@@ -0,0 +1,87 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type Load3d from '@/extensions/core/load3d/Load3d'
|
||||
import { snapshotLoad3dState } from '@/extensions/core/load3d/load3dSerialize'
|
||||
import type { CameraState } from '@/extensions/core/load3d/interfaces'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
|
||||
function makeNode(props: Record<string, unknown> = {}): LGraphNode {
|
||||
return { properties: { ...props } } as unknown as LGraphNode
|
||||
}
|
||||
|
||||
const baseCameraState: CameraState = {
|
||||
position: { x: 1, y: 2, z: 3 },
|
||||
target: { x: 0, y: 0, z: 0 },
|
||||
zoom: 1,
|
||||
cameraType: 'perspective'
|
||||
} as unknown as CameraState
|
||||
|
||||
function makeLoad3d({
|
||||
cameraType = 'perspective',
|
||||
fov = 35,
|
||||
modelInfo = { transform: { position: [0, 0, 0] } } as unknown
|
||||
}: {
|
||||
cameraType?: string
|
||||
fov?: number
|
||||
modelInfo?: unknown
|
||||
} = {}) {
|
||||
return {
|
||||
getCurrentCameraType: vi.fn(() => cameraType),
|
||||
cameraManager: { perspectiveCamera: { fov } },
|
||||
getCameraState: vi.fn(() => baseCameraState),
|
||||
stopRecording: vi.fn(),
|
||||
getModelInfo: vi.fn(() => modelInfo)
|
||||
} as unknown as Load3d
|
||||
}
|
||||
|
||||
describe('snapshotLoad3dState', () => {
|
||||
it('returns only camera_info and model_3d_info', () => {
|
||||
const result = snapshotLoad3dState(makeNode(), makeLoad3d())
|
||||
expect(Object.keys(result).sort()).toEqual(['camera_info', 'model_3d_info'])
|
||||
})
|
||||
|
||||
it('writes the camera state into properties["Camera Config"]', () => {
|
||||
const node = makeNode()
|
||||
snapshotLoad3dState(node, makeLoad3d({ fov: 42 }))
|
||||
const cfg = node.properties['Camera Config'] as Record<string, unknown>
|
||||
expect(cfg).toMatchObject({
|
||||
cameraType: 'perspective',
|
||||
fov: 42,
|
||||
state: baseCameraState
|
||||
})
|
||||
})
|
||||
|
||||
it('preserves an existing Camera Config object instead of replacing it', () => {
|
||||
const existing = { cameraType: 'orthographic', fov: 99 }
|
||||
const node = makeNode({ 'Camera Config': existing })
|
||||
snapshotLoad3dState(node, makeLoad3d())
|
||||
// Same object reference (mutated in place), with state attached.
|
||||
expect(node.properties['Camera Config']).toBe(existing)
|
||||
expect(
|
||||
(node.properties['Camera Config'] as Record<string, unknown>).state
|
||||
).toBe(baseCameraState)
|
||||
})
|
||||
|
||||
it('stops in-progress recording as a side effect', () => {
|
||||
const load3d = makeLoad3d()
|
||||
snapshotLoad3dState(makeNode(), load3d)
|
||||
expect(load3d.stopRecording).toHaveBeenCalledOnce()
|
||||
})
|
||||
|
||||
it('returns model_3d_info as a single-element list when a model is loaded', () => {
|
||||
const info = { transform: { position: [1, 2, 3] } }
|
||||
const result = snapshotLoad3dState(
|
||||
makeNode(),
|
||||
makeLoad3d({ modelInfo: info })
|
||||
)
|
||||
expect(result.model_3d_info).toEqual([info])
|
||||
})
|
||||
|
||||
it('returns an empty model_3d_info list when no model is loaded', () => {
|
||||
const result = snapshotLoad3dState(
|
||||
makeNode(),
|
||||
makeLoad3d({ modelInfo: null })
|
||||
)
|
||||
expect(result.model_3d_info).toEqual([])
|
||||
})
|
||||
})
|
||||
36
src/extensions/core/load3d/load3dSerialize.ts
Normal file
36
src/extensions/core/load3d/load3dSerialize.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import type Load3d from '@/extensions/core/load3d/Load3d'
|
||||
import type {
|
||||
CameraConfig,
|
||||
CameraState,
|
||||
Model3DInfo
|
||||
} from '@/extensions/core/load3d/interfaces'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
|
||||
export type Load3dSerializedBase = {
|
||||
camera_info: CameraState | null
|
||||
model_3d_info: Model3DInfo
|
||||
}
|
||||
|
||||
export function snapshotLoad3dState(
|
||||
node: LGraphNode,
|
||||
load3d: Load3d
|
||||
): Load3dSerializedBase {
|
||||
const cameraConfig: CameraConfig = (node.properties['Camera Config'] as
|
||||
| CameraConfig
|
||||
| undefined) || {
|
||||
cameraType: load3d.getCurrentCameraType(),
|
||||
fov: load3d.cameraManager.perspectiveCamera.fov
|
||||
}
|
||||
cameraConfig.state = load3d.getCameraState()
|
||||
node.properties['Camera Config'] = cameraConfig
|
||||
|
||||
load3d.stopRecording()
|
||||
|
||||
const modelInfo = load3d.getModelInfo()
|
||||
const model_3d_info: Model3DInfo = modelInfo ? [modelInfo] : []
|
||||
|
||||
return {
|
||||
camera_info: cameraConfig.state ?? null,
|
||||
model_3d_info
|
||||
}
|
||||
}
|
||||
@@ -9,7 +9,12 @@ const LOAD3D_PREVIEW_NODES = new Set([
|
||||
'PreviewPointCloud'
|
||||
])
|
||||
|
||||
const LOAD3D_ALL_NODES = new Set([...LOAD3D_PREVIEW_NODES, 'Load3D', 'SaveGLB'])
|
||||
const LOAD3D_ALL_NODES = new Set([
|
||||
...LOAD3D_PREVIEW_NODES,
|
||||
'Load3D',
|
||||
'Load3DAdvanced',
|
||||
'SaveGLB'
|
||||
])
|
||||
|
||||
export const isLoad3dPreviewNode = (nodeType: string): boolean =>
|
||||
LOAD3D_PREVIEW_NODES.has(nodeType)
|
||||
|
||||
103
src/extensions/core/load3dAdvanced.ts
Normal file
103
src/extensions/core/load3dAdvanced.ts
Normal file
@@ -0,0 +1,103 @@
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
import Load3DAdvanced from '@/components/load3d/Load3DAdvanced.vue'
|
||||
import { nodeToLoad3dMap, useLoad3d } from '@/composables/useLoad3d'
|
||||
import { createExportMenuItems } from '@/extensions/core/load3d/exportMenuHelper'
|
||||
import type { CameraConfig } from '@/extensions/core/load3d/interfaces'
|
||||
import Load3DConfiguration from '@/extensions/core/load3d/Load3DConfiguration'
|
||||
import { snapshotLoad3dState } from '@/extensions/core/load3d/load3dSerialize'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
import type { IContextMenuValue } from '@/lib/litegraph/src/interfaces'
|
||||
import type { CustomInputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import { ComponentWidgetImpl, addWidget } from '@/scripts/domWidget'
|
||||
import { useExtensionService } from '@/services/extensionService'
|
||||
import { useLoad3dService } from '@/services/load3dService'
|
||||
|
||||
const inputSpecLoad3DAdvanced: CustomInputSpec = {
|
||||
name: 'viewport_state',
|
||||
type: 'LOAD_3D_ADVANCED',
|
||||
isPreview: false
|
||||
}
|
||||
|
||||
useExtensionService().registerExtension({
|
||||
name: 'Comfy.Load3DAdvanced',
|
||||
|
||||
beforeRegisterNodeDef(_nodeType, nodeData) {
|
||||
if (nodeData.name !== 'Load3DAdvanced') return
|
||||
if (!nodeData.input?.required) return
|
||||
nodeData.input.required.viewport_state = ['LOAD_3D_ADVANCED', {}]
|
||||
},
|
||||
|
||||
getNodeMenuItems(node: LGraphNode): (IContextMenuValue | null)[] {
|
||||
if (node.constructor.comfyClass !== 'Load3DAdvanced') return []
|
||||
|
||||
const load3d = useLoad3dService().getLoad3d(node)
|
||||
if (!load3d) return []
|
||||
|
||||
return createExportMenuItems(load3d)
|
||||
},
|
||||
|
||||
getCustomWidgets() {
|
||||
return {
|
||||
LOAD_3D_ADVANCED(node) {
|
||||
const widget = new ComponentWidgetImpl({
|
||||
node,
|
||||
name: 'viewport_state',
|
||||
component: Load3DAdvanced,
|
||||
inputSpec: inputSpecLoad3DAdvanced,
|
||||
options: {}
|
||||
})
|
||||
|
||||
widget.type = 'load3DAdvanced'
|
||||
|
||||
addWidget(node, widget)
|
||||
|
||||
return { widget }
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
async nodeCreated(node: LGraphNode) {
|
||||
if (node.constructor.comfyClass !== 'Load3DAdvanced') return
|
||||
|
||||
const [oldWidth, oldHeight] = node.size
|
||||
node.setSize([Math.max(oldWidth, 300), Math.max(oldHeight, 600)])
|
||||
|
||||
await nextTick()
|
||||
|
||||
useLoad3d(node).onLoad3dReady((load3d) => {
|
||||
const modelWidget = node.widgets?.find((w) => w.name === 'model_file')
|
||||
const width = node.widgets?.find((w) => w.name === 'width')
|
||||
const height = node.widgets?.find((w) => w.name === 'height')
|
||||
if (!modelWidget || !width || !height) return
|
||||
|
||||
const cameraConfig = node.properties['Camera Config'] as
|
||||
| CameraConfig
|
||||
| undefined
|
||||
const cameraState = cameraConfig?.state
|
||||
|
||||
const config = new Load3DConfiguration(load3d, node.properties)
|
||||
config.configure({
|
||||
loadFolder: 'input',
|
||||
modelWidget,
|
||||
cameraState,
|
||||
width,
|
||||
height
|
||||
})
|
||||
})
|
||||
|
||||
useLoad3d(node).waitForLoad3d(() => {
|
||||
const sceneWidget = node.widgets?.find((w) => w.name === 'viewport_state')
|
||||
if (!sceneWidget) return
|
||||
|
||||
sceneWidget.serializeValue = async () => {
|
||||
const currentLoad3d = nodeToLoad3dMap.get(node)
|
||||
if (!currentLoad3d) {
|
||||
console.error('No load3d instance found for node')
|
||||
return null
|
||||
}
|
||||
return snapshotLoad3dState(node, currentLoad3d)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -37,6 +37,7 @@ async function loadLoad3dExtensions(): Promise<ComfyExtension[]> {
|
||||
// Import extensions - they self-register via useExtensionService()
|
||||
await Promise.all([
|
||||
import('./load3d'),
|
||||
import('./load3dAdvanced'),
|
||||
import('./load3dPreviewExtensions'),
|
||||
import('./saveMesh')
|
||||
])
|
||||
@@ -66,6 +67,12 @@ useExtensionService().registerExtension({
|
||||
modelFile[1].mesh_upload = true
|
||||
modelFile[1].upload_subfolder = '3d'
|
||||
}
|
||||
} else if (nodeData.name === 'Load3DAdvanced') {
|
||||
const modelFile = nodeData.input?.required?.model_file
|
||||
if (modelFile?.[1]) {
|
||||
modelFile[1].mesh_upload = true
|
||||
modelFile[1].upload_subfolder = ''
|
||||
}
|
||||
}
|
||||
|
||||
// Load the 3D extensions and replay their beforeRegisterNodeDef hooks,
|
||||
|
||||
439
src/platform/assets/utils/createModelNodeFromAsset.test.ts
Normal file
439
src/platform/assets/utils/createModelNodeFromAsset.test.ts
Normal file
@@ -0,0 +1,439 @@
|
||||
// oxlint-disable no-misused-spread
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { markRaw } from 'vue'
|
||||
import type { Raw } from 'vue'
|
||||
|
||||
import type { LGraphNode, Subgraph } from '@/lib/litegraph/src/litegraph'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type * as LitegraphModule from '@/lib/litegraph/src/litegraph'
|
||||
import type * as ModelToNodeStoreModule from '@/stores/modelToNodeStore'
|
||||
import type * as WorkflowStoreModule from '@/platform/workflow/management/stores/workflowStore'
|
||||
import type * as LitegraphServiceModule from '@/services/litegraphService'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { createModelNodeFromAsset } from '@/platform/assets/utils/createModelNodeFromAsset'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
apiURL: vi.fn((path: string) => `http://localhost:8188${path}`)
|
||||
}
|
||||
}))
|
||||
vi.mock('@/stores/modelToNodeStore', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof ModelToNodeStoreModule>()
|
||||
return {
|
||||
...actual,
|
||||
useModelToNodeStore: vi.fn()
|
||||
}
|
||||
})
|
||||
vi.mock(
|
||||
'@/platform/workflow/management/stores/workflowStore',
|
||||
async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof WorkflowStoreModule>()
|
||||
return {
|
||||
...actual,
|
||||
useWorkflowStore: vi.fn()
|
||||
}
|
||||
}
|
||||
)
|
||||
vi.mock('@/services/litegraphService', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof LitegraphServiceModule>()
|
||||
return {
|
||||
...actual,
|
||||
useLitegraphService: vi.fn()
|
||||
}
|
||||
})
|
||||
vi.mock('@/lib/litegraph/src/litegraph', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof LitegraphModule>()
|
||||
return {
|
||||
...actual,
|
||||
LiteGraph: {
|
||||
...actual.LiteGraph,
|
||||
createNode: vi.fn()
|
||||
}
|
||||
}
|
||||
})
|
||||
vi.mock('@/scripts/app', () => ({
|
||||
app: {
|
||||
canvas: {
|
||||
graph: {
|
||||
add: vi.fn()
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: {
|
||||
filename: 'models/checkpoints/test-model.safetensors'
|
||||
},
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
async function createMockNode(overrides?: {
|
||||
widgetName?: string
|
||||
widgetValue?: string
|
||||
hasWidgets?: boolean
|
||||
}): Promise<LGraphNode> {
|
||||
const {
|
||||
widgetName = 'ckpt_name',
|
||||
widgetValue = '',
|
||||
hasWidgets = true
|
||||
} = overrides || {}
|
||||
|
||||
const { LGraphNode: ActualLGraphNode } = await vi.importActual<
|
||||
typeof LitegraphModule
|
||||
>('@/lib/litegraph/src/litegraph')
|
||||
|
||||
if (!hasWidgets) {
|
||||
return Object.create(ActualLGraphNode.prototype)
|
||||
}
|
||||
|
||||
type Widget = NonNullable<LGraphNode['widgets']>[number]
|
||||
const widget: Pick<Widget, 'name' | 'value' | 'type' | 'options' | 'y'> = {
|
||||
name: widgetName,
|
||||
value: widgetValue,
|
||||
type: 'string',
|
||||
options: {},
|
||||
y: 0
|
||||
}
|
||||
|
||||
return Object.create(ActualLGraphNode.prototype, {
|
||||
widgets: { value: [widget], writable: true }
|
||||
})
|
||||
}
|
||||
function createMockNodeProvider(
|
||||
overrides: {
|
||||
nodeDef?: { name: string; display_name: string }
|
||||
key?: string
|
||||
} = {}
|
||||
) {
|
||||
return {
|
||||
nodeDef: {
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint',
|
||||
...overrides.nodeDef
|
||||
},
|
||||
key: overrides.key ?? 'ckpt_name'
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Configures all mocked dependencies with sensible defaults.
|
||||
* Uses semantic parameters for clearer test intent.
|
||||
* For error paths or edge cases, pass null values or specific overrides.
|
||||
*/
|
||||
async function setupMocks(
|
||||
overrides: {
|
||||
nodeProvider?: ReturnType<typeof createMockNodeProvider> | null
|
||||
canvasCenter?: [number, number]
|
||||
activeSubgraph?: Raw<Subgraph>
|
||||
createdNode?: Awaited<ReturnType<typeof createMockNode>> | null
|
||||
} = {}
|
||||
) {
|
||||
const {
|
||||
nodeProvider = createMockNodeProvider(),
|
||||
canvasCenter = [100, 200],
|
||||
activeSubgraph = undefined,
|
||||
createdNode = await createMockNode()
|
||||
} = overrides
|
||||
|
||||
vi.mocked(useModelToNodeStore).mockReturnValue({
|
||||
...useModelToNodeStore(),
|
||||
getNodeProvider: vi.fn().mockReturnValue(nodeProvider)
|
||||
})
|
||||
|
||||
vi.mocked(useLitegraphService).mockReturnValue({
|
||||
...useLitegraphService(),
|
||||
getCanvasCenter: vi.fn().mockReturnValue(canvasCenter)
|
||||
})
|
||||
|
||||
vi.mocked(useWorkflowStore).mockReturnValue({
|
||||
...useWorkflowStore(),
|
||||
activeSubgraph,
|
||||
isSubgraphActive: !!activeSubgraph
|
||||
})
|
||||
vi.mocked(LiteGraph.createNode).mockReturnValue(createdNode)
|
||||
}
|
||||
describe('createModelNodeFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
})
|
||||
describe('when creating nodes from valid assets', () => {
|
||||
it('should create the appropriate loader node for the asset category', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(
|
||||
vi.mocked(useModelToNodeStore)().getNodeProvider
|
||||
).toHaveBeenCalledWith('checkpoints')
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [100, 200] }
|
||||
)
|
||||
}
|
||||
})
|
||||
it('should place node at canvas center by default', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks({
|
||||
canvasCenter: [150, 250]
|
||||
})
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(
|
||||
vi.mocked(useLitegraphService)().getCanvasCenter
|
||||
).toHaveBeenCalled()
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [150, 250] }
|
||||
)
|
||||
})
|
||||
it('should place node at specified position when position is provided', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
const result = createModelNodeFromAsset(asset, { position: [300, 400] })
|
||||
expect(result.success).toBe(true)
|
||||
expect(
|
||||
vi.mocked(useLitegraphService)().getCanvasCenter
|
||||
).not.toHaveBeenCalled()
|
||||
expect(LiteGraph.createNode).toHaveBeenCalledWith(
|
||||
'CheckpointLoaderSimple',
|
||||
'Load Checkpoint',
|
||||
{ pos: [300, 400] }
|
||||
)
|
||||
})
|
||||
it('should populate the loader widget with the asset file path', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe(
|
||||
'models/checkpoints/test-model.safetensors'
|
||||
)
|
||||
})
|
||||
it('should add node to root graph when no subgraph is active', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(vi.mocked(app).canvas.graph!.add).toHaveBeenCalledWith(mockNode)
|
||||
})
|
||||
it('should fallback to asset.metadata.filename when user_metadata.filename missing', async () => {
|
||||
const asset = createMockAsset({
|
||||
user_metadata: {},
|
||||
metadata: { filename: 'models/checkpoints/from-metadata.safetensors' }
|
||||
})
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe(
|
||||
'models/checkpoints/from-metadata.safetensors'
|
||||
)
|
||||
})
|
||||
it('should fallback to asset.name when both filename sources missing', async () => {
|
||||
const asset = createMockAsset({
|
||||
user_metadata: {},
|
||||
metadata: undefined
|
||||
})
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockNode.widgets?.[0].value).toBe('test-model.safetensors')
|
||||
})
|
||||
it('should add node to active subgraph when present', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
const { Subgraph } = await vi.importActual<typeof LitegraphModule>(
|
||||
'@/lib/litegraph/src/litegraph'
|
||||
)
|
||||
const mockSubgraph = markRaw(
|
||||
Object.create(Subgraph.prototype, {
|
||||
add: { value: vi.fn() }
|
||||
})
|
||||
)
|
||||
await setupMocks({
|
||||
createdNode: mockNode,
|
||||
activeSubgraph: mockSubgraph
|
||||
})
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(mockSubgraph.add).toHaveBeenCalledWith(mockNode)
|
||||
expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled()
|
||||
})
|
||||
it('should succeed when provider has empty key (auto-load nodes)', async () => {
|
||||
const asset = createMockAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
const nodeProvider = createMockNodeProvider({
|
||||
nodeDef: {
|
||||
name: 'FL_ChatterboxVC',
|
||||
display_name: 'FL Chatterbox VC'
|
||||
},
|
||||
key: ''
|
||||
})
|
||||
await setupMocks({ createdNode: mockNode, nodeProvider })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(true)
|
||||
expect(vi.mocked(app).canvas.graph!.add).toHaveBeenCalledWith(mockNode)
|
||||
})
|
||||
})
|
||||
describe('when asset data is incomplete or invalid', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
it.for([
|
||||
{
|
||||
case: 'missing user_metadata with no fallback',
|
||||
overrides: { user_metadata: undefined, metadata: undefined, name: '' },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
},
|
||||
{
|
||||
case: 'empty filename with no fallback',
|
||||
overrides: {
|
||||
user_metadata: { filename: '' },
|
||||
metadata: undefined,
|
||||
name: ''
|
||||
},
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
}
|
||||
])(
|
||||
'should fail when asset has $case',
|
||||
({ overrides, expectedCode, errorPattern }) => {
|
||||
const asset = createMockAsset(overrides)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe(expectedCode)
|
||||
expect(result.error.message).toMatch(errorPattern)
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
}
|
||||
}
|
||||
)
|
||||
it.for([
|
||||
{
|
||||
case: 'no tags',
|
||||
overrides: { tags: undefined },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no tags defined'
|
||||
},
|
||||
{
|
||||
case: 'only excluded tags',
|
||||
overrides: { tags: ['models', 'missing'] },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no valid category tag'
|
||||
},
|
||||
{
|
||||
case: 'only the models tag',
|
||||
overrides: { tags: ['models'] },
|
||||
expectedCode: 'INVALID_ASSET' as const,
|
||||
errorMessage: 'Asset has no valid category tag'
|
||||
}
|
||||
])(
|
||||
'should fail when asset has $case',
|
||||
({ overrides, expectedCode, errorMessage }) => {
|
||||
const asset = createMockAsset(overrides)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe(expectedCode)
|
||||
expect(result.error.message).toBe(errorMessage)
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
describe('when system resources are unavailable', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
it('should fail when no provider registered for category', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks({ nodeProvider: null })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_PROVIDER')
|
||||
expect(result.error.message).toContain('checkpoints')
|
||||
expect(result.error.details?.category).toBe('checkpoints')
|
||||
}
|
||||
})
|
||||
it('should fail when node creation fails', async () => {
|
||||
const asset = createMockAsset()
|
||||
await setupMocks()
|
||||
vi.mocked(LiteGraph.createNode).mockReturnValue(null)
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NODE_CREATION_FAILED')
|
||||
expect(result.error.message).toContain('CheckpointLoaderSimple')
|
||||
}
|
||||
})
|
||||
it('should fail when widget is missing from node', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ widgetName: 'wrong_widget' })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('MISSING_WIDGET')
|
||||
expect(result.error.message).toContain('ckpt_name')
|
||||
expect(result.error.message).toContain('CheckpointLoaderSimple')
|
||||
expect(result.error.details?.widgetName).toBe('ckpt_name')
|
||||
}
|
||||
})
|
||||
it('should fail when node has no widgets array', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('MISSING_WIDGET')
|
||||
expect(result.error.message).toContain('ckpt_name not found')
|
||||
}
|
||||
})
|
||||
it('should not add node to graph when widget validation fails', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode({ hasWidgets: false })
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
createModelNodeFromAsset(asset)
|
||||
expect(vi.mocked(app).canvas.graph!.add).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
describe('when graph is null', () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.mocked(app).canvas.graph = null
|
||||
})
|
||||
it('should fail when no graph is available', async () => {
|
||||
const asset = createMockAsset()
|
||||
const mockNode = await createMockNode()
|
||||
await setupMocks({ createdNode: mockNode })
|
||||
const result = createModelNodeFromAsset(asset)
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_GRAPH')
|
||||
expect(result.error.message).toBe('No active graph available')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
198
src/platform/assets/utils/createModelNodeFromAsset.ts
Normal file
198
src/platform/assets/utils/createModelNodeFromAsset.ts
Normal file
@@ -0,0 +1,198 @@
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { LGraphNode, Point } from '@/lib/litegraph/src/litegraph'
|
||||
import { assetItemSchema } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import {
|
||||
MISSING_TAG,
|
||||
MODELS_TAG
|
||||
} from '@/platform/assets/services/assetService'
|
||||
import { getAssetFilename } from '@/platform/assets/utils/assetMetadataUtils'
|
||||
import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
|
||||
interface ModelNodeCreateOptions {
|
||||
position?: Point
|
||||
}
|
||||
|
||||
type NodeCreationErrorCode =
|
||||
| 'INVALID_ASSET'
|
||||
| 'NO_PROVIDER'
|
||||
| 'NODE_CREATION_FAILED'
|
||||
| 'MISSING_WIDGET'
|
||||
| 'NO_GRAPH'
|
||||
|
||||
interface NodeCreationError {
|
||||
code: NodeCreationErrorCode
|
||||
message: string
|
||||
assetId: string
|
||||
details?: Record<string, unknown>
|
||||
}
|
||||
|
||||
type Result<T, E> = { success: true; value: T } | { success: false; error: E }
|
||||
|
||||
/**
|
||||
* Creates a LiteGraph node from an asset item.
|
||||
*
|
||||
* **Boundary Function**: Bridges Vue reactive domain with LiteGraph canvas domain.
|
||||
*
|
||||
* @param asset - Asset item to create node from (Vue domain)
|
||||
* @param options - Optional position and configuration
|
||||
* @returns Result with LiteGraph node (Canvas domain) or error details
|
||||
*
|
||||
* @remarks
|
||||
* This function performs side effects on the canvas graph. Validation failures
|
||||
* return error results rather than throwing to allow graceful degradation in UI contexts.
|
||||
* Widget validation occurs before graph mutation to prevent orphaned nodes.
|
||||
*/
|
||||
export function createModelNodeFromAsset(
|
||||
asset: AssetItem,
|
||||
options?: ModelNodeCreateOptions
|
||||
): Result<LGraphNode, NodeCreationError> {
|
||||
const validatedAsset = assetItemSchema.safeParse(asset)
|
||||
|
||||
if (!validatedAsset.success) {
|
||||
const errorMessage = validatedAsset.error.errors
|
||||
.map((e) => `${e.path.join('.')}: ${e.message}`)
|
||||
.join(', ')
|
||||
console.error('Invalid asset item:', errorMessage)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset schema validation failed',
|
||||
assetId: asset.id,
|
||||
details: { validationErrors: errorMessage }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const validAsset = validatedAsset.data
|
||||
|
||||
const filename = getAssetFilename(validAsset)
|
||||
if (filename.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has invalid user_metadata.filename (expected non-empty string, got ${typeof filename})`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: `Invalid filename (expected non-empty string, got ${typeof filename})`,
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (validAsset.tags.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no tags defined (expected at least one category tag)`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no tags defined',
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const category = validAsset.tags.find(
|
||||
(tag) => tag !== MODELS_TAG && tag !== MISSING_TAG
|
||||
)
|
||||
if (!category) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no valid category tag. Available tags: ${validAsset.tags.join(', ')} (expected tag other than '${MODELS_TAG}' or '${MISSING_TAG}')`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no valid category tag',
|
||||
assetId: validAsset.id,
|
||||
details: { availableTags: validAsset.tags }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const modelToNodeStore = useModelToNodeStore()
|
||||
const provider = modelToNodeStore.getNodeProvider(category)
|
||||
if (!provider) {
|
||||
console.error(`No node provider registered for category: ${category}`)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NO_PROVIDER',
|
||||
message: `No node provider registered for category: ${category}`,
|
||||
assetId: validAsset.id,
|
||||
details: { category }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const litegraphService = useLitegraphService()
|
||||
const pos = options?.position ?? litegraphService.getCanvasCenter()
|
||||
|
||||
const node = LiteGraph.createNode(
|
||||
provider.nodeDef.name,
|
||||
provider.nodeDef.display_name,
|
||||
{ pos }
|
||||
)
|
||||
|
||||
if (!node) {
|
||||
console.error(`Failed to create node for type: ${provider.nodeDef.name}`)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NODE_CREATION_FAILED',
|
||||
message: `Failed to create node for type: ${provider.nodeDef.name}`,
|
||||
assetId: validAsset.id,
|
||||
details: { nodeType: provider.nodeDef.name }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const workflowStore = useWorkflowStore()
|
||||
const targetGraph = workflowStore.isSubgraphActive
|
||||
? workflowStore.activeSubgraph
|
||||
: app.canvas.graph
|
||||
|
||||
if (!targetGraph) {
|
||||
console.error('No active graph available')
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NO_GRAPH',
|
||||
message: 'No active graph available',
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set widget value if provider specifies a key (some nodes auto-load models without a widget)
|
||||
if (provider.key) {
|
||||
const widget = node.widgets?.find((w) => w.name === provider.key)
|
||||
if (!widget) {
|
||||
console.error(
|
||||
`Widget ${provider.key} not found on node ${provider.nodeDef.name}`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'MISSING_WIDGET',
|
||||
message: `Widget ${provider.key} not found on node ${provider.nodeDef.name}`,
|
||||
assetId: validAsset.id,
|
||||
details: { widgetName: provider.key, nodeType: provider.nodeDef.name }
|
||||
}
|
||||
}
|
||||
}
|
||||
widget.value = filename
|
||||
}
|
||||
|
||||
// Add the node to the graph
|
||||
targetGraph.add(node)
|
||||
|
||||
return { success: true, value: node }
|
||||
}
|
||||
@@ -1,208 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import { resolveModelNodeFromAsset } from '@/platform/assets/utils/resolveModelNodeFromAsset'
|
||||
|
||||
const mockGetNodeProvider = vi.hoisted(() => vi.fn())
|
||||
|
||||
vi.mock('@/stores/modelToNodeStore', () => ({
|
||||
useModelToNodeStore: () => ({ getNodeProvider: mockGetNodeProvider })
|
||||
}))
|
||||
|
||||
function createMockAsset(overrides: Partial<AssetItem> = {}): AssetItem {
|
||||
return {
|
||||
id: 'asset-123',
|
||||
name: 'test-model.safetensors',
|
||||
size: 1024,
|
||||
created_at: '2025-10-01T00:00:00Z',
|
||||
tags: ['models', 'checkpoints'],
|
||||
user_metadata: {
|
||||
filename: 'models/checkpoints/test-model.safetensors'
|
||||
},
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function createMockNodeProvider(
|
||||
overrides: {
|
||||
nodeDef?: { name: string; display_name: string }
|
||||
key?: string
|
||||
} = {}
|
||||
) {
|
||||
return {
|
||||
nodeDef: {
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint',
|
||||
...overrides.nodeDef
|
||||
},
|
||||
key: overrides.key ?? 'ckpt_name'
|
||||
}
|
||||
}
|
||||
|
||||
function mockProvider(
|
||||
provider: ReturnType<typeof createMockNodeProvider> | null
|
||||
) {
|
||||
mockGetNodeProvider.mockReturnValue(provider)
|
||||
}
|
||||
|
||||
describe('resolveModelNodeFromAsset', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
})
|
||||
|
||||
describe('valid assets', () => {
|
||||
it('resolves the provider for the asset category and the filename', () => {
|
||||
mockProvider(createMockNodeProvider())
|
||||
const result = resolveModelNodeFromAsset(createMockAsset())
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(mockGetNodeProvider).toHaveBeenCalledWith('checkpoints')
|
||||
expect(result.value.provider.nodeDef.name).toBe(
|
||||
'CheckpointLoaderSimple'
|
||||
)
|
||||
expect(result.value.filename).toBe(
|
||||
'models/checkpoints/test-model.safetensors'
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('falls back to metadata.filename when user_metadata.filename missing', () => {
|
||||
mockProvider(createMockNodeProvider())
|
||||
const result = resolveModelNodeFromAsset(
|
||||
createMockAsset({
|
||||
user_metadata: {},
|
||||
metadata: { filename: 'models/checkpoints/from-metadata.safetensors' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.value.filename).toBe(
|
||||
'models/checkpoints/from-metadata.safetensors'
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('falls back to asset.name when both filename sources missing', () => {
|
||||
mockProvider(createMockNodeProvider())
|
||||
const result = resolveModelNodeFromAsset(
|
||||
createMockAsset({ user_metadata: {}, metadata: undefined })
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.value.filename).toBe('test-model.safetensors')
|
||||
}
|
||||
})
|
||||
|
||||
it('resolves a provider with an empty key (auto-load nodes)', () => {
|
||||
mockProvider(
|
||||
createMockNodeProvider({
|
||||
nodeDef: {
|
||||
name: 'FL_ChatterboxVC',
|
||||
display_name: 'FL Chatterbox VC'
|
||||
},
|
||||
key: ''
|
||||
})
|
||||
)
|
||||
const result = resolveModelNodeFromAsset(
|
||||
createMockAsset({
|
||||
tags: ['models', 'chatterbox/chatterbox_vc'],
|
||||
user_metadata: { filename: 'chatterbox_vc_model.pt' }
|
||||
})
|
||||
)
|
||||
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.value.provider.key).toBe('')
|
||||
expect(result.value.filename).toBe('chatterbox_vc_model.pt')
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('invalid assets', () => {
|
||||
it('fails when the asset does not match the schema', () => {
|
||||
const invalid = {
|
||||
id: 'asset-123',
|
||||
tags: ['models', 'checkpoints']
|
||||
} as unknown as AssetItem
|
||||
|
||||
const result = resolveModelNodeFromAsset(invalid)
|
||||
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('INVALID_ASSET')
|
||||
expect(result.error.message).toBe('Asset schema validation failed')
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
expect(result.error.details?.validationErrors).toBeTruthy()
|
||||
}
|
||||
})
|
||||
|
||||
it.for([
|
||||
{
|
||||
case: 'missing user_metadata with no fallback',
|
||||
overrides: {
|
||||
user_metadata: undefined,
|
||||
metadata: undefined,
|
||||
name: ''
|
||||
},
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
},
|
||||
{
|
||||
case: 'empty filename with no fallback',
|
||||
overrides: {
|
||||
user_metadata: { filename: '' },
|
||||
metadata: undefined,
|
||||
name: ''
|
||||
},
|
||||
errorPattern: /Invalid filename.*expected non-empty string/
|
||||
}
|
||||
])('fails when asset has $case', ({ overrides, errorPattern }) => {
|
||||
const result = resolveModelNodeFromAsset(createMockAsset(overrides))
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('INVALID_ASSET')
|
||||
expect(result.error.message).toMatch(errorPattern)
|
||||
expect(result.error.assetId).toBe('asset-123')
|
||||
}
|
||||
})
|
||||
|
||||
it.for([
|
||||
{
|
||||
case: 'no tags',
|
||||
overrides: { tags: undefined },
|
||||
message: 'Asset has no tags defined'
|
||||
},
|
||||
{
|
||||
case: 'only excluded tags',
|
||||
overrides: { tags: ['models', 'missing'] },
|
||||
message: 'Asset has no valid category tag'
|
||||
},
|
||||
{
|
||||
case: 'only the models tag',
|
||||
overrides: { tags: ['models'] },
|
||||
message: 'Asset has no valid category tag'
|
||||
}
|
||||
])('fails when asset has $case', ({ overrides, message }) => {
|
||||
const result = resolveModelNodeFromAsset(createMockAsset(overrides))
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('INVALID_ASSET')
|
||||
expect(result.error.message).toBe(message)
|
||||
}
|
||||
})
|
||||
|
||||
it('fails when no provider is registered for the category', () => {
|
||||
mockProvider(null)
|
||||
const result = resolveModelNodeFromAsset(createMockAsset())
|
||||
expect(result.success).toBe(false)
|
||||
if (!result.success) {
|
||||
expect(result.error.code).toBe('NO_PROVIDER')
|
||||
expect(result.error.message).toContain('checkpoints')
|
||||
expect(result.error.details?.category).toBe('checkpoints')
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,117 +0,0 @@
|
||||
import { assetItemSchema } from '@/platform/assets/schemas/assetSchema'
|
||||
import type { AssetItem } from '@/platform/assets/schemas/assetSchema'
|
||||
import {
|
||||
MISSING_TAG,
|
||||
MODELS_TAG
|
||||
} from '@/platform/assets/services/assetService'
|
||||
import { getAssetFilename } from '@/platform/assets/utils/assetMetadataUtils'
|
||||
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
|
||||
import type { ModelNodeProvider } from '@/stores/modelToNodeStore'
|
||||
|
||||
type ResolveErrorCode = 'INVALID_ASSET' | 'NO_PROVIDER'
|
||||
|
||||
export interface ResolveModelNodeError {
|
||||
code: ResolveErrorCode
|
||||
message: string
|
||||
assetId: string
|
||||
details?: Record<string, unknown>
|
||||
}
|
||||
|
||||
export interface ResolvedModelNode {
|
||||
provider: ModelNodeProvider
|
||||
filename: string
|
||||
}
|
||||
|
||||
type Result<T, E> = { success: true; value: T } | { success: false; error: E }
|
||||
|
||||
/**
|
||||
* Resolves an asset item to the node provider and filename needed to add a
|
||||
* model loader node. Validation failures return error results rather than
|
||||
* throwing, so callers can degrade gracefully in UI contexts.
|
||||
*/
|
||||
export function resolveModelNodeFromAsset(
|
||||
asset: AssetItem
|
||||
): Result<ResolvedModelNode, ResolveModelNodeError> {
|
||||
const validatedAsset = assetItemSchema.safeParse(asset)
|
||||
|
||||
if (!validatedAsset.success) {
|
||||
const errorMessage = validatedAsset.error.errors
|
||||
.map((e) => `${e.path.join('.')}: ${e.message}`)
|
||||
.join(', ')
|
||||
console.error('Invalid asset item:', errorMessage)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset schema validation failed',
|
||||
assetId: asset.id,
|
||||
details: { validationErrors: errorMessage }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const validAsset = validatedAsset.data
|
||||
|
||||
const filename = getAssetFilename(validAsset)
|
||||
if (filename.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has invalid user_metadata.filename (expected non-empty string, got ${typeof filename})`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: `Invalid filename (expected non-empty string, got ${typeof filename})`,
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (validAsset.tags.length === 0) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no tags defined (expected at least one category tag)`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no tags defined',
|
||||
assetId: validAsset.id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const category = validAsset.tags.find(
|
||||
(tag) => tag !== MODELS_TAG && tag !== MISSING_TAG
|
||||
)
|
||||
if (!category) {
|
||||
console.error(
|
||||
`Asset ${validAsset.id} has no valid category tag. Available tags: ${validAsset.tags.join(', ')} (expected tag other than '${MODELS_TAG}' or '${MISSING_TAG}')`
|
||||
)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'INVALID_ASSET',
|
||||
message: 'Asset has no valid category tag',
|
||||
assetId: validAsset.id,
|
||||
details: { availableTags: validAsset.tags }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const provider = useModelToNodeStore().getNodeProvider(category)
|
||||
if (!provider) {
|
||||
console.error(`No node provider registered for category: ${category}`)
|
||||
return {
|
||||
success: false,
|
||||
error: {
|
||||
code: 'NO_PROVIDER',
|
||||
message: `No node provider registered for category: ${category}`,
|
||||
assetId: validAsset.id,
|
||||
details: { category }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { success: true, value: { provider, filename } }
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { ref } from 'vue'
|
||||
import { createI18n } from 'vue-i18n'
|
||||
|
||||
import WorkspaceSwitcherPopover from './WorkspaceSwitcherPopover.vue'
|
||||
|
||||
vi.mock('@/platform/workspace/composables/useWorkspaceSwitch', () => ({
|
||||
useWorkspaceSwitch: () => ({ switchWorkspace: vi.fn() })
|
||||
}))
|
||||
|
||||
vi.mock('@/composables/billing/useBillingContext', () => ({
|
||||
useBillingContext: () => ({ subscription: ref(null) })
|
||||
}))
|
||||
|
||||
const LONG_WORKSPACE_NAME =
|
||||
'Quantum Renaissance Collective for Hyperdimensional Latent Diffusion Research and Experimental Workflow Engineering'
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
locale: 'en',
|
||||
messages: {
|
||||
en: {
|
||||
workspaceSwitcher: {
|
||||
personal: 'Personal',
|
||||
roleOwner: 'Owner',
|
||||
roleMember: 'Member',
|
||||
createWorkspace: 'Create new workspace',
|
||||
maxWorkspacesReached:
|
||||
'You can only own 10 workspaces. Delete one to create a new one.'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
function createWorkspaceState(overrides: Record<string, unknown>) {
|
||||
return {
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
joined_at: '2026-01-01T00:00:00Z',
|
||||
isSubscribed: false,
|
||||
subscriptionPlan: null,
|
||||
subscriptionTier: null,
|
||||
members: [],
|
||||
pendingInvites: [],
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
function renderComponent() {
|
||||
return render(WorkspaceSwitcherPopover, {
|
||||
global: {
|
||||
plugins: [
|
||||
createTestingPinia({
|
||||
createSpy: vi.fn,
|
||||
initialState: {
|
||||
teamWorkspace: {
|
||||
activeWorkspaceId: 'ws-personal',
|
||||
isFetchingWorkspaces: false,
|
||||
workspaces: [
|
||||
createWorkspaceState({
|
||||
id: 'ws-personal',
|
||||
name: 'Personal Workspace',
|
||||
type: 'personal',
|
||||
role: 'owner'
|
||||
}),
|
||||
createWorkspaceState({
|
||||
id: 'ws-team-long',
|
||||
name: LONG_WORKSPACE_NAME,
|
||||
type: 'team',
|
||||
role: 'member'
|
||||
})
|
||||
]
|
||||
}
|
||||
}
|
||||
}),
|
||||
i18n
|
||||
],
|
||||
stubs: {
|
||||
WorkspaceProfilePic: true
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
describe('WorkspaceSwitcherPopover', () => {
|
||||
it('exposes the full team workspace name as a tooltip on the row', () => {
|
||||
renderComponent()
|
||||
|
||||
const name = screen.getByText(LONG_WORKSPACE_NAME)
|
||||
|
||||
expect(name).toHaveAttribute('title', LONG_WORKSPACE_NAME)
|
||||
})
|
||||
})
|
||||
@@ -34,21 +34,20 @@
|
||||
@click="handleSelectWorkspace(workspace)"
|
||||
>
|
||||
<WorkspaceProfilePic
|
||||
class="size-8 text-sm"
|
||||
class="size-8 shrink-0 text-sm"
|
||||
:workspace-name="workspace.name"
|
||||
/>
|
||||
<div class="flex min-w-0 flex-1 flex-col items-start gap-1">
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-sm text-base-foreground">
|
||||
{{
|
||||
workspace.type === 'personal'
|
||||
? $t('workspaceSwitcher.personal')
|
||||
: workspace.name
|
||||
}}
|
||||
<div class="flex max-w-full items-center gap-1.5">
|
||||
<span
|
||||
:title="getDisplayName(workspace)"
|
||||
class="truncate text-sm text-base-foreground"
|
||||
>
|
||||
{{ getDisplayName(workspace) }}
|
||||
</span>
|
||||
<span
|
||||
v-if="resolveTierLabel(workspace)"
|
||||
class="rounded-full bg-base-foreground px-1 py-0.5 text-2xs font-bold text-base-background uppercase"
|
||||
class="shrink-0 rounded-full bg-base-foreground px-1 py-0.5 text-2xs font-bold text-base-background uppercase"
|
||||
>
|
||||
{{ resolveTierLabel(workspace) }}
|
||||
</span>
|
||||
@@ -59,7 +58,7 @@
|
||||
</div>
|
||||
<i
|
||||
v-if="isCurrentWorkspace(workspace)"
|
||||
class="pi pi-check text-sm text-base-foreground"
|
||||
class="pi pi-check shrink-0 text-sm text-base-foreground"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
@@ -171,6 +170,12 @@ function isCurrentWorkspace(workspace: AvailableWorkspace): boolean {
|
||||
return workspace.id === workspaceId.value
|
||||
}
|
||||
|
||||
function getDisplayName(workspace: AvailableWorkspace): string {
|
||||
return workspace.type === 'personal'
|
||||
? t('workspaceSwitcher.personal')
|
||||
: workspace.name
|
||||
}
|
||||
|
||||
function getRoleLabel(role: AvailableWorkspace['role']): string {
|
||||
if (role === 'owner') return t('workspaceSwitcher.roleOwner')
|
||||
if (role === 'member') return t('workspaceSwitcher.roleMember')
|
||||
|
||||
@@ -33,6 +33,10 @@ const WorkspaceTokenResponseSchema = z.object({
|
||||
permissions: z.array(z.string())
|
||||
})
|
||||
|
||||
export type WorkspaceTokenResponse = z.infer<
|
||||
typeof WorkspaceTokenResponseSchema
|
||||
>
|
||||
|
||||
export class WorkspaceAuthError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
import { render, screen } from '@testing-library/vue'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import type { ComfyNodeDef as ComfyNodeDefV2 } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import LGraphNodePreview from '@/renderer/extensions/vueNodes/components/LGraphNodePreview.vue'
|
||||
import { fromPartial } from '@total-typescript/shoehorn'
|
||||
|
||||
vi.mock('@/stores/widgetStore', () => ({
|
||||
useWidgetStore: () => ({ inputIsWidget: () => true })
|
||||
}))
|
||||
|
||||
// Serializes the nodeData prop so tests can assert on the data contract
|
||||
// LGraphNodePreview hands to NodeWidgets. How that data renders is covered
|
||||
// by NodeWidgets.test.ts and browser_tests/tests/sidebar/modelLibrary.spec.ts.
|
||||
const NodeWidgetsProbe = {
|
||||
props: ['nodeData'],
|
||||
template: '<div data-testid="node-data">{{ JSON.stringify(nodeData) }}</div>'
|
||||
}
|
||||
|
||||
interface ProbedWidget {
|
||||
name: string
|
||||
options?: { values?: string[] }
|
||||
}
|
||||
|
||||
const nodeDef = fromPartial<ComfyNodeDefV2>({
|
||||
name: 'CheckpointLoaderSimple',
|
||||
display_name: 'Load Checkpoint',
|
||||
inputs: {
|
||||
ckpt_name: { type: 'COMBO', options: ['a.safetensors', 'b.safetensors'] }
|
||||
},
|
||||
outputs: []
|
||||
})
|
||||
|
||||
function renderedComboWidget(
|
||||
props: { widgetValues?: Record<string, string> } = {}
|
||||
) {
|
||||
render(LGraphNodePreview, {
|
||||
props: { nodeDef, ...props },
|
||||
global: {
|
||||
stubs: {
|
||||
NodeHeader: true,
|
||||
NodeSlots: true,
|
||||
NodeWidgets: NodeWidgetsProbe
|
||||
}
|
||||
}
|
||||
})
|
||||
const nodeData: { widgets?: ProbedWidget[] } = JSON.parse(
|
||||
screen.getByTestId('node-data').textContent ?? ''
|
||||
)
|
||||
return nodeData.widgets?.find((w) => w.name === 'ckpt_name')
|
||||
}
|
||||
|
||||
describe('LGraphNodePreview', () => {
|
||||
it('leads the combo options with the provided widget value', () => {
|
||||
const widget = renderedComboWidget({
|
||||
widgetValues: { ckpt_name: 'sd_xl_base_1.0.safetensors' }
|
||||
})
|
||||
|
||||
expect(widget?.options?.values).toEqual([
|
||||
'sd_xl_base_1.0.safetensors',
|
||||
'a.safetensors',
|
||||
'b.safetensors'
|
||||
])
|
||||
})
|
||||
|
||||
it('keeps the combo options untouched when no value is provided', () => {
|
||||
const widget = renderedComboWidget()
|
||||
|
||||
expect(widget?.options?.values).toEqual(['a.safetensors', 'b.safetensors'])
|
||||
})
|
||||
})
|
||||
@@ -45,14 +45,9 @@ import type { ComfyNodeDef as ComfyNodeDefV2 } from '@/schemas/nodeDef/nodeDefSc
|
||||
import { useWidgetStore } from '@/stores/widgetStore'
|
||||
import { cn } from '@comfyorg/tailwind-utils'
|
||||
|
||||
const {
|
||||
nodeDef,
|
||||
position = 'absolute',
|
||||
widgetValues
|
||||
} = defineProps<{
|
||||
const { nodeDef, position = 'absolute' } = defineProps<{
|
||||
nodeDef: ComfyNodeDefV2
|
||||
position?: 'absolute' | 'relative'
|
||||
widgetValues?: Record<string, string>
|
||||
}>()
|
||||
|
||||
const widgetStore = useWidgetStore()
|
||||
@@ -61,32 +56,27 @@ const widgetStore = useWidgetStore()
|
||||
const nodeData = computed<VueNodeData>(() => {
|
||||
const widgets = Object.entries(nodeDef.inputs || {})
|
||||
.filter(([_, input]) => widgetStore.inputIsWidget(input))
|
||||
.map(([name, input]) => {
|
||||
const comboValues =
|
||||
input.type === 'COMBO' && Array.isArray(input.options)
|
||||
? input.options
|
||||
: undefined
|
||||
// Preview nodes have no widget-value store entry, so combo widgets
|
||||
// render their first option; lead with the requested value to show it.
|
||||
const leadValue = widgetValues?.[name]
|
||||
return {
|
||||
nodeId: '-1',
|
||||
name,
|
||||
type: input.widgetType || input.type,
|
||||
value:
|
||||
input.default !== undefined
|
||||
? input.default
|
||||
: (comboValues?.[0] ?? ''),
|
||||
options: {
|
||||
hidden: input.hidden,
|
||||
advanced: input.advanced,
|
||||
values:
|
||||
leadValue && comboValues
|
||||
? [leadValue, ...comboValues.filter((o) => o !== leadValue)]
|
||||
: comboValues
|
||||
} satisfies IWidgetOptions
|
||||
}
|
||||
})
|
||||
.map(([name, input]) => ({
|
||||
nodeId: '-1',
|
||||
name,
|
||||
type: input.widgetType || input.type,
|
||||
value:
|
||||
input.default !== undefined
|
||||
? input.default
|
||||
: input.type === 'COMBO' &&
|
||||
Array.isArray(input.options) &&
|
||||
input.options.length > 0
|
||||
? input.options[0]
|
||||
: '',
|
||||
options: {
|
||||
hidden: input.hidden,
|
||||
advanced: input.advanced,
|
||||
values:
|
||||
input.type === 'COMBO' && Array.isArray(input.options)
|
||||
? input.options
|
||||
: undefined
|
||||
} satisfies IWidgetOptions
|
||||
}))
|
||||
|
||||
const inputs: INodeInputSlot[] = Object.entries(nodeDef.inputs || {})
|
||||
.filter(([_, input]) => !widgetStore.inputIsWidget(input))
|
||||
|
||||
@@ -51,6 +51,9 @@ const AudioPreviewPlayer = defineAsyncComponent(
|
||||
const Load3D = defineAsyncComponent(
|
||||
() => import('@/components/load3d/Load3D.vue')
|
||||
)
|
||||
const Load3DAdvanced = defineAsyncComponent(
|
||||
() => import('@/components/load3d/Load3DAdvanced.vue')
|
||||
)
|
||||
const WidgetImageCrop = defineAsyncComponent(
|
||||
() => import('@/components/imagecrop/WidgetImageCrop.vue')
|
||||
)
|
||||
@@ -169,6 +172,14 @@ const coreWidgetDefinitions: Array<[string, WidgetDefinition]> = [
|
||||
}
|
||||
],
|
||||
['load3D', { component: Load3D, aliases: ['LOAD_3D'], essential: false }],
|
||||
[
|
||||
'load3DAdvanced',
|
||||
{
|
||||
component: Load3DAdvanced,
|
||||
aliases: ['LOAD_3D_ADVANCED'],
|
||||
essential: false
|
||||
}
|
||||
],
|
||||
[
|
||||
'imagecrop',
|
||||
{
|
||||
@@ -243,6 +254,7 @@ const EXPANDING_TYPES = [
|
||||
'textarea',
|
||||
'markdown',
|
||||
'load3D',
|
||||
'load3DAdvanced',
|
||||
'curve',
|
||||
'painter',
|
||||
'imagecompare',
|
||||
|
||||
Reference in New Issue
Block a user