diff --git a/.gitignore b/.gitignore index 5473190ea..32e1b6624 100644 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,7 @@ components.d.ts tests-ui/data/* tests-ui/ComfyUI_examples tests-ui/workflows/examples +coverage/ # Browser tests /test-results/ diff --git a/CODEOWNERS b/CODEOWNERS index 8d4e4a90f..cd1b4e508 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,17 +1,61 @@ -# Admins -* @Comfy-Org/comfy_frontend_devs +# Desktop/Electron +/src/types/desktop/ @webfiltered +/src/constants/desktopDialogs.ts @webfiltered +/src/constants/desktopMaintenanceTasks.ts @webfiltered +/src/stores/electronDownloadStore.ts @webfiltered +/src/extensions/core/electronAdapter.ts @webfiltered +/src/views/DesktopDialogView.vue @webfiltered +/src/components/install/ @webfiltered +/src/components/maintenance/ @webfiltered +/vite.electron.config.mts @webfiltered -# Maintainers -*.md @Comfy-Org/comfy_maintainer -/tests-ui/ @Comfy-Org/comfy_maintainer -/browser_tests/ @Comfy-Org/comfy_maintainer -/.env_example @Comfy-Org/comfy_maintainer +# Common UI Components +/src/components/chip/ @viva-jinyi +/src/components/card/ @viva-jinyi +/src/components/button/ @viva-jinyi +/src/components/input/ @viva-jinyi -# Translations (AIGODLIKE team + shinshin86) -/src/locales/ @Yorha4D @KarryCharon @DorotaLuna @shinshin86 @Comfy-Org/comfy_maintainer +# Topbar +/src/components/topbar/ @pythongosssss -# Load 3D extension -/src/extensions/core/load3d.ts @jtydhr88 @Comfy-Org/comfy_frontend_devs +# Thumbnail +/src/renderer/core/thumbnail/ @pythongosssss -# Mask Editor extension -/src/extensions/core/maskeditor.ts @brucew4yn3rp @trsommer @Comfy-Org/comfy_frontend_devs +# Legacy UI +/scripts/ui/ @pythongosssss + +# Link rendering +/src/renderer/core/canvas/links/ @benceruleanlu + +# Node help system +/src/utils/nodeHelpUtil.ts @benceruleanlu +/src/stores/workspace/nodeHelpStore.ts @benceruleanlu +/src/services/nodeHelpService.ts @benceruleanlu + +# Selection toolbox +/src/components/graph/selectionToolbox/ @Myestery + +# Minimap +/src/renderer/extensions/minimap/ @jtydhr88 + +# Assets +/src/platform/assets/ @arjansingh + +# Workflow Templates +/src/platform/workflow/templates/ @Myestery @christian-byrne @comfyui-wiki +/src/components/templates/ @Myestery @christian-byrne @comfyui-wiki + +# Mask Editor +/src/extensions/core/maskeditor.ts @trsommer @brucew4yn3rp +/src/extensions/core/maskEditorLayerFilenames.ts @trsommer @brucew4yn3rp +/src/extensions/core/maskEditorOld.ts @trsommer @brucew4yn3rp + +# 3D +/src/extensions/core/load3d.ts @jtydhr88 +/src/components/load3d/ @jtydhr88 + +# Manager +/src/workbench/extensions/manager/ @viva-jinyi @christian-byrne @ltdrdata + +# Translations +/src/locales/ @Yorha4D @KarryCharon @shinshin86 @Comfy-Org/comfy_maintainer diff --git a/browser_tests/tests/vueNodes/nodeInteractions/selectionState.spec.ts b/browser_tests/tests/vueNodes/nodeInteractions/selectionState.spec.ts new file mode 100644 index 000000000..ff8b6f951 --- /dev/null +++ b/browser_tests/tests/vueNodes/nodeInteractions/selectionState.spec.ts @@ -0,0 +1,47 @@ +import { + comfyExpect as expect, + comfyPageFixture as test +} from '../../../fixtures/ComfyPage' + +test.describe('Vue Node Selection', () => { + test.beforeEach(async ({ comfyPage }) => { + await comfyPage.setSetting('Comfy.VueNodes.Enabled', true) + await comfyPage.vueNodes.waitForNodes() + }) + + const modifiers = [ + { key: 'Control', name: 'ctrl' }, + { key: 'Shift', name: 'shift' } + ] as const + + for (const { key: modifier, name } of modifiers) { + test(`should allow selecting multiple nodes with ${name}+click`, async ({ + comfyPage + }) => { + await comfyPage.page.getByText('Load Checkpoint').click() + expect(await comfyPage.vueNodes.getSelectedNodeCount()).toBe(1) + + await comfyPage.page.getByText('Empty Latent Image').click({ + modifiers: [modifier] + }) + expect(await comfyPage.vueNodes.getSelectedNodeCount()).toBe(2) + + await comfyPage.page.getByText('KSampler').click({ + modifiers: [modifier] + }) + expect(await comfyPage.vueNodes.getSelectedNodeCount()).toBe(3) + }) + + test(`should allow de-selecting nodes with ${name}+click`, async ({ + comfyPage + }) => { + await comfyPage.page.getByText('Load Checkpoint').click() + expect(await comfyPage.vueNodes.getSelectedNodeCount()).toBe(1) + + await comfyPage.page.getByText('Load Checkpoint').click({ + modifiers: [modifier] + }) + expect(await comfyPage.vueNodes.getSelectedNodeCount()).toBe(0) + }) + } +}) diff --git a/browser_tests/tests/vueNodes/nodeStates/bypass.spec.ts b/browser_tests/tests/vueNodes/nodeStates/bypass.spec.ts new file mode 100644 index 000000000..c80a86503 --- /dev/null +++ b/browser_tests/tests/vueNodes/nodeStates/bypass.spec.ts @@ -0,0 +1,49 @@ +import { + comfyExpect as expect, + comfyPageFixture as test +} from '../../../fixtures/ComfyPage' + +const BYPASS_HOTKEY = 'Control+b' +const BYPASS_CLASS = /before:bg-bypass\/60/ + +test.describe('Vue Node Bypass', () => { + test.beforeEach(async ({ comfyPage }) => { + await comfyPage.setSetting('Comfy.VueNodes.Enabled', true) + await comfyPage.vueNodes.waitForNodes() + }) + + test('should allow toggling bypass on a selected node with hotkey', async ({ + comfyPage + }) => { + const checkpointNode = comfyPage.page.locator('[data-node-id]').filter({ + hasText: 'Load Checkpoint' + }) + await checkpointNode.getByText('Load Checkpoint').click() + await comfyPage.page.keyboard.press(BYPASS_HOTKEY) + await expect(checkpointNode).toHaveClass(BYPASS_CLASS) + + await comfyPage.page.keyboard.press(BYPASS_HOTKEY) + await expect(checkpointNode).not.toHaveClass(BYPASS_CLASS) + }) + + test('should allow toggling bypass on multiple selected nodes with hotkey', async ({ + comfyPage + }) => { + const checkpointNode = comfyPage.page.locator('[data-node-id]').filter({ + hasText: 'Load Checkpoint' + }) + const ksamplerNode = comfyPage.page.locator('[data-node-id]').filter({ + hasText: 'KSampler' + }) + + await checkpointNode.getByText('Load Checkpoint').click() + await ksamplerNode.getByText('KSampler').click({ modifiers: ['Control'] }) + await comfyPage.page.keyboard.press(BYPASS_HOTKEY) + await expect(checkpointNode).toHaveClass(BYPASS_CLASS) + await expect(ksamplerNode).toHaveClass(BYPASS_CLASS) + + await comfyPage.page.keyboard.press(BYPASS_HOTKEY) + await expect(checkpointNode).not.toHaveClass(BYPASS_CLASS) + await expect(ksamplerNode).not.toHaveClass(BYPASS_CLASS) + }) +}) diff --git a/eslint.config.ts b/eslint.config.ts index 94f8bb5f2..04f4b2578 100644 --- a/eslint.config.ts +++ b/eslint.config.ts @@ -77,12 +77,25 @@ export default defineConfig([ '@typescript-eslint/prefer-as-const': 'off', '@typescript-eslint/consistent-type-imports': 'error', '@typescript-eslint/no-import-type-side-effects': 'error', + '@typescript-eslint/no-empty-object-type': [ + 'error', + { + allowInterfaces: 'always' + } + ], 'unused-imports/no-unused-imports': 'error', 'vue/no-v-html': 'off', // Enforce dark-theme: instead of dark: prefix 'vue/no-restricted-class': ['error', '/^dark:/'], 'vue/multi-word-component-names': 'off', // TODO: fix 'vue/no-template-shadow': 'off', // TODO: fix + /* Toggle on to do additional until we can clean up existing violations. + 'vue/no-unused-emit-declarations': 'error', + 'vue/no-unused-properties': 'error', + 'vue/no-unused-refs': 'error', + 'vue/no-use-v-else-with-v-for': 'error', + 'vue/no-useless-v-bind': 'error', + // */ 'vue/one-component-per-file': 'off', // TODO: fix 'vue/require-default-prop': 'off', // TODO: fix -- this one is very worthwhile // Restrict deprecated PrimeVue components diff --git a/index.html b/index.html index de7710c63..8684af476 100644 --- a/index.html +++ b/index.html @@ -8,8 +8,8 @@ - - + + diff --git a/package.json b/package.json index 770ef7e04..923f04b7e 100644 --- a/package.json +++ b/package.json @@ -27,6 +27,8 @@ "preview": "nx preview", "lint": "eslint src --cache", "lint:fix": "eslint src --cache --fix", + "lint:unstaged": "git diff --name-only HEAD | grep -E '\\.(js|ts|vue|mts)$' | xargs -r eslint --cache", + "lint:unstaged:fix": "git diff --name-only HEAD | grep -E '\\.(js|ts|vue|mts)$' | xargs -r eslint --cache --fix", "lint:no-cache": "eslint src", "lint:fix:no-cache": "eslint src --fix", "knip": "knip --cache", @@ -94,6 +96,7 @@ "vite-plugin-html": "^3.2.2", "vite-plugin-vue-devtools": "^7.7.6", "vitest": "^3.2.4", + "vue-component-type-helpers": "^3.0.7", "vue-eslint-parser": "^10.2.0", "vue-tsc": "^3.0.7", "zip-dir": "^2.0.0", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 75fc42327..6ce1f4d34 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -339,6 +339,9 @@ importers: vitest: specifier: ^3.2.4 version: 3.2.4(@types/debug@4.1.12)(@types/node@20.14.10)(@vitest/ui@3.2.4)(happy-dom@15.11.0)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.39.2) + vue-component-type-helpers: + specifier: ^3.0.7 + version: 3.0.7 vue-eslint-parser: specifier: ^10.2.0 version: 10.2.0(eslint@9.35.0(jiti@2.4.2)) diff --git a/src/base/common/async.ts b/src/base/common/async.ts new file mode 100644 index 000000000..a97f6f1bd --- /dev/null +++ b/src/base/common/async.ts @@ -0,0 +1,98 @@ +/** + * Cross-browser async utilities for scheduling tasks during browser idle time + * with proper fallbacks for browsers that don't support requestIdleCallback. + * + * Implementation based on: + * https://github.com/microsoft/vscode/blob/main/src/vs/base/common/async.ts + */ + +interface IdleDeadline { + didTimeout: boolean + timeRemaining(): number +} + +interface IDisposable { + dispose(): void +} + +/** + * Internal implementation function that handles the actual scheduling logic. + * Uses feature detection to determine whether to use native requestIdleCallback + * or fall back to setTimeout-based implementation. + */ +let _runWhenIdle: ( + targetWindow: any, + callback: (idle: IdleDeadline) => void, + timeout?: number +) => IDisposable + +/** + * Execute the callback during the next browser idle period. + * Falls back to setTimeout-based scheduling in browsers without native support. + */ +export let runWhenGlobalIdle: ( + callback: (idle: IdleDeadline) => void, + timeout?: number + ) => IDisposable + + // Self-invoking function to set up the idle callback implementation +;(function () { + const safeGlobal: any = globalThis + + if ( + typeof safeGlobal.requestIdleCallback !== 'function' || + typeof safeGlobal.cancelIdleCallback !== 'function' + ) { + // Fallback implementation for browsers without native support (e.g., Safari) + _runWhenIdle = (_targetWindow, runner, _timeout?) => { + setTimeout(() => { + if (disposed) { + return + } + + // Simulate IdleDeadline - give 15ms window (one frame at ~64fps) + const end = Date.now() + 15 + const deadline: IdleDeadline = { + didTimeout: true, + timeRemaining() { + return Math.max(0, end - Date.now()) + } + } + + runner(Object.freeze(deadline)) + }) + + let disposed = false + return { + dispose() { + if (disposed) { + return + } + disposed = true + } + } + } + } else { + // Native requestIdleCallback implementation + _runWhenIdle = (targetWindow: typeof safeGlobal, runner, timeout?) => { + const handle: number = targetWindow.requestIdleCallback( + runner, + typeof timeout === 'number' ? { timeout } : undefined + ) + + let disposed = false + return { + dispose() { + if (disposed) { + return + } + disposed = true + targetWindow.cancelIdleCallback(handle) + } + } + } + } + + runWhenGlobalIdle = (runner, timeout) => + _runWhenIdle(globalThis, runner, timeout) +})() diff --git a/src/components/graph/GraphCanvas.vue b/src/components/graph/GraphCanvas.vue index 2bbb45a87..50f7a4b3d 100644 --- a/src/components/graph/GraphCanvas.vue +++ b/src/components/graph/GraphCanvas.vue @@ -29,7 +29,7 @@ /> @@ -75,7 +72,6 @@ import { nextTick, onMounted, onUnmounted, - provide, ref, shallowRef, watch, @@ -113,14 +109,11 @@ import { useWorkflowStore } from '@/platform/workflow/management/stores/workflow import { useWorkflowAutoSave } from '@/platform/workflow/persistence/composables/useWorkflowAutoSave' import { useWorkflowPersistence } from '@/platform/workflow/persistence/composables/useWorkflowPersistence' import { useCanvasStore } from '@/renderer/core/canvas/canvasStore' -import { SelectedNodeIdsKey } from '@/renderer/core/canvas/injectionKeys' import { attachSlotLinkPreviewRenderer } from '@/renderer/core/canvas/links/slotLinkPreviewRenderer' import { useCanvasInteractions } from '@/renderer/core/canvas/useCanvasInteractions' import TransformPane from '@/renderer/core/layout/transform/TransformPane.vue' import MiniMap from '@/renderer/extensions/minimap/MiniMap.vue' import VueGraphNode from '@/renderer/extensions/vueNodes/components/LGraphNode.vue' -import { useNodeEventHandlers } from '@/renderer/extensions/vueNodes/composables/useNodeEventHandlers' -import { useExecutionStateProvider } from '@/renderer/extensions/vueNodes/execution/useExecutionStateProvider' import { UnauthorizedError, api } from '@/scripts/api' import { app as comfyApp } from '@/scripts/app' import { ChangeTracker } from '@/scripts/changeTracker' @@ -167,20 +160,13 @@ const minimapEnabled = computed(() => settingStore.get('Comfy.Minimap.Visible')) // Feature flags const { shouldRenderVueNodes } = useVueFeatureFlags() -const isVueNodesEnabled = computed(() => shouldRenderVueNodes.value) // Vue node system -const vueNodeLifecycle = useVueNodeLifecycle(isVueNodesEnabled) -const viewportCulling = useViewportCulling( - isVueNodesEnabled, - vueNodeLifecycle.vueNodeData, - vueNodeLifecycle.nodeDataTrigger, - vueNodeLifecycle.nodeManager -) -const nodeEventHandlers = useNodeEventHandlers(vueNodeLifecycle.nodeManager) +const vueNodeLifecycle = useVueNodeLifecycle() +const viewportCulling = useViewportCulling() const handleVueNodeLifecycleReset = async () => { - if (isVueNodesEnabled.value) { + if (shouldRenderVueNodes.value) { vueNodeLifecycle.disposeNodeManagerAndSyncs() await nextTick() vueNodeLifecycle.initializeNodeManager() @@ -208,23 +194,6 @@ const handleTransformUpdate = () => { // TODO: Fix paste position sync in separate PR vueNodeLifecycle.detectChangesInRAF.value() } -const handleNodeSelect = nodeEventHandlers.handleNodeSelect -const handleNodeCollapse = nodeEventHandlers.handleNodeCollapse -const handleNodeTitleUpdate = nodeEventHandlers.handleNodeTitleUpdate - -// Provide selection state to all Vue nodes -const selectedNodeIds = computed( - () => - new Set( - canvasStore.selectedItems - .filter((item) => item.id !== undefined) - .map((item) => String(item.id)) - ) -) -provide(SelectedNodeIdsKey, selectedNodeIds) - -// Provide execution state to all Vue nodes -useExecutionStateProvider() watchEffect(() => { nodeDefStore.showDeprecated = settingStore.get('Comfy.Node.ShowDeprecated') diff --git a/src/components/graph/SelectionToolbox.vue b/src/components/graph/SelectionToolbox.vue index 067b04346..f0a18ed3e 100644 --- a/src/components/graph/SelectionToolbox.vue +++ b/src/components/graph/SelectionToolbox.vue @@ -11,7 +11,7 @@ :style="`backgroundColor: ${containerStyles.backgroundColor};`" :pt="{ header: 'hidden', - content: 'px-1 py-1 h-10 px-1 flex flex-row gap-1' + content: 'p-1 h-10 flex flex-row gap-1' }" @wheel="canvasInteractions.handleWheel" > diff --git a/src/composables/auth/useCurrentUser.ts b/src/composables/auth/useCurrentUser.ts index 2c70be227..37b9e4866 100644 --- a/src/composables/auth/useCurrentUser.ts +++ b/src/composables/auth/useCurrentUser.ts @@ -1,4 +1,5 @@ -import { computed, watch } from 'vue' +import { whenever } from '@vueuse/core' +import { computed } from 'vue' import { useFirebaseAuthActions } from '@/composables/auth/useFirebaseAuthActions' import { t } from '@/i18n' @@ -33,19 +34,8 @@ export const useCurrentUser = () => { return null }) - const onUserResolved = (callback: (user: AuthUserInfo) => void) => { - if (resolvedUserInfo.value) { - callback(resolvedUserInfo.value) - } - - const stop = watch(resolvedUserInfo, (value) => { - if (value) { - callback(value) - } - }) - - return () => stop() - } + const onUserResolved = (callback: (user: AuthUserInfo) => void) => + whenever(resolvedUserInfo, callback, { immediate: true }) const userDisplayName = computed(() => { if (isApiKeyLogin.value) { diff --git a/src/composables/graph/useGraphNodeManager.ts b/src/composables/graph/useGraphNodeManager.ts index ae430987a..618b3087a 100644 --- a/src/composables/graph/useGraphNodeManager.ts +++ b/src/composables/graph/useGraphNodeManager.ts @@ -68,7 +68,7 @@ interface SpatialMetrics { nodesInIndex: number } -interface GraphNodeManager { +export interface GraphNodeManager { // Reactive state - safe data extracted from LiteGraph nodes vueNodeData: ReadonlyMap nodeState: ReadonlyMap diff --git a/src/composables/graph/useViewportCulling.ts b/src/composables/graph/useViewportCulling.ts index 6fc835e7e..f311af01c 100644 --- a/src/composables/graph/useViewportCulling.ts +++ b/src/composables/graph/useViewportCulling.ts @@ -6,26 +6,20 @@ * 2. Set display none on element to avoid cascade resolution overhead * 3. Only run when transform changes (event driven) */ -import { type Ref, computed } from 'vue' +import { computed } from 'vue' -import type { VueNodeData } from '@/composables/graph/useGraphNodeManager' +import { useVueNodeLifecycle } from '@/composables/graph/useVueNodeLifecycle' +import { useVueFeatureFlags } from '@/composables/useVueFeatureFlags' import { useCanvasStore } from '@/renderer/core/canvas/canvasStore' import { app as comfyApp } from '@/scripts/app' -interface NodeManager { - getNode: (id: string) => any -} - -export function useViewportCulling( - isVueNodesEnabled: Ref, - vueNodeData: Ref>, - nodeDataTrigger: Ref, - nodeManager: Ref -) { +export function useViewportCulling() { const canvasStore = useCanvasStore() + const { shouldRenderVueNodes } = useVueFeatureFlags() + const { vueNodeData, nodeDataTrigger, nodeManager } = useVueNodeLifecycle() const allNodes = computed(() => { - if (!isVueNodesEnabled.value) return [] + if (!shouldRenderVueNodes.value) return [] void nodeDataTrigger.value // Force re-evaluation when nodeManager initializes return Array.from(vueNodeData.value.values()) }) @@ -84,7 +78,7 @@ export function useViewportCulling( * Uses RAF to batch updates for smooth performance */ const handleTransformUpdate = () => { - if (!isVueNodesEnabled.value) return + if (!shouldRenderVueNodes.value) return // Cancel previous RAF if still pending if (rafId !== null) { diff --git a/src/composables/graph/useVueNodeLifecycle.ts b/src/composables/graph/useVueNodeLifecycle.ts index b481439dd..d2c1bcfcd 100644 --- a/src/composables/graph/useVueNodeLifecycle.ts +++ b/src/composables/graph/useVueNodeLifecycle.ts @@ -8,13 +8,16 @@ * - Reactive state management for node data, positions, and sizes * - Memory management and proper cleanup */ -import { type Ref, computed, readonly, ref, shallowRef, watch } from 'vue' +import { createSharedComposable } from '@vueuse/core' +import { computed, readonly, ref, shallowRef, watch } from 'vue' import { useGraphNodeManager } from '@/composables/graph/useGraphNodeManager' import type { + GraphNodeManager, NodeState, VueNodeData } from '@/composables/graph/useGraphNodeManager' +import { useVueFeatureFlags } from '@/composables/useVueFeatureFlags' import type { LGraphCanvas, LGraphNode } from '@/lib/litegraph/src/litegraph' import { useCanvasStore } from '@/renderer/core/canvas/canvasStore' import { useLayoutMutations } from '@/renderer/core/layout/operations/layoutMutations' @@ -24,13 +27,12 @@ import { useLinkLayoutSync } from '@/renderer/core/layout/sync/useLinkLayoutSync import { useSlotLayoutSync } from '@/renderer/core/layout/sync/useSlotLayoutSync' import { app as comfyApp } from '@/scripts/app' -export function useVueNodeLifecycle(isVueNodesEnabled: Ref) { +function useVueNodeLifecycleIndividual() { const canvasStore = useCanvasStore() const layoutMutations = useLayoutMutations() + const { shouldRenderVueNodes } = useVueFeatureFlags() - const nodeManager = shallowRef | null>( - null - ) + const nodeManager = shallowRef(null) const cleanupNodeManager = shallowRef<(() => void) | null>(null) // Sync management @@ -145,7 +147,7 @@ export function useVueNodeLifecycle(isVueNodesEnabled: Ref) { // Watch for Vue nodes enabled state changes watch( () => - isVueNodesEnabled.value && + shouldRenderVueNodes.value && Boolean(comfyApp.canvas?.graph || comfyApp.graph), (enabled) => { if (enabled) { @@ -159,7 +161,7 @@ export function useVueNodeLifecycle(isVueNodesEnabled: Ref) { // Consolidated watch for slot layout sync management watch( - [() => canvasStore.canvas, () => isVueNodesEnabled.value], + [() => canvasStore.canvas, () => shouldRenderVueNodes.value], ([canvas, vueMode], [, oldVueMode]) => { const modeChanged = vueMode !== oldVueMode @@ -191,7 +193,7 @@ export function useVueNodeLifecycle(isVueNodesEnabled: Ref) { // Handle case where Vue nodes are enabled but graph starts empty const setupEmptyGraphListener = () => { if ( - isVueNodesEnabled.value && + shouldRenderVueNodes.value && comfyApp.graph && !nodeManager.value && comfyApp.graph._nodes.length === 0 @@ -202,7 +204,7 @@ export function useVueNodeLifecycle(isVueNodesEnabled: Ref) { comfyApp.graph.onNodeAdded = originalOnNodeAdded // Initialize node manager if needed - if (isVueNodesEnabled.value && !nodeManager.value) { + if (shouldRenderVueNodes.value && !nodeManager.value) { initializeNodeManager() } @@ -248,3 +250,7 @@ export function useVueNodeLifecycle(isVueNodesEnabled: Ref) { cleanup } } + +export const useVueNodeLifecycle = createSharedComposable( + useVueNodeLifecycleIndividual +) diff --git a/src/composables/node/useNodePricing.ts b/src/composables/node/useNodePricing.ts index e85e6adb6..91f957463 100644 --- a/src/composables/node/useNodePricing.ts +++ b/src/composables/node/useNodePricing.ts @@ -1548,6 +1548,71 @@ const apiNodeCosts: Record = }, ByteDanceImageReferenceNode: { displayPrice: byteDanceVideoPricingCalculator + }, + WanTextToVideoApi: { + displayPrice: (node: LGraphNode): string => { + const durationWidget = node.widgets?.find( + (w) => w.name === 'duration' + ) as IComboWidget + const resolutionWidget = node.widgets?.find( + (w) => w.name === 'size' + ) as IComboWidget + + if (!durationWidget || !resolutionWidget) return '$0.05-0.15/second' + + const seconds = parseFloat(String(durationWidget.value)) + const resolutionStr = String(resolutionWidget.value).toLowerCase() + + const resKey = resolutionStr.includes('1080') + ? '1080p' + : resolutionStr.includes('720') + ? '720p' + : resolutionStr.includes('480') + ? '480p' + : resolutionStr.match(/^\s*(\d{3,4}p)/)?.[1] ?? '' + + const pricePerSecond: Record = { + '480p': 0.05, + '720p': 0.1, + '1080p': 0.15 + } + + const pps = pricePerSecond[resKey] + if (isNaN(seconds) || !pps) return '$0.05-0.15/second' + + const cost = (pps * seconds).toFixed(2) + return `$${cost}/Run` + } + }, + WanImageToVideoApi: { + displayPrice: (node: LGraphNode): string => { + const durationWidget = node.widgets?.find( + (w) => w.name === 'duration' + ) as IComboWidget + const resolutionWidget = node.widgets?.find( + (w) => w.name === 'resolution' + ) as IComboWidget + + if (!durationWidget || !resolutionWidget) return '$0.05-0.15/second' + + const seconds = parseFloat(String(durationWidget.value)) + const resolution = String(resolutionWidget.value).trim().toLowerCase() + + const pricePerSecond: Record = { + '480p': 0.05, + '720p': 0.1, + '1080p': 0.15 + } + + const pps = pricePerSecond[resolution] + if (isNaN(seconds) || !pps) return '$0.05-0.15/second' + + const cost = (pps * seconds).toFixed(2) + return `$${cost}/Run` + } + }, + WanTextToImageApi: { + displayPrice: '$0.03/Run' } } @@ -1647,7 +1712,9 @@ export const useNodePricing = () => { ByteDanceTextToVideoNode: ['model', 'duration', 'resolution'], ByteDanceImageToVideoNode: ['model', 'duration', 'resolution'], ByteDanceFirstLastFrameNode: ['model', 'duration', 'resolution'], - ByteDanceImageReferenceNode: ['model', 'duration', 'resolution'] + ByteDanceImageReferenceNode: ['model', 'duration', 'resolution'], + WanTextToVideoApi: ['duration', 'size'], + WanImageToVideoApi: ['duration', 'resolution'] } return widgetMap[nodeType] || [] } diff --git a/src/composables/useVueFeatureFlags.ts b/src/composables/useVueFeatureFlags.ts index 87836c221..f863fc187 100644 --- a/src/composables/useVueFeatureFlags.ts +++ b/src/composables/useVueFeatureFlags.ts @@ -2,16 +2,17 @@ * Vue-related feature flags composable * Manages local settings-driven flags and LiteGraph integration */ +import { createSharedComposable } from '@vueuse/core' import { computed, watch } from 'vue' import { useSettingStore } from '@/platform/settings/settingStore' import { LiteGraph } from '../lib/litegraph/src/litegraph' -export const useVueFeatureFlags = () => { +function useVueFeatureFlagsIndividual() { const settingStore = useSettingStore() - const isVueNodesEnabled = computed(() => { + const shouldRenderVueNodes = computed(() => { try { return settingStore.get('Comfy.VueNodes.Enabled') ?? false } catch { @@ -19,20 +20,20 @@ export const useVueFeatureFlags = () => { } }) - // Whether Vue nodes should render - const shouldRenderVueNodes = computed(() => isVueNodesEnabled.value) - - // Sync the Vue nodes flag with LiteGraph global settings - const syncVueNodesFlag = () => { - LiteGraph.vueNodesMode = isVueNodesEnabled.value - } - // Watch for changes and update LiteGraph immediately - watch(isVueNodesEnabled, syncVueNodesFlag, { immediate: true }) + watch( + shouldRenderVueNodes, + () => { + LiteGraph.vueNodesMode = shouldRenderVueNodes.value + }, + { immediate: true } + ) return { - isVueNodesEnabled, - shouldRenderVueNodes, - syncVueNodesFlag + shouldRenderVueNodes } } + +export const useVueFeatureFlags = createSharedComposable( + useVueFeatureFlagsIndividual +) diff --git a/src/constants/coreKeybindings.ts b/src/constants/coreKeybindings.ts index b4245f789..fe2bde835 100644 --- a/src/constants/coreKeybindings.ts +++ b/src/constants/coreKeybindings.ts @@ -122,14 +122,14 @@ export const CORE_KEYBINDINGS: Keybinding[] = [ key: '.' }, commandId: 'Comfy.Canvas.FitView', - targetElementId: 'graph-canvas' + targetElementId: 'graph-canvas-container' }, { combo: { key: 'p' }, commandId: 'Comfy.Canvas.ToggleSelected.Pin', - targetElementId: 'graph-canvas' + targetElementId: 'graph-canvas-container' }, { combo: { @@ -137,7 +137,7 @@ export const CORE_KEYBINDINGS: Keybinding[] = [ alt: true }, commandId: 'Comfy.Canvas.ToggleSelectedNodes.Collapse', - targetElementId: 'graph-canvas' + targetElementId: 'graph-canvas-container' }, { combo: { @@ -145,7 +145,7 @@ export const CORE_KEYBINDINGS: Keybinding[] = [ ctrl: true }, commandId: 'Comfy.Canvas.ToggleSelectedNodes.Bypass', - targetElementId: 'graph-canvas' + targetElementId: 'graph-canvas-container' }, { combo: { @@ -153,7 +153,7 @@ export const CORE_KEYBINDINGS: Keybinding[] = [ ctrl: true }, commandId: 'Comfy.Canvas.ToggleSelectedNodes.Mute', - targetElementId: 'graph-canvas' + targetElementId: 'graph-canvas-container' }, { combo: { diff --git a/src/i18n.ts b/src/i18n.ts index 102ac2600..38a8dfe95 100644 --- a/src/i18n.ts +++ b/src/i18n.ts @@ -76,7 +76,7 @@ export const i18n = createI18n({ }) /** Convenience shorthand: i18n.global */ -export const { t, te } = i18n.global +export const { t, te, d } = i18n.global /** * Safe translation function that returns the fallback message if the key is not found. diff --git a/src/lib/litegraph/src/LLink.ts b/src/lib/litegraph/src/LLink.ts index 58ae4e090..71b41f23d 100644 --- a/src/lib/litegraph/src/LLink.ts +++ b/src/lib/litegraph/src/LLink.ts @@ -205,7 +205,7 @@ export class LLink implements LinkSegment, Serialisable { network: Pick, linkSegment: LinkSegment ): Reroute[] { - if (!linkSegment.parentId) return [] + if (linkSegment.parentId === undefined) return [] return network.reroutes.get(linkSegment.parentId)?.getReroutes() ?? [] } @@ -229,7 +229,7 @@ export class LLink implements LinkSegment, Serialisable { linkSegment: LinkSegment, rerouteId: RerouteId ): Reroute | null | undefined { - if (!linkSegment.parentId) return + if (linkSegment.parentId === undefined) return return network.reroutes .get(linkSegment.parentId) ?.findNextReroute(rerouteId) @@ -498,7 +498,7 @@ export class LLink implements LinkSegment, Serialisable { target_slot: this.target_slot, type: this.type } - if (this.parentId) copy.parentId = this.parentId + if (this.parentId !== undefined) copy.parentId = this.parentId return copy } } diff --git a/src/lib/litegraph/src/litegraph.ts b/src/lib/litegraph/src/litegraph.ts index 46202a219..098c30e7a 100644 --- a/src/lib/litegraph/src/litegraph.ts +++ b/src/lib/litegraph/src/litegraph.ts @@ -48,7 +48,6 @@ export interface LinkReleaseContextExtended { links: ConnectingLink[] } -// eslint-disable-next-line @typescript-eslint/no-empty-object-type export interface LiteGraphCanvasEvent extends CustomEvent {} export interface LGraphNodeConstructor { @@ -140,7 +139,7 @@ export { BaseWidget } from './widgets/BaseWidget' export { LegacyWidget } from './widgets/LegacyWidget' -export { isComboWidget } from './widgets/widgetMap' +export { isComboWidget, isAssetWidget } from './widgets/widgetMap' // Additional test-specific exports export { LGraphButton } from './LGraphButton' export { MovingOutputLink } from './canvas/MovingOutputLink' diff --git a/src/lib/litegraph/src/widgets/AssetWidget.ts b/src/lib/litegraph/src/widgets/AssetWidget.ts index f8a8e1209..1a5047beb 100644 --- a/src/lib/litegraph/src/widgets/AssetWidget.ts +++ b/src/lib/litegraph/src/widgets/AssetWidget.ts @@ -13,6 +13,22 @@ export class AssetWidget this.value = widget.value?.toString() ?? '' } + override set value(value: IAssetWidget['value']) { + const oldValue = this.value + super.value = value + + // Force canvas redraw when value changes to show update immediately + if (oldValue !== value && this.node.graph?.list_of_graphcanvas) { + for (const canvas of this.node.graph.list_of_graphcanvas) { + canvas.setDirty(true) + } + } + } + + override get value(): IAssetWidget['value'] { + return super.value + } + override get _displayValue(): string { return String(this.value) //FIXME: Resolve asset name } diff --git a/src/lib/litegraph/src/widgets/widgetMap.ts b/src/lib/litegraph/src/widgets/widgetMap.ts index 02cdb5597..0e6a34fe5 100644 --- a/src/lib/litegraph/src/widgets/widgetMap.ts +++ b/src/lib/litegraph/src/widgets/widgetMap.ts @@ -1,5 +1,6 @@ import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode' import type { + IAssetWidget, IBaseWidget, IComboWidget, IWidget, @@ -132,4 +133,9 @@ export function isComboWidget(widget: IBaseWidget): widget is IComboWidget { return widget.type === 'combo' } +/** Type guard: Narrow **from {@link IBaseWidget}** to {@link IAssetWidget}. */ +export function isAssetWidget(widget: IBaseWidget): widget is IAssetWidget { + return widget.type === 'asset' +} + // #endregion Type Guards diff --git a/src/locales/en/main.json b/src/locales/en/main.json index 3abd85335..74f0352ae 100644 --- a/src/locales/en/main.json +++ b/src/locales/en/main.json @@ -1873,6 +1873,13 @@ "noModelsInFolder": "No {type} available in this folder", "searchAssetsPlaceholder": "Search assets...", "allModels": "All Models", - "unknown": "Unknown" + "unknown": "Unknown", + "fileFormats": "File formats", + "baseModels": "Base models", + "sortBy": "Sort by", + "sortAZ": "A-Z", + "sortZA": "Z-A", + "sortRecent": "Recent", + "sortPopular": "Popular" } } diff --git a/src/platform/assets/components/AssetBrowserModal.stories.ts b/src/platform/assets/components/AssetBrowserModal.stories.ts index acc93181d..9d2321d57 100644 --- a/src/platform/assets/components/AssetBrowserModal.stories.ts +++ b/src/platform/assets/components/AssetBrowserModal.stories.ts @@ -1,6 +1,7 @@ import type { Meta, StoryObj } from '@storybook/vue3-vite' import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vue' +import type { AssetDisplayItem } from '@/platform/assets/composables/useAssetBrowser' import { createMockAssets, mockAssets @@ -56,7 +57,7 @@ export const Default: Story = { render: (args) => ({ components: { AssetBrowserModal }, setup() { - const onAssetSelect = (asset: any) => { + const onAssetSelect = (asset: AssetDisplayItem) => { console.log('Selected asset:', asset) } const onClose = () => { @@ -96,7 +97,7 @@ export const SingleAssetType: Story = { render: (args) => ({ components: { AssetBrowserModal }, setup() { - const onAssetSelect = (asset: any) => { + const onAssetSelect = (asset: AssetDisplayItem) => { console.log('Selected asset:', asset) } const onClose = () => { @@ -145,7 +146,7 @@ export const NoLeftPanel: Story = { render: (args) => ({ components: { AssetBrowserModal }, setup() { - const onAssetSelect = (asset: any) => { + const onAssetSelect = (asset: AssetDisplayItem) => { console.log('Selected asset:', asset) } const onClose = () => { diff --git a/src/platform/assets/components/AssetBrowserModal.vue b/src/platform/assets/components/AssetBrowserModal.vue index de05f437d..cb45f38ba 100644 --- a/src/platform/assets/components/AssetBrowserModal.vue +++ b/src/platform/assets/components/AssetBrowserModal.vue @@ -12,7 +12,7 @@ :nav-items="availableCategories" > @@ -37,7 +37,7 @@ diff --git a/src/platform/assets/components/AssetCard.vue b/src/platform/assets/components/AssetCard.vue index e379099c1..be7c45ca5 100644 --- a/src/platform/assets/components/AssetCard.vue +++ b/src/platform/assets/components/AssetCard.vue @@ -14,7 +14,7 @@ 'bg-ivory-100 border border-gray-300 dark-theme:bg-charcoal-400 dark-theme:border-charcoal-600', 'hover:transform hover:-translate-y-0.5 hover:shadow-lg hover:shadow-black/10 hover:border-gray-400', 'dark-theme:hover:shadow-lg dark-theme:hover:shadow-black/30 dark-theme:hover:border-charcoal-700', - 'focus:outline-none focus:ring-2 focus:ring-blue-500 dark-theme:focus:ring-blue-400' + 'focus:outline-none focus:transform focus:-translate-y-0.5 focus:shadow-lg focus:shadow-black/10 dark-theme:focus:shadow-black/30' ], // Div-specific styles !interactive && [ diff --git a/src/platform/assets/components/AssetFilterBar.vue b/src/platform/assets/components/AssetFilterBar.vue index 1f3295b43..904ce3e82 100644 --- a/src/platform/assets/components/AssetFilterBar.vue +++ b/src/platform/assets/components/AssetFilterBar.vue @@ -3,7 +3,7 @@
void + ): Promise { if (import.meta.env.DEV) { - console.log('Asset selected:', asset.id, asset.name) + console.debug('Asset selected:', assetId) + } + + if (!onSelect) { + return + } + + try { + const detailAsset = await assetService.getAssetDetails(assetId) + const filename = detailAsset.user_metadata?.filename + const validatedFilename = assetFilenameSchema.safeParse(filename) + if (!validatedFilename.success) { + console.error( + 'Invalid asset filename:', + validatedFilename.error.errors, + 'for asset:', + assetId + ) + return + } + + onSelect(validatedFilename.data) + } catch (error) { + console.error(`Failed to fetch asset details for ${assetId}:`, error) } - return asset.id } return { @@ -182,7 +212,6 @@ export function useAssetBrowser(assets: AssetItem[] = []) { filteredAssets, // Actions - selectAsset, - transformAssetForDisplay + selectAssetWithCallback } } diff --git a/src/platform/assets/composables/useAssetBrowserDialog.ts b/src/platform/assets/composables/useAssetBrowserDialog.ts index e5f63eead..31f75c353 100644 --- a/src/platform/assets/composables/useAssetBrowserDialog.ts +++ b/src/platform/assets/composables/useAssetBrowserDialog.ts @@ -1,5 +1,7 @@ import AssetBrowserModal from '@/platform/assets/components/AssetBrowserModal.vue' -import { useDialogStore } from '@/stores/dialogStore' +import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import { assetService } from '@/platform/assets/services/assetService' +import { type DialogComponentProps, useDialogStore } from '@/stores/dialogStore' interface AssetBrowserDialogProps { /** ComfyUI node type for context (e.g., 'CheckpointLoaderSimple') */ @@ -8,36 +10,29 @@ interface AssetBrowserDialogProps { inputName: string /** Current selected asset value */ currentValue?: string - /** Callback for when an asset is selected */ - onAssetSelected?: (assetPath: string) => void + /** + * Callback for when an asset is selected + * @param {string} filename - The validated filename from user_metadata.filename + */ + onAssetSelected?: (filename: string) => void } export const useAssetBrowserDialog = () => { const dialogStore = useDialogStore() const dialogKey = 'global-asset-browser' - function hide() { - dialogStore.closeDialog({ key: dialogKey }) - } - - function show(props: AssetBrowserDialogProps) { - const handleAssetSelected = (assetPath: string) => { - props.onAssetSelected?.(assetPath) - hide() // Auto-close on selection + async function show(props: AssetBrowserDialogProps) { + const handleAssetSelected = (filename: string) => { + props.onAssetSelected?.(filename) + dialogStore.closeDialog({ key: dialogKey }) } - - const handleClose = () => { - hide() - } - - // Default dialog configuration for AssetBrowserModal - const dialogComponentProps = { + const dialogComponentProps: DialogComponentProps = { headless: true, modal: true, - closable: false, + closable: true, pt: { root: { - class: 'rounded-2xl overflow-hidden' + class: 'rounded-2xl overflow-hidden asset-browser-dialog' }, header: { class: 'p-0 hidden' @@ -48,6 +43,17 @@ export const useAssetBrowserDialog = () => { } } + const assets: AssetItem[] = await assetService + .getAssetsForNodeType(props.nodeType) + .catch((error) => { + console.error( + 'Failed to fetch assets for node type:', + props.nodeType, + error + ) + return [] + }) + dialogStore.showDialog({ key: dialogKey, component: AssetBrowserModal, @@ -55,12 +61,13 @@ export const useAssetBrowserDialog = () => { nodeType: props.nodeType, inputName: props.inputName, currentValue: props.currentValue, + assets, onSelect: handleAssetSelected, - onClose: handleClose + onClose: () => dialogStore.closeDialog({ key: dialogKey }) }, dialogComponentProps }) } - return { show, hide } + return { show } } diff --git a/src/platform/assets/schemas/assetSchema.ts b/src/platform/assets/schemas/assetSchema.ts index fab41649a..2c051a30d 100644 --- a/src/platform/assets/schemas/assetSchema.ts +++ b/src/platform/assets/schemas/assetSchema.ts @@ -4,13 +4,13 @@ import { z } from 'zod' const zAsset = z.object({ id: z.string(), name: z.string(), - asset_hash: z.string(), + asset_hash: z.string().nullable(), size: z.number(), - mime_type: z.string(), + mime_type: z.string().nullable(), tags: z.array(z.string()), preview_url: z.string().optional(), created_at: z.string(), - updated_at: z.string(), + updated_at: z.string().optional(), last_access_time: z.string(), user_metadata: z.record(z.unknown()).optional(), // API allows arbitrary key-value pairs preview_id: z.string().nullable().optional() @@ -33,6 +33,14 @@ const zModelFile = z.object({ pathIndex: z.number() }) +// Filename validation schema +export const assetFilenameSchema = z + .string() + .min(1, 'Filename cannot be empty') + .regex(/^[^\\:*?"<>|]+$/, 'Invalid filename characters') // Allow forward slashes, block backslashes and other unsafe chars + .regex(/^(?!\/|.*\.\.)/, 'Path must not start with / or contain ..') // Prevent absolute paths and directory traversal + .trim() + // Export schemas following repository patterns export const assetResponseSchema = zAssetResponse diff --git a/src/platform/assets/services/assetService.ts b/src/platform/assets/services/assetService.ts index 74b20a753..7d0f82cbb 100644 --- a/src/platform/assets/services/assetService.ts +++ b/src/platform/assets/services/assetService.ts @@ -1,6 +1,7 @@ import { fromZodError } from 'zod-validation-error' import { + type AssetItem, type AssetResponse, type ModelFile, type ModelFolder, @@ -127,10 +128,75 @@ function createAssetService() { ) } + /** + * Gets assets for a specific node type by finding the matching category + * and fetching all assets with that category tag + * + * @param nodeType - The ComfyUI node type (e.g., 'CheckpointLoaderSimple') + * @returns Promise - Full asset objects with preserved metadata + */ + async function getAssetsForNodeType(nodeType: string): Promise { + if (!nodeType || typeof nodeType !== 'string') { + return [] + } + + // Find the category for this node type using efficient O(1) lookup + const modelToNodeStore = useModelToNodeStore() + const category = modelToNodeStore.getCategoryForNodeType(nodeType) + + if (!category) { + return [] + } + + // Fetch assets for this category using same API pattern as getAssetModels + const data = await handleAssetRequest( + `${ASSETS_ENDPOINT}?include_tags=${MODELS_TAG},${category}`, + `assets for ${nodeType}` + ) + + // Return full AssetItem[] objects (don't strip like getAssetModels does) + return ( + data?.assets?.filter( + (asset) => + !asset.tags.includes(MISSING_TAG) && asset.tags.includes(category) + ) ?? [] + ) + } + + /** + * Gets complete details for a specific asset by ID + * Calls the detail endpoint which includes user_metadata and all fields + * + * @param id - The asset ID + * @returns Promise - Complete asset object with user_metadata + */ + async function getAssetDetails(id: string): Promise { + const res = await api.fetchApi(`${ASSETS_ENDPOINT}/${id}`) + if (!res.ok) { + throw new Error( + `Unable to load asset details for ${id}: Server returned ${res.status}. Please try again.` + ) + } + const data = await res.json() + + // Validate the single asset response against our schema + const result = assetResponseSchema.safeParse({ assets: [data] }) + if (result.success && result.data.assets?.[0]) { + return result.data.assets[0] + } + + const error = result.error + ? fromZodError(result.error) + : 'Unknown validation error' + throw new Error(`Invalid asset response against zod schema:\n${error}`) + } + return { getAssetModelFolders, getAssetModels, - isAssetBrowserEligible + isAssetBrowserEligible, + getAssetsForNodeType, + getAssetDetails } } diff --git a/src/platform/settings/constants/coreSettings.ts b/src/platform/settings/constants/coreSettings.ts index d592a92f0..4adf2db9d 100644 --- a/src/platform/settings/constants/coreSettings.ts +++ b/src/platform/settings/constants/coreSettings.ts @@ -595,7 +595,7 @@ export const CORE_SETTINGS: SettingParams[] = [ migrateDeprecatedValue: (value: any[]) => { return value.map((keybinding) => { if (keybinding['targetSelector'] === '#graph-canvas') { - keybinding['targetElementId'] = 'graph-canvas' + keybinding['targetElementId'] = 'graph-canvas-container' } return keybinding }) diff --git a/src/renderer/core/canvas/canvasStore.ts b/src/renderer/core/canvas/canvasStore.ts index 371035c08..ec38940fe 100644 --- a/src/renderer/core/canvas/canvasStore.ts +++ b/src/renderer/core/canvas/canvasStore.ts @@ -99,6 +99,16 @@ export const useCanvasStore = defineStore('canvas', () => { const currentGraph = shallowRef(null) const isInSubgraph = ref(false) + // Provide selection state to all Vue nodes + const selectedNodeIds = computed( + () => + new Set( + selectedItems.value + .filter((item) => item.id !== undefined) + .map((item) => String(item.id)) + ) + ) + whenever( () => canvas.value, (newCanvas) => { @@ -122,6 +132,7 @@ export const useCanvasStore = defineStore('canvas', () => { return { canvas, selectedItems, + selectedNodeIds, nodeSelected, groupSelected, rerouteSelected, diff --git a/src/renderer/core/canvas/injectionKeys.ts b/src/renderer/core/canvas/injectionKeys.ts deleted file mode 100644 index 5c850c100..000000000 --- a/src/renderer/core/canvas/injectionKeys.ts +++ /dev/null @@ -1,25 +0,0 @@ -import type { InjectionKey, Ref } from 'vue' - -import type { NodeProgressState } from '@/schemas/apiSchema' - -/** - * Injection key for providing selected node IDs to Vue node components. - * Contains a reactive Set of selected node IDs (as strings). - */ -export const SelectedNodeIdsKey: InjectionKey>> = - Symbol('selectedNodeIds') - -/** - * Injection key for providing executing node IDs to Vue node components. - * Contains a reactive Set of currently executing node IDs (as strings). - */ -export const ExecutingNodeIdsKey: InjectionKey>> = - Symbol('executingNodeIds') - -/** - * Injection key for providing node progress states to Vue node components. - * Contains a reactive Record of node IDs to their current progress state. - */ -export const NodeProgressStatesKey: InjectionKey< - Ref> -> = Symbol('nodeProgressStates') diff --git a/src/renderer/core/layout/injectionKeys.ts b/src/renderer/core/layout/injectionKeys.ts index dd6efda21..8e0e0e1d6 100644 --- a/src/renderer/core/layout/injectionKeys.ts +++ b/src/renderer/core/layout/injectionKeys.ts @@ -1,6 +1,6 @@ import type { InjectionKey } from 'vue' -import type { Point } from '@/renderer/core/layout/types' +import type { useTransformState } from '@/renderer/core/layout/transform/useTransformState' /** * Lightweight, injectable transform state used by layout-aware components. @@ -21,29 +21,11 @@ import type { Point } from '@/renderer/core/layout/types' * const state = inject(TransformStateKey)! * const screen = state.canvasToScreen({ x: 100, y: 50 }) */ -interface TransformState { - /** Convert a screen-space point (CSS pixels) to canvas space. */ - screenToCanvas: (p: Point) => Point - /** Convert a canvas-space point to screen space (CSS pixels). */ - canvasToScreen: (p: Point) => Point - /** Current pan/zoom; `x`/`y` are offsets, `z` is scale. */ - camera?: { x: number; y: number; z: number } - /** - * Test whether a node's rectangle intersects the (expanded) viewport. - * Handy for viewport culling and lazy work. - * - * @param nodePos Top-left in canvas space `[x, y]` - * @param nodeSize Size in canvas units `[width, height]` - * @param viewport Screen-space viewport `{ width, height }` - * @param margin Optional fractional margin (e.g. `0.2` = 20%) - */ - isNodeInViewport?: ( - nodePos: ArrayLike, - nodeSize: ArrayLike, - viewport: { width: number; height: number }, - margin?: number - ) => boolean -} +interface TransformState + extends Pick< + ReturnType, + 'screenToCanvas' | 'canvasToScreen' | 'camera' | 'isNodeInViewport' + > {} export const TransformStateKey: InjectionKey = Symbol('transformState') diff --git a/src/renderer/core/layout/sync/useSlotLayoutSync.ts b/src/renderer/core/layout/sync/useSlotLayoutSync.ts index 618d1857f..281199e8b 100644 --- a/src/renderer/core/layout/sync/useSlotLayoutSync.ts +++ b/src/renderer/core/layout/sync/useSlotLayoutSync.ts @@ -134,7 +134,11 @@ export function useSlotLayoutSync() { restoreHandlers = () => { graph.onNodeAdded = origNodeAdded || undefined graph.onNodeRemoved = origNodeRemoved || undefined - graph.onTrigger = origTrigger || undefined + // Only restore onTrigger if Vue nodes are not active + // Vue node manager sets its own onTrigger handler + if (!LiteGraph.vueNodesMode) { + graph.onTrigger = origTrigger || undefined + } graph.onAfterChange = origAfterChange || undefined } diff --git a/src/renderer/extensions/vueNodes/components/LGraphNode.vue b/src/renderer/extensions/vueNodes/components/LGraphNode.vue index bd6480c38..56a0984fb 100644 --- a/src/renderer/extensions/vueNodes/components/LGraphNode.vue +++ b/src/renderer/extensions/vueNodes/components/LGraphNode.vue @@ -54,7 +54,7 @@ :lod-level="lodLevel" :collapsed="isCollapsed" @collapse="handleCollapse" - @update:title="handleTitleUpdate" + @update:title="handleHeaderTitleUpdate" @enter-subgraph="handleEnterSubgraph" />
@@ -101,7 +101,6 @@ :node-data="nodeData" :readonly="readonly" :lod-level="lodLevel" - @slot-click="handleSlotClick" /> @@ -139,29 +138,22 @@ diff --git a/src/renderer/extensions/vueNodes/components/NodeHeader.vue b/src/renderer/extensions/vueNodes/components/NodeHeader.vue index 3bd75024c..40b8a7fe0 100644 --- a/src/renderer/extensions/vueNodes/components/NodeHeader.vue +++ b/src/renderer/extensions/vueNodes/components/NodeHeader.vue @@ -63,7 +63,6 @@ import EditableText from '@/components/common/EditableText.vue' import type { VueNodeData } from '@/composables/graph/useGraphNodeManager' import { useErrorHandling } from '@/composables/useErrorHandling' import { useNodeTooltips } from '@/renderer/extensions/vueNodes/composables/useNodeTooltips' -import type { LODLevel } from '@/renderer/extensions/vueNodes/lod/useLOD' import { app } from '@/scripts/app' import { getLocatorIdFromNodeData, @@ -73,7 +72,6 @@ import { interface NodeHeaderProps { nodeData?: VueNodeData readonly?: boolean - lodLevel?: LODLevel collapsed?: boolean } diff --git a/src/renderer/extensions/vueNodes/components/NodeSlots.vue b/src/renderer/extensions/vueNodes/components/NodeSlots.vue index 931cacb4d..26187899d 100644 --- a/src/renderer/extensions/vueNodes/components/NodeSlots.vue +++ b/src/renderer/extensions/vueNodes/components/NodeSlots.vue @@ -35,7 +35,6 @@ import { computed, onErrorCaptured, ref } from 'vue' import type { VueNodeData } from '@/composables/graph/useGraphNodeManager' import { useErrorHandling } from '@/composables/useErrorHandling' import type { INodeSlot } from '@/lib/litegraph/src/litegraph' -import type { LODLevel } from '@/renderer/extensions/vueNodes/lod/useLOD' import { isSlotObject } from '@/utils/typeGuardUtil' import InputSlot from './InputSlot.vue' @@ -44,7 +43,6 @@ import OutputSlot from './OutputSlot.vue' interface NodeSlotsProps { nodeData?: VueNodeData readonly?: boolean - lodLevel?: LODLevel } const { nodeData = null, readonly } = defineProps() diff --git a/src/renderer/extensions/vueNodes/components/NodeWidgets.vue b/src/renderer/extensions/vueNodes/components/NodeWidgets.vue index 57461ce09..4645429da 100644 --- a/src/renderer/extensions/vueNodes/components/NodeWidgets.vue +++ b/src/renderer/extensions/vueNodes/components/NodeWidgets.vue @@ -12,9 +12,9 @@ : 'pointer-events-none' ) " - @pointerdown="handleWidgetPointerEvent" - @pointermove="handleWidgetPointerEvent" - @pointerup="handleWidgetPointerEvent" + @pointerdown.stop="handleWidgetPointerEvent" + @pointermove.stop="handleWidgetPointerEvent" + @pointerup.stop="handleWidgetPointerEvent" >
any -} - -export function useNodeEventHandlers(nodeManager: Ref) { +function useNodeEventHandlersIndividual() { const canvasStore = useCanvasStore() + const { nodeManager } = useVueNodeLifecycle() const { bringNodeToFront } = useNodeZIndex() const { shouldHandleNodePointerEvents } = useCanvasInteractions() @@ -40,7 +38,7 @@ export function useNodeEventHandlers(nodeManager: Ref) { const node = nodeManager.value.getNode(nodeData.id) if (!node) return - const isMultiSelect = event.ctrlKey || event.metaKey + const isMultiSelect = event.ctrlKey || event.metaKey || event.shiftKey if (isMultiSelect) { // Ctrl/Cmd+click -> toggle selection @@ -84,6 +82,7 @@ export function useNodeEventHandlers(nodeManager: Ref) { const currentCollapsed = node.flags?.collapsed ?? false if (currentCollapsed !== collapsed) { node.collapse() + nodeManager.value.scheduleUpdate(nodeId, 'critical') } } @@ -237,3 +236,7 @@ export function useNodeEventHandlers(nodeManager: Ref) { deselectNodes } } + +export const useNodeEventHandlers = createSharedComposable( + useNodeEventHandlersIndividual +) diff --git a/src/renderer/extensions/vueNodes/composables/useNodePointerInteractions.ts b/src/renderer/extensions/vueNodes/composables/useNodePointerInteractions.ts new file mode 100644 index 000000000..f5ba08374 --- /dev/null +++ b/src/renderer/extensions/vueNodes/composables/useNodePointerInteractions.ts @@ -0,0 +1,93 @@ +import { type MaybeRefOrGetter, computed, ref, toValue } from 'vue' + +import type { VueNodeData } from '@/composables/graph/useGraphNodeManager' +import { useCanvasInteractions } from '@/renderer/core/canvas/useCanvasInteractions' +import { layoutStore } from '@/renderer/core/layout/store/layoutStore' +import { useNodeLayout } from '@/renderer/extensions/vueNodes/layout/useNodeLayout' + +// Treat tiny pointer jitter as a click, not a drag +const DRAG_THRESHOLD_PX = 4 + +export function useNodePointerInteractions( + nodeDataMaybe: MaybeRefOrGetter, + onPointerUp: ( + event: PointerEvent, + nodeData: VueNodeData, + wasDragging: boolean + ) => void +) { + const nodeData = toValue(nodeDataMaybe) + + const { startDrag, endDrag, handleDrag } = useNodeLayout(nodeData.id) + // Use canvas interactions for proper wheel event handling and pointer event capture control + const { forwardEventToCanvas, shouldHandleNodePointerEvents } = + useCanvasInteractions() + + // Drag state for styling + const isDragging = ref(false) + const dragStyle = computed(() => ({ + cursor: isDragging.value ? 'grabbing' : 'grab' + })) + const lastX = ref(0) + const lastY = ref(0) + + const handlePointerDown = (event: PointerEvent) => { + if (!nodeData) { + console.warn( + 'LGraphNode: nodeData is null/undefined in handlePointerDown' + ) + return + } + + // Don't handle pointer events when canvas is in panning mode - forward to canvas instead + if (!shouldHandleNodePointerEvents.value) { + forwardEventToCanvas(event) + return + } + + // Start drag using layout system + isDragging.value = true + + // Set Vue node dragging state for selection toolbox + layoutStore.isDraggingVueNodes.value = true + + startDrag(event) + lastY.value = event.clientY + lastX.value = event.clientX + } + + const handlePointerMove = (event: PointerEvent) => { + if (isDragging.value) { + void handleDrag(event) + } + } + + const handlePointerUp = (event: PointerEvent) => { + if (isDragging.value) { + isDragging.value = false + void endDrag(event) + + // Clear Vue node dragging state for selection toolbox + layoutStore.isDraggingVueNodes.value = false + } + + // Don't emit node-click when canvas is in panning mode - forward to canvas instead + if (!shouldHandleNodePointerEvents.value) { + forwardEventToCanvas(event) + return + } + + // Emit node-click for selection handling in GraphCanvas + const dx = event.clientX - lastX.value + const dy = event.clientY - lastY.value + const wasDragging = Math.hypot(dx, dy) > DRAG_THRESHOLD_PX + onPointerUp(event, nodeData, wasDragging) + } + return { + isDragging, + dragStyle, + handlePointerMove, + handlePointerDown, + handlePointerUp + } +} diff --git a/src/renderer/extensions/vueNodes/composables/useNodeTooltips.ts b/src/renderer/extensions/vueNodes/composables/useNodeTooltips.ts index 0dd9922ef..034047471 100644 --- a/src/renderer/extensions/vueNodes/composables/useNodeTooltips.ts +++ b/src/renderer/extensions/vueNodes/composables/useNodeTooltips.ts @@ -93,10 +93,10 @@ export function useNodeTooltips( pt: { text: { class: - 'bg-charcoal-100 border border-slate-300 rounded-md px-4 py-2 text-white text-sm font-normal leading-tight max-w-75 shadow-none' + 'bg-charcoal-800 border border-slate-300 rounded-md px-4 py-2 text-white text-sm font-normal leading-tight max-w-75 shadow-none' }, arrow: { - class: 'before:border-charcoal-100' + class: 'before:border-slate-300' } } } diff --git a/src/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking.ts b/src/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking.ts index 4ce9f8e62..e8c38164d 100644 --- a/src/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking.ts +++ b/src/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking.ts @@ -8,7 +8,13 @@ * Supports different element types (nodes, slots, widgets, etc.) with * customizable data attributes and update handlers. */ -import { getCurrentInstance, onMounted, onUnmounted } from 'vue' +import { + type MaybeRefOrGetter, + getCurrentInstance, + onMounted, + onUnmounted, + toValue +} from 'vue' import { useSharedCanvasPositionConversion } from '@/composables/element/useCanvasPositionConversion' import { LiteGraph } from '@/lib/litegraph/src/litegraph' @@ -154,9 +160,10 @@ const resizeObserver = new ResizeObserver((entries) => { * ``` */ export function useVueElementTracking( - appIdentifier: string, + appIdentifierMaybe: MaybeRefOrGetter, trackingType: string ) { + const appIdentifier = toValue(appIdentifierMaybe) onMounted(() => { const element = getCurrentInstance()?.proxy?.$el if (!(element instanceof HTMLElement) || !appIdentifier) return diff --git a/src/renderer/extensions/vueNodes/execution/useExecutionStateProvider.ts b/src/renderer/extensions/vueNodes/execution/useExecutionStateProvider.ts deleted file mode 100644 index aae08298a..000000000 --- a/src/renderer/extensions/vueNodes/execution/useExecutionStateProvider.ts +++ /dev/null @@ -1,36 +0,0 @@ -import { storeToRefs } from 'pinia' -import { computed, provide } from 'vue' - -import { - ExecutingNodeIdsKey, - NodeProgressStatesKey -} from '@/renderer/core/canvas/injectionKeys' -import { useExecutionStore } from '@/stores/executionStore' - -/** - * Composable for providing execution state to Vue node children - * - * This composable sets up the execution state providers that can be injected - * by child Vue nodes using useNodeExecutionState. - * - * Should be used in the parent component that manages Vue nodes (e.g., GraphCanvas). - */ -export const useExecutionStateProvider = () => { - const executionStore = useExecutionStore() - const { executingNodeIds: storeExecutingNodeIds, nodeProgressStates } = - storeToRefs(executionStore) - - // Convert execution store data to the format expected by Vue nodes - const executingNodeIds = computed( - () => new Set(storeExecutingNodeIds.value.map(String)) - ) - - // Provide the execution state to all child Vue nodes - provide(ExecutingNodeIdsKey, executingNodeIds) - provide(NodeProgressStatesKey, nodeProgressStates) - - return { - executingNodeIds, - nodeProgressStates - } -} diff --git a/src/renderer/extensions/vueNodes/execution/useNodeExecutionState.ts b/src/renderer/extensions/vueNodes/execution/useNodeExecutionState.ts index 8f03e29e1..aa4867db9 100644 --- a/src/renderer/extensions/vueNodes/execution/useNodeExecutionState.ts +++ b/src/renderer/extensions/vueNodes/execution/useNodeExecutionState.ts @@ -1,10 +1,7 @@ -import { computed, inject, ref } from 'vue' +import { storeToRefs } from 'pinia' +import { type MaybeRefOrGetter, computed, toValue } from 'vue' -import { - ExecutingNodeIdsKey, - NodeProgressStatesKey -} from '@/renderer/core/canvas/injectionKeys' -import type { NodeProgressState } from '@/schemas/apiSchema' +import { useExecutionStore } from '@/stores/executionStore' /** * Composable for managing execution state of Vue-based nodes @@ -12,18 +9,18 @@ import type { NodeProgressState } from '@/schemas/apiSchema' * Provides reactive access to execution state and progress for a specific node * by injecting execution data from the parent GraphCanvas provider. * - * @param nodeId - The ID of the node to track execution state for + * @param nodeIdMaybe - The ID of the node to track execution state for * @returns Object containing reactive execution state and progress */ -export const useNodeExecutionState = (nodeId: string) => { - const executingNodeIds = inject(ExecutingNodeIdsKey, ref(new Set())) - const nodeProgressStates = inject( - NodeProgressStatesKey, - ref>({}) - ) +export const useNodeExecutionState = ( + nodeIdMaybe: MaybeRefOrGetter +) => { + const nodeId = toValue(nodeIdMaybe) + const { uniqueExecutingNodeIdStrings, nodeProgressStates } = + storeToRefs(useExecutionStore()) const executing = computed(() => { - return executingNodeIds.value.has(nodeId) + return uniqueExecutingNodeIdStrings.value.has(nodeId) }) const progress = computed(() => { diff --git a/src/renderer/extensions/vueNodes/layout/useNodeLayout.ts b/src/renderer/extensions/vueNodes/layout/useNodeLayout.ts index 18a085641..89718eb8d 100644 --- a/src/renderer/extensions/vueNodes/layout/useNodeLayout.ts +++ b/src/renderer/extensions/vueNodes/layout/useNodeLayout.ts @@ -1,12 +1,7 @@ -/** - * Composable for individual Vue node components - * - * Uses customRef for shared write access with Canvas renderer. - * Provides dragging functionality and reactive layout state. - */ -import { computed, inject } from 'vue' +import { storeToRefs } from 'pinia' +import { type MaybeRefOrGetter, computed, inject, toValue } from 'vue' -import { SelectedNodeIdsKey } from '@/renderer/core/canvas/injectionKeys' +import { useCanvasStore } from '@/renderer/core/canvas/canvasStore' import { TransformStateKey } from '@/renderer/core/layout/injectionKeys' import { useLayoutMutations } from '@/renderer/core/layout/operations/layoutMutations' import { layoutStore } from '@/renderer/core/layout/store/layoutStore' @@ -16,15 +11,16 @@ import { LayoutSource, type Point } from '@/renderer/core/layout/types' * Composable for individual Vue node components * Uses customRef for shared write access with Canvas renderer */ -export function useNodeLayout(nodeId: string) { - const store = layoutStore +export function useNodeLayout(nodeIdMaybe: MaybeRefOrGetter) { + const nodeId = toValue(nodeIdMaybe) const mutations = useLayoutMutations() + const { selectedNodeIds } = storeToRefs(useCanvasStore()) // Get transform utilities from TransformPane if available const transformState = inject(TransformStateKey) // Get the customRef for this node (shared write access) - const layoutRef = store.getNodeLayoutRef(nodeId) + const layoutRef = layoutStore.getNodeLayoutRef(nodeId) // Computed properties for easy access const position = computed(() => { @@ -53,8 +49,6 @@ export function useNodeLayout(nodeId: string) { let dragStartMouse: Point | null = null let otherSelectedNodesStartPositions: Map | null = null - const selectedNodeIds = inject(SelectedNodeIdsKey, null) - /** * Start dragging the node */ diff --git a/src/renderer/extensions/vueNodes/lod/useLOD.ts b/src/renderer/extensions/vueNodes/lod/useLOD.ts index 584e21f9a..87c1bb865 100644 --- a/src/renderer/extensions/vueNodes/lod/useLOD.ts +++ b/src/renderer/extensions/vueNodes/lod/useLOD.ts @@ -27,7 +27,7 @@ * * ``` */ -import { type Ref, computed, readonly } from 'vue' +import { type MaybeRefOrGetter, computed, readonly, toRef } from 'vue' export enum LODLevel { MINIMAL = 'minimal', // zoom <= 0.4 @@ -78,7 +78,8 @@ const LOD_CONFIGS: Record = { * @param zoomRef - Reactive reference to current zoom level (camera.z) * @returns LOD state and configuration */ -export function useLOD(zoomRef: Ref) { +export function useLOD(zoomRefMaybe: MaybeRefOrGetter) { + const zoomRef = toRef(zoomRefMaybe) // Continuous LOD score (0-1) for smooth transitions const lodScore = computed(() => { const zoom = zoomRef.value diff --git a/src/renderer/extensions/vueNodes/preview/useNodePreviewState.ts b/src/renderer/extensions/vueNodes/preview/useNodePreviewState.ts index 427cf20c0..8fc82147a 100644 --- a/src/renderer/extensions/vueNodes/preview/useNodePreviewState.ts +++ b/src/renderer/extensions/vueNodes/preview/useNodePreviewState.ts @@ -1,16 +1,17 @@ import { storeToRefs } from 'pinia' -import { type Ref, computed } from 'vue' +import { type MaybeRefOrGetter, type Ref, computed, toValue } from 'vue' import { useWorkflowStore } from '@/platform/workflow/management/stores/workflowStore' import { useNodeOutputStore } from '@/stores/imagePreviewStore' export const useNodePreviewState = ( - nodeId: string, + nodeIdMaybe: MaybeRefOrGetter, options?: { isMinimalLOD?: Ref isCollapsed?: Ref } ) => { + const nodeId = toValue(nodeIdMaybe) const workflowStore = useWorkflowStore() const { nodePreviewImages } = storeToRefs(useNodeOutputStore()) diff --git a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts index 2387fc59c..59705458c 100644 --- a/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts +++ b/src/renderer/extensions/vueNodes/widgets/composables/useComboWidget.ts @@ -3,10 +3,9 @@ import { ref } from 'vue' import MultiSelectWidget from '@/components/graph/widgets/MultiSelectWidget.vue' import { t } from '@/i18n' import type { LGraphNode } from '@/lib/litegraph/src/litegraph' -import type { - IBaseWidget, - IComboWidget -} from '@/lib/litegraph/src/types/widgets' +import { isAssetWidget, isComboWidget } from '@/lib/litegraph/src/litegraph' +import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets' +import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog' import { assetService } from '@/platform/assets/services/assetService' import { useSettingStore } from '@/platform/settings/settingStore' import { transformInputSpecV2ToV1 } from '@/schemas/nodeDef/migration' @@ -73,11 +72,29 @@ const addComboWidget = ( const currentValue = getDefaultValue(inputSpec) const displayLabel = currentValue ?? t('widgets.selectModel') - const widget = node.addWidget('asset', inputSpec.name, displayLabel, () => { - console.log( - `Asset Browser would open here for:\nNode: ${node.type}\nWidget: ${inputSpec.name}\nCurrent Value:${currentValue}` - ) - }) + const assetBrowserDialog = useAssetBrowserDialog() + + const widget = node.addWidget( + 'asset', + inputSpec.name, + displayLabel, + async () => { + if (!isAssetWidget(widget)) { + throw new Error(`Expected asset widget but received ${widget.type}`) + } + await assetBrowserDialog.show({ + nodeType: node.comfyClass || '', + inputName: inputSpec.name, + currentValue: widget.value, + onAssetSelected: (filename: string) => { + const oldValue = widget.value + widget.value = filename + // Using onWidgetChanged prevents a callback race where asset selection could reopen the dialog + node.onWidgetChanged?.(widget.name, filename, oldValue, widget) + } + }) + } + ) return widget } @@ -96,11 +113,14 @@ const addComboWidget = ( ) if (inputSpec.remote) { + if (!isComboWidget(widget)) { + throw new Error(`Expected combo widget but received ${widget.type}`) + } const remoteWidget = useRemoteWidget({ remoteConfig: inputSpec.remote, defaultValue, node, - widget: widget as IComboWidget + widget }) if (inputSpec.remote.refresh_button) remoteWidget.addRefreshButton() @@ -116,16 +136,19 @@ const addComboWidget = ( } if (inputSpec.control_after_generate) { + if (!isComboWidget(widget)) { + throw new Error(`Expected combo widget but received ${widget.type}`) + } widget.linkedWidgets = addValueControlWidgets( node, - widget as IComboWidget, + widget, undefined, undefined, transformInputSpecV2ToV1(inputSpec) ) } - return widget as IBaseWidget + return widget } export const useComboWidget = () => { diff --git a/src/scripts/app.ts b/src/scripts/app.ts index 43722f9b4..cd085eead 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -596,7 +596,10 @@ export class ComfyApp { const keybindingStore = useKeybindingStore() const keybinding = keybindingStore.getKeybinding(keyCombo) - if (keybinding && keybinding.targetElementId === 'graph-canvas') { + if ( + keybinding && + keybinding.targetElementId === 'graph-canvas-container' + ) { useCommandStore().execute(keybinding.commandId) this.graph.change() diff --git a/src/stores/executionStore.ts b/src/stores/executionStore.ts index cfc7a2dd4..8791ab4e1 100644 --- a/src/stores/executionStore.ts +++ b/src/stores/executionStore.ts @@ -43,6 +43,57 @@ interface QueuedPrompt { workflow?: ComfyWorkflow } +const subgraphNodeIdToSubgraph = (id: string, graph: LGraph | Subgraph) => { + const node = graph.getNodeById(id) + if (node?.isSubgraphNode()) return node.subgraph +} + +/** + * Recursively get the subgraph objects for the given subgraph instance IDs + * @param currentGraph The current graph + * @param subgraphNodeIds The instance IDs + * @param subgraphs The subgraphs + * @returns The subgraphs that correspond to each of the instance IDs. + */ +function getSubgraphsFromInstanceIds( + currentGraph: LGraph | Subgraph, + subgraphNodeIds: string[], + subgraphs: Subgraph[] = [] +): Subgraph[] { + // Last segment is the node portion; nothing to do. + if (subgraphNodeIds.length === 1) return subgraphs + + const currentPart = subgraphNodeIds.shift() + if (currentPart === undefined) return subgraphs + + const subgraph = subgraphNodeIdToSubgraph(currentPart, currentGraph) + if (!subgraph) throw new Error(`Subgraph not found: ${currentPart}`) + + subgraphs.push(subgraph) + return getSubgraphsFromInstanceIds(subgraph, subgraphNodeIds, subgraphs) +} + +/** + * Convert execution context node IDs to NodeLocatorIds + * @param nodeId The node ID from execution context (could be execution ID) + * @returns The NodeLocatorId + */ +function executionIdToNodeLocatorId(nodeId: string | number): NodeLocatorId { + const nodeIdStr = String(nodeId) + + if (!nodeIdStr.includes(':')) { + // It's a top-level node ID + return nodeIdStr + } + + // It's an execution node ID + const parts = nodeIdStr.split(':') + const localNodeId = parts[parts.length - 1] + const subgraphs = getSubgraphsFromInstanceIds(app.graph, parts) + const nodeLocatorId = createNodeLocatorId(subgraphs.at(-1)!.id, localNodeId) + return nodeLocatorId +} + export const useExecutionStore = defineStore('execution', () => { const workflowStore = useWorkflowStore() const canvasStore = useCanvasStore() @@ -55,29 +106,6 @@ export const useExecutionStore = defineStore('execution', () => { // This is the progress of all nodes in the currently executing workflow const nodeProgressStates = ref>({}) - /** - * Convert execution context node IDs to NodeLocatorIds - * @param nodeId The node ID from execution context (could be execution ID) - * @returns The NodeLocatorId - */ - const executionIdToNodeLocatorId = ( - nodeId: string | number - ): NodeLocatorId => { - const nodeIdStr = String(nodeId) - - if (!nodeIdStr.includes(':')) { - // It's a top-level node ID - return nodeIdStr - } - - // It's an execution node ID - const parts = nodeIdStr.split(':') - const localNodeId = parts[parts.length - 1] - const subgraphs = getSubgraphsFromInstanceIds(app.graph, parts) - const nodeLocatorId = createNodeLocatorId(subgraphs.at(-1)!.id, localNodeId) - return nodeLocatorId - } - const mergeExecutionProgressStates = ( currentState: NodeProgressState | undefined, newState: NodeProgressState @@ -139,9 +167,13 @@ export const useExecutionStore = defineStore('execution', () => { // @deprecated For backward compatibility - stores the primary executing node ID const executingNodeId = computed(() => { - return executingNodeIds.value.length > 0 ? executingNodeIds.value[0] : null + return executingNodeIds.value[0] ?? null }) + const uniqueExecutingNodeIdStrings = computed( + () => new Set(executingNodeIds.value.map(String)) + ) + // For backward compatibility - returns the primary executing node const executingNode = computed(() => { if (!executingNodeId.value) return null @@ -159,36 +191,6 @@ export const useExecutionStore = defineStore('execution', () => { ) }) - const subgraphNodeIdToSubgraph = (id: string, graph: LGraph | Subgraph) => { - const node = graph.getNodeById(id) - if (node?.isSubgraphNode()) return node.subgraph - } - - /** - * Recursively get the subgraph objects for the given subgraph instance IDs - * @param currentGraph The current graph - * @param subgraphNodeIds The instance IDs - * @param subgraphs The subgraphs - * @returns The subgraphs that correspond to each of the instance IDs. - */ - const getSubgraphsFromInstanceIds = ( - currentGraph: LGraph | Subgraph, - subgraphNodeIds: string[], - subgraphs: Subgraph[] = [] - ): Subgraph[] => { - // Last segment is the node portion; nothing to do. - if (subgraphNodeIds.length === 1) return subgraphs - - const currentPart = subgraphNodeIds.shift() - if (currentPart === undefined) return subgraphs - - const subgraph = subgraphNodeIdToSubgraph(currentPart, currentGraph) - if (!subgraph) throw new Error(`Subgraph not found: ${currentPart}`) - - subgraphs.push(subgraph) - return getSubgraphsFromInstanceIds(subgraph, subgraphNodeIds, subgraphs) - } - // This is the progress of the currently executing node (for backward compatibility) const _executingNodeProgress = ref(null) const executingNodeProgress = computed(() => @@ -423,66 +425,25 @@ export const useExecutionStore = defineStore('execution', () => { return { isIdle, clientId, - /** - * The id of the prompt that is currently being executed - */ activePromptId, - /** - * The queued prompts - */ queuedPrompts, - /** - * The node errors from the previous execution. - */ lastNodeErrors, - /** - * The error from the previous execution. - */ lastExecutionError, - /** - * Local node ID for the most recent execution error. - */ lastExecutionErrorNodeId, - /** - * The id of the node that is currently being executed (backward compatibility) - */ executingNodeId, - /** - * The list of all nodes that are currently executing - */ executingNodeIds, - /** - * The prompt that is currently being executed - */ activePrompt, - /** - * The total number of nodes to execute - */ totalNodesToExecute, - /** - * The number of nodes that have been executed - */ nodesExecuted, - /** - * The progress of the execution - */ executionProgress, - /** - * The node that is currently being executed (backward compatibility) - */ executingNode, - /** - * The progress of the executing node (backward compatibility) - */ executingNodeProgress, - /** - * All node progress states from progress_state events - */ nodeProgressStates, nodeLocationProgressStates, bindExecutionEvents, unbindExecutionEvents, storePrompt, + uniqueExecutingNodeIdStrings, // Raw executing progress data for backward compatibility in ComfyApp. _executingNodeProgress, // NodeLocatorId conversion helpers diff --git a/src/stores/modelToNodeStore.ts b/src/stores/modelToNodeStore.ts index f6c15e91a..4f7925294 100644 --- a/src/stores/modelToNodeStore.ts +++ b/src/stores/modelToNodeStore.ts @@ -33,12 +33,43 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { ) }) + /** Internal computed for efficient reverse lookup: nodeType -> category */ + const nodeTypeToCategory = computed(() => { + const lookup: Record = {} + for (const [category, providers] of Object.entries(modelToNodeMap.value)) { + for (const provider of providers) { + // Only store the first category for each node type (matches current assetService behavior) + if (!lookup[provider.nodeDef.name]) { + lookup[provider.nodeDef.name] = category + } + } + } + return lookup + }) + /** Get set of all registered node types for efficient lookup */ function getRegisteredNodeTypes(): Set { registerDefaults() return registeredNodeTypes.value } + /** + * Get the category for a given node type. + * Performs efficient O(1) lookup using cached reverse map. + * @param nodeType The node type name to find the category for + * @returns The category name, or undefined if not found + */ + function getCategoryForNodeType(nodeType: string): string | undefined { + registerDefaults() + + // Handle invalid input gracefully + if (!nodeType || typeof nodeType !== 'string') { + return undefined + } + + return nodeTypeToCategory.value[nodeType] + } + /** * Get the node provider for the given model type name. * @param modelType The name of the model type to get the node provider for. @@ -109,6 +140,7 @@ export const useModelToNodeStore = defineStore('modelToNode', () => { return { modelToNodeMap, getRegisteredNodeTypes, + getCategoryForNodeType, getNodeProvider, getAllNodeProviders, registerNodeProvider, diff --git a/src/types/litegraph-augmentation.d.ts b/src/types/litegraph-augmentation.d.ts index d404ff88f..0f5dee17d 100644 --- a/src/types/litegraph-augmentation.d.ts +++ b/src/types/litegraph-augmentation.d.ts @@ -82,7 +82,7 @@ declare module '@/lib/litegraph/src/litegraph' { } // Add interface augmentations into the class itself - // eslint-disable-next-line @typescript-eslint/no-empty-object-type + interface BaseWidget extends IBaseWidget {} interface LGraphNode { diff --git a/src/views/GraphView.vue b/src/views/GraphView.vue index a23d4fecf..bbfca6ea0 100644 --- a/src/views/GraphView.vue +++ b/src/views/GraphView.vue @@ -33,6 +33,7 @@ import { } from 'vue' import { useI18n } from 'vue-i18n' +import { runWhenGlobalIdle } from '@/base/common/async' import MenuHamburger from '@/components/MenuHamburger.vue' import UnloadWindowConfirmDialog from '@/components/dialog/UnloadWindowConfirmDialog.vue' import GraphCanvas from '@/components/graph/GraphCanvas.vue' @@ -253,33 +254,30 @@ void nextTick(() => { }) const onGraphReady = () => { - requestIdleCallback( - () => { - // Setting values now available after comfyApp.setup. - // Load keybindings. - wrapWithErrorHandling(useKeybindingService().registerUserKeybindings)() + runWhenGlobalIdle(() => { + // Setting values now available after comfyApp.setup. + // Load keybindings. + wrapWithErrorHandling(useKeybindingService().registerUserKeybindings)() - // Load server config - wrapWithErrorHandling(useServerConfigStore().loadServerConfig)( - SERVER_CONFIG_ITEMS, - settingStore.get('Comfy.Server.ServerConfigValues') - ) + // Load server config + wrapWithErrorHandling(useServerConfigStore().loadServerConfig)( + SERVER_CONFIG_ITEMS, + settingStore.get('Comfy.Server.ServerConfigValues') + ) - // Load model folders - void wrapWithErrorHandlingAsync(useModelStore().loadModelFolders)() + // Load model folders + void wrapWithErrorHandlingAsync(useModelStore().loadModelFolders)() - // Non-blocking load of node frequencies - void wrapWithErrorHandlingAsync( - useNodeFrequencyStore().loadNodeFrequencies - )() + // Non-blocking load of node frequencies + void wrapWithErrorHandlingAsync( + useNodeFrequencyStore().loadNodeFrequencies + )() - // Node defs now available after comfyApp.setup. - // Explicitly initialize nodeSearchService to avoid indexing delay when - // node search is triggered - useNodeDefStore().nodeSearchService.searchNode('') - }, - { timeout: 1000 } - ) + // Node defs now available after comfyApp.setup. + // Explicitly initialize nodeSearchService to avoid indexing delay when + // node search is triggered + useNodeDefStore().nodeSearchService.searchNode('') + }, 1000) } diff --git a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts index d7d4f74dc..bef33733b 100644 --- a/tests-ui/platform/assets/composables/useAssetBrowser.test.ts +++ b/tests-ui/platform/assets/composables/useAssetBrowser.test.ts @@ -1,10 +1,33 @@ -import { describe, expect, it } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' import { nextTick } from 'vue' import { useAssetBrowser } from '@/platform/assets/composables/useAssetBrowser' import type { AssetItem } from '@/platform/assets/schemas/assetSchema' +import { assetService } from '@/platform/assets/services/assetService' + +vi.mock('@/platform/assets/services/assetService', () => ({ + assetService: { + getAssetDetails: vi.fn() + } +})) + +vi.mock('@/i18n', () => ({ + t: (key: string) => { + const translations: Record = { + 'assetBrowser.allModels': 'All Models', + 'assetBrowser.assets': 'Assets', + 'assetBrowser.unknown': 'unknown' + } + return translations[key] || key + }, + d: (date: Date) => date.toLocaleDateString() +})) describe('useAssetBrowser', () => { + beforeEach(() => { + vi.restoreAllMocks() + }) + // Test fixtures - minimal data focused on functionality being tested const createApiAsset = (overrides: Partial = {}): AssetItem => ({ id: 'test-id', @@ -26,8 +49,8 @@ describe('useAssetBrowser', () => { user_metadata: { description: 'Test model' } }) - const { transformAssetForDisplay } = useAssetBrowser([apiAsset]) - const result = transformAssetForDisplay(apiAsset) + const { filteredAssets } = useAssetBrowser([apiAsset]) + const result = filteredAssets.value[0] // Get the transformed asset from filteredAssets // Preserves API properties expect(result.id).toBe(apiAsset.id) @@ -49,15 +72,13 @@ describe('useAssetBrowser', () => { user_metadata: undefined }) - const { transformAssetForDisplay } = useAssetBrowser([apiAsset]) - const result = transformAssetForDisplay(apiAsset) + const { filteredAssets } = useAssetBrowser([apiAsset]) + const result = filteredAssets.value[0] expect(result.description).toBe('loras model') }) it('formats various file sizes correctly', () => { - const { transformAssetForDisplay } = useAssetBrowser([]) - const testCases = [ { size: 512, expected: '512 B' }, { size: 1536, expected: '1.5 KB' }, @@ -67,7 +88,8 @@ describe('useAssetBrowser', () => { testCases.forEach(({ size, expected }) => { const asset = createApiAsset({ size }) - const result = transformAssetForDisplay(asset) + const { filteredAssets } = useAssetBrowser([asset]) + const result = filteredAssets.value[0] expect(result.formattedSize).toBe(expected) }) }) @@ -236,18 +258,182 @@ describe('useAssetBrowser', () => { }) }) - describe('Asset Selection', () => { - it('returns selected asset UUID for efficient handling', () => { + describe('Async Asset Selection with Detail Fetching', () => { + it('should fetch asset details and call onSelect with filename when provided', async () => { + const onSelectSpy = vi.fn() const asset = createApiAsset({ - id: 'test-uuid-123', - name: 'selected_model.safetensors' + id: 'asset-123', + name: 'test-model.safetensors' }) - const { selectAsset, transformAssetForDisplay } = useAssetBrowser([asset]) - const displayAsset = transformAssetForDisplay(asset) - const result = selectAsset(displayAsset) + const detailAsset = createApiAsset({ + id: 'asset-123', + name: 'test-model.safetensors', + user_metadata: { filename: 'checkpoints/test-model.safetensors' } + }) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) - expect(result).toBe('test-uuid-123') + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id, onSelectSpy) + + expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-123') + expect(onSelectSpy).toHaveBeenCalledWith( + 'checkpoints/test-model.safetensors' + ) + }) + + it('should handle missing user_metadata.filename as error', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + const asset = createApiAsset({ id: 'asset-456' }) + + const detailAsset = createApiAsset({ + id: 'asset-456', + user_metadata: { filename: '' } // Invalid empty filename + }) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id, onSelectSpy) + + expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-456') + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Invalid asset filename:', + expect.arrayContaining([ + expect.objectContaining({ + message: 'Filename cannot be empty' + }) + ]), + 'for asset:', + 'asset-456' + ) + }) + + it('should handle API errors gracefully', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + const asset = createApiAsset({ id: 'asset-789' }) + + const apiError = new Error('API Error') + vi.mocked(assetService.getAssetDetails).mockRejectedValue(apiError) + + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id, onSelectSpy) + + expect(assetService.getAssetDetails).toHaveBeenCalledWith('asset-789') + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + expect.stringContaining('Failed to fetch asset details for asset-789'), + apiError + ) + }) + + it('should not fetch details when no callback provided', async () => { + const asset = createApiAsset({ id: 'asset-no-callback' }) + + const { selectAssetWithCallback } = useAssetBrowser([asset]) + + await selectAssetWithCallback(asset.id) + + expect(assetService.getAssetDetails).not.toHaveBeenCalled() + }) + }) + + describe('Filename Validation Security', () => { + const createValidationTest = (filename: string) => { + const testAsset = createApiAsset({ id: 'validation-test' }) + const detailAsset = createApiAsset({ + id: 'validation-test', + user_metadata: { filename } + }) + return { testAsset, detailAsset } + } + + it('accepts valid file paths with forward slashes', async () => { + const onSelectSpy = vi.fn() + const { testAsset, detailAsset } = createValidationTest( + 'models/checkpoints/v1/test-model.safetensors' + ) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([testAsset]) + await selectAssetWithCallback(testAsset.id, onSelectSpy) + + expect(onSelectSpy).toHaveBeenCalledWith( + 'models/checkpoints/v1/test-model.safetensors' + ) + }) + + it('rejects directory traversal attacks', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + + const maliciousPaths = [ + '../malicious-model.safetensors', + 'models/../../../etc/passwd', + '/etc/passwd' + ] + + for (const path of maliciousPaths) { + const { testAsset, detailAsset } = createValidationTest(path) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([testAsset]) + await selectAssetWithCallback(testAsset.id, onSelectSpy) + + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Invalid asset filename:', + expect.arrayContaining([ + expect.objectContaining({ + message: 'Path must not start with / or contain ..' + }) + ]), + 'for asset:', + 'validation-test' + ) + } + }) + + it('rejects invalid filename characters', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}) + const onSelectSpy = vi.fn() + + const invalidChars = ['\\', ':', '*', '?', '"', '<', '>', '|'] + + for (const char of invalidChars) { + const { testAsset, detailAsset } = createValidationTest( + `bad${char}filename.safetensors` + ) + vi.mocked(assetService.getAssetDetails).mockResolvedValue(detailAsset) + + const { selectAssetWithCallback } = useAssetBrowser([testAsset]) + await selectAssetWithCallback(testAsset.id, onSelectSpy) + + expect(onSelectSpy).not.toHaveBeenCalled() + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Invalid asset filename:', + expect.arrayContaining([ + expect.objectContaining({ + message: 'Invalid filename characters' + }) + ]), + 'for asset:', + 'validation-test' + ) + } }) }) diff --git a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts index fefeeceac..102aa7a18 100644 --- a/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts +++ b/tests-ui/platform/assets/composables/useAssetBrowserDialog.test.ts @@ -6,11 +6,18 @@ import { useDialogStore } from '@/stores/dialogStore' // Mock the dialog store vi.mock('@/stores/dialogStore') +// Mock the asset service +vi.mock('@/platform/assets/services/assetService', () => ({ + assetService: { + getAssetsForNodeType: vi.fn().mockResolvedValue([]) + } +})) + // Test factory functions interface AssetBrowserProps { nodeType: string inputName: string - onAssetSelected?: ReturnType + onAssetSelected?: (filename: string) => void } function createAssetBrowserProps( @@ -25,7 +32,7 @@ function createAssetBrowserProps( describe('useAssetBrowserDialog', () => { describe('Asset Selection Flow', () => { - it('auto-closes dialog when asset is selected', () => { + it('auto-closes dialog when asset is selected', async () => { // Create fresh mocks for this test const mockShowDialog = vi.fn() const mockCloseDialog = vi.fn() @@ -41,7 +48,7 @@ describe('useAssetBrowserDialog', () => { const onAssetSelected = vi.fn() const props = createAssetBrowserProps({ onAssetSelected }) - assetBrowserDialog.show(props) + await assetBrowserDialog.show(props) // Get the onSelect handler that was passed to the dialog const dialogCall = mockShowDialog.mock.calls[0][0] @@ -50,14 +57,14 @@ describe('useAssetBrowserDialog', () => { // Simulate asset selection onSelectHandler('selected-asset-path') - // Should call the original callback and close dialog + // Should call the original callback and trigger hide animation expect(onAssetSelected).toHaveBeenCalledWith('selected-asset-path') expect(mockCloseDialog).toHaveBeenCalledWith({ key: 'global-asset-browser' }) }) - it('closes dialog when close handler is called', () => { + it('closes dialog when close handler is called', async () => { // Create fresh mocks for this test const mockShowDialog = vi.fn() const mockCloseDialog = vi.fn() @@ -72,7 +79,7 @@ describe('useAssetBrowserDialog', () => { const assetBrowserDialog = useAssetBrowserDialog() const props = createAssetBrowserProps() - assetBrowserDialog.show(props) + await assetBrowserDialog.show(props) // Get the onClose handler that was passed to the dialog const dialogCall = mockShowDialog.mock.calls[0][0] diff --git a/tests-ui/tests/composables/node/useNodePricing.test.ts b/tests-ui/tests/composables/node/useNodePricing.test.ts index 6cd76cb75..32b18ed68 100644 --- a/tests-ui/tests/composables/node/useNodePricing.test.ts +++ b/tests-ui/tests/composables/node/useNodePricing.test.ts @@ -1894,4 +1894,159 @@ describe('useNodePricing', () => { expect(getNodeDisplayPrice(missingDuration)).toBe('Token-based') }) }) + + describe('dynamic pricing - WanTextToVideoApi', () => { + it('should return $1.50 for 10s at 1080p', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanTextToVideoApi', [ + { name: 'duration', value: '10' }, + { name: 'size', value: '1080p: 4:3 (1632x1248)' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$1.50/Run') // 0.15 * 10 + }) + + it('should return $0.50 for 5s at 720p', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanTextToVideoApi', [ + { name: 'duration', value: 5 }, + { name: 'size', value: '720p: 16:9 (1280x720)' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$0.50/Run') // 0.10 * 5 + }) + + it('should return $0.15 for 3s at 480p', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanTextToVideoApi', [ + { name: 'duration', value: '3' }, + { name: 'size', value: '480p: 1:1 (624x624)' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$0.15/Run') // 0.05 * 3 + }) + + it('should fall back when widgets are missing', () => { + const { getNodeDisplayPrice } = useNodePricing() + const missingBoth = createMockNode('WanTextToVideoApi', []) + const missingSize = createMockNode('WanTextToVideoApi', [ + { name: 'duration', value: '5' } + ]) + const missingDuration = createMockNode('WanTextToVideoApi', [ + { name: 'size', value: '1080p' } + ]) + + expect(getNodeDisplayPrice(missingBoth)).toBe('$0.05-0.15/second') + expect(getNodeDisplayPrice(missingSize)).toBe('$0.05-0.15/second') + expect(getNodeDisplayPrice(missingDuration)).toBe('$0.05-0.15/second') + }) + + it('should fall back on invalid duration', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanTextToVideoApi', [ + { name: 'duration', value: 'invalid' }, + { name: 'size', value: '1080p' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$0.05-0.15/second') + }) + + it('should fall back on unknown resolution', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanTextToVideoApi', [ + { name: 'duration', value: '10' }, + { name: 'size', value: '2K' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$0.05-0.15/second') + }) + }) + + describe('dynamic pricing - WanImageToVideoApi', () => { + it('should return $0.80 for 8s at 720p', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanImageToVideoApi', [ + { name: 'duration', value: 8 }, + { name: 'resolution', value: '720p' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$0.80/Run') // 0.10 * 8 + }) + + it('should return $0.60 for 12s at 480P', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanImageToVideoApi', [ + { name: 'duration', value: '12' }, + { name: 'resolution', value: '480P' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$0.60/Run') // 0.05 * 12 + }) + + it('should return $1.50 for 10s at 1080p', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanImageToVideoApi', [ + { name: 'duration', value: '10' }, + { name: 'resolution', value: '1080p' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$1.50/Run') // 0.15 * 10 + }) + + it('should handle "5s" string duration at 1080P', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanImageToVideoApi', [ + { name: 'duration', value: '5s' }, + { name: 'resolution', value: '1080P' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$0.75/Run') // 0.15 * 5 + }) + + it('should fall back when widgets are missing', () => { + const { getNodeDisplayPrice } = useNodePricing() + const missingBoth = createMockNode('WanImageToVideoApi', []) + const missingRes = createMockNode('WanImageToVideoApi', [ + { name: 'duration', value: '5' } + ]) + const missingDuration = createMockNode('WanImageToVideoApi', [ + { name: 'resolution', value: '1080p' } + ]) + + expect(getNodeDisplayPrice(missingBoth)).toBe('$0.05-0.15/second') + expect(getNodeDisplayPrice(missingRes)).toBe('$0.05-0.15/second') + expect(getNodeDisplayPrice(missingDuration)).toBe('$0.05-0.15/second') + }) + + it('should fall back on invalid duration', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanImageToVideoApi', [ + { name: 'duration', value: 'invalid' }, + { name: 'resolution', value: '720p' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$0.05-0.15/second') + }) + + it('should fall back on unknown resolution', () => { + const { getNodeDisplayPrice } = useNodePricing() + const node = createMockNode('WanImageToVideoApi', [ + { name: 'duration', value: '10' }, + { name: 'resolution', value: 'weird-res' } + ]) + + const price = getNodeDisplayPrice(node) + expect(price).toBe('$0.05-0.15/second') + }) + }) }) diff --git a/tests-ui/tests/performance/transformPerformance.test.ts b/tests-ui/tests/performance/transformPerformance.test.ts index e9f995e97..1f2fb83f7 100644 --- a/tests-ui/tests/performance/transformPerformance.test.ts +++ b/tests-ui/tests/performance/transformPerformance.test.ts @@ -14,7 +14,7 @@ const createMockCanvasContext = () => ({ const isCI = Boolean(process.env.CI) const describeIfNotCI = isCI ? describe.skip : describe -describeIfNotCI('Transform Performance', () => { +describeIfNotCI.skip('Transform Performance', () => { let transformState: ReturnType let mockCanvas: any diff --git a/tests-ui/tests/renderer/extensions/vueNodes/components/LGraphNode.spec.ts b/tests-ui/tests/renderer/extensions/vueNodes/components/LGraphNode.spec.ts index 07c0a3081..0615e9b9a 100644 --- a/tests-ui/tests/renderer/extensions/vueNodes/components/LGraphNode.spec.ts +++ b/tests-ui/tests/renderer/extensions/vueNodes/components/LGraphNode.spec.ts @@ -1,14 +1,38 @@ import { createTestingPinia } from '@pinia/testing' import { mount } from '@vue/test-utils' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { computed, ref } from 'vue' +import { computed, toValue } from 'vue' +import type { ComponentProps } from 'vue-component-type-helpers' import { createI18n } from 'vue-i18n' import type { VueNodeData } from '@/composables/graph/useGraphNodeManager' -import { SelectedNodeIdsKey } from '@/renderer/core/canvas/injectionKeys' import LGraphNode from '@/renderer/extensions/vueNodes/components/LGraphNode.vue' +import { useNodeEventHandlers } from '@/renderer/extensions/vueNodes/composables/useNodeEventHandlers' import { useVueElementTracking } from '@/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking' -import { useNodeExecutionState } from '@/renderer/extensions/vueNodes/execution/useNodeExecutionState' + +const mockData = vi.hoisted(() => ({ + mockNodeIds: new Set(), + mockExecuting: false +})) + +vi.mock('@/renderer/core/canvas/canvasStore', () => { + const getCanvas = vi.fn() + const useCanvasStore = () => ({ + getCanvas, + selectedNodeIds: computed(() => mockData.mockNodeIds) + }) + return { + useCanvasStore + } +}) + +vi.mock( + '@/renderer/extensions/vueNodes/composables/useNodeEventHandlers', + () => { + const handleNodeSelect = vi.fn() + return { useNodeEventHandlers: () => ({ handleNodeSelect }) } + } +) vi.mock( '@/renderer/extensions/vueNodes/composables/useVueNodeResizeTracking', @@ -47,7 +71,7 @@ vi.mock( '@/renderer/extensions/vueNodes/execution/useNodeExecutionState', () => ({ useNodeExecutionState: vi.fn(() => ({ - executing: computed(() => false), + executing: computed(() => mockData.mockExecuting), progress: computed(() => undefined), progressPercentage: computed(() => undefined), progressState: computed(() => undefined as any), @@ -72,61 +96,56 @@ const i18n = createI18n({ } } }) +function mountLGraphNode(props: ComponentProps) { + return mount(LGraphNode, { + props, + global: { + plugins: [ + createTestingPinia({ + createSpy: vi.fn + }), + i18n + ], + stubs: { + NodeHeader: true, + NodeSlots: true, + NodeWidgets: true, + NodeContent: true, + SlotConnectionDot: true + } + } + }) +} +const mockNodeData: VueNodeData = { + id: 'test-node-123', + title: 'Test Node', + type: 'TestNode', + mode: 0, + flags: {}, + inputs: [], + outputs: [], + widgets: [], + selected: false, + executing: false +} describe('LGraphNode', () => { - const mockNodeData: VueNodeData = { - id: 'test-node-123', - title: 'Test Node', - type: 'TestNode', - mode: 0, - flags: {}, - inputs: [], - outputs: [], - widgets: [], - selected: false, - executing: false - } - - const mountLGraphNode = (props: any, selectedNodeIds = new Set()) => { - return mount(LGraphNode, { - props, - global: { - plugins: [ - createTestingPinia({ - createSpy: vi.fn - }), - i18n - ], - provide: { - [SelectedNodeIdsKey as symbol]: ref(selectedNodeIds) - }, - stubs: { - NodeHeader: true, - NodeSlots: true, - NodeWidgets: true, - NodeContent: true, - SlotConnectionDot: true - } - } - }) - } - beforeEach(() => { - vi.clearAllMocks() - // Reset to default mock - vi.mocked(useNodeExecutionState).mockReturnValue({ - executing: computed(() => false), - progress: computed(() => undefined), - progressPercentage: computed(() => undefined), - progressState: computed(() => undefined as any), - executionState: computed(() => 'idle' as const) - }) + vi.resetAllMocks() + mockData.mockNodeIds = new Set() + mockData.mockExecuting = false }) it('should call resize tracking composable with node ID', () => { mountLGraphNode({ nodeData: mockNodeData }) - expect(useVueElementTracking).toHaveBeenCalledWith('test-node-123', 'node') + expect(useVueElementTracking).toHaveBeenCalledWith( + expect.any(Function), + 'node' + ) + const idArg = vi.mocked(useVueElementTracking).mock.calls[0]?.[0] + const id = toValue(idArg) + expect(id).toEqual('test-node-123') }) it('should render with data-node-id attribute', () => { @@ -146,9 +165,6 @@ describe('LGraphNode', () => { }), i18n ], - provide: { - [SelectedNodeIdsKey as symbol]: ref(new Set()) - }, stubs: { NodeSlots: true, NodeWidgets: true, @@ -162,24 +178,15 @@ describe('LGraphNode', () => { }) it('should apply selected styling when selected prop is true', () => { - const wrapper = mountLGraphNode( - { nodeData: mockNodeData, selected: true }, - new Set(['test-node-123']) - ) + mockData.mockNodeIds = new Set(['test-node-123']) + const wrapper = mountLGraphNode({ nodeData: mockNodeData }) expect(wrapper.classes()).toContain('outline-2') expect(wrapper.classes()).toContain('outline-black') expect(wrapper.classes()).toContain('dark-theme:outline-white') }) it('should apply executing animation when executing prop is true', () => { - // Mock the execution state to return executing: true - vi.mocked(useNodeExecutionState).mockReturnValue({ - executing: computed(() => true), - progress: computed(() => undefined), - progressPercentage: computed(() => undefined), - progressState: computed(() => undefined as any), - executionState: computed(() => 'running' as const) - }) + mockData.mockExecuting = true const wrapper = mountLGraphNode({ nodeData: mockNodeData }) @@ -187,12 +194,16 @@ describe('LGraphNode', () => { }) it('should emit node-click event on pointer up', async () => { + const { handleNodeSelect } = useNodeEventHandlers() const wrapper = mountLGraphNode({ nodeData: mockNodeData }) await wrapper.trigger('pointerup') - expect(wrapper.emitted('node-click')).toHaveLength(1) - expect(wrapper.emitted('node-click')?.[0]).toHaveLength(3) - expect(wrapper.emitted('node-click')?.[0][1]).toEqual(mockNodeData) + expect(handleNodeSelect).toHaveBeenCalledOnce() + expect(handleNodeSelect).toHaveBeenCalledWith( + expect.any(PointerEvent), + mockNodeData, + expect.any(Boolean) + ) }) }) diff --git a/tests-ui/tests/renderer/extensions/vueNodes/composables/useNodeEventHandlers.test.ts b/tests-ui/tests/renderer/extensions/vueNodes/composables/useNodeEventHandlers.test.ts index e2a4bd920..dd08457eb 100644 --- a/tests-ui/tests/renderer/extensions/vueNodes/composables/useNodeEventHandlers.test.ts +++ b/tests-ui/tests/renderer/extensions/vueNodes/composables/useNodeEventHandlers.test.ts @@ -1,98 +1,82 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' -import { computed, ref } from 'vue' +import { computed, shallowRef } from 'vue' -import type { VueNodeData } from '@/composables/graph/useGraphNodeManager' -import type { useGraphNodeManager } from '@/composables/graph/useGraphNodeManager' -import type { LGraphCanvas, LGraphNode } from '@/lib/litegraph/src/litegraph' +import { + type GraphNodeManager, + type VueNodeData, + useGraphNodeManager +} from '@/composables/graph/useGraphNodeManager' +import { useVueNodeLifecycle } from '@/composables/graph/useVueNodeLifecycle' +import type { + LGraph, + LGraphCanvas, + LGraphNode +} from '@/lib/litegraph/src/litegraph' import { useCanvasStore } from '@/renderer/core/canvas/canvasStore' -import { useCanvasInteractions } from '@/renderer/core/canvas/useCanvasInteractions' import { useLayoutMutations } from '@/renderer/core/layout/operations/layoutMutations' import { useNodeEventHandlers } from '@/renderer/extensions/vueNodes/composables/useNodeEventHandlers' -vi.mock('@/renderer/core/canvas/canvasStore', () => ({ - useCanvasStore: vi.fn() -})) - -vi.mock('@/renderer/core/canvas/useCanvasInteractions', () => ({ - useCanvasInteractions: vi.fn() -})) - -vi.mock('@/renderer/core/layout/operations/layoutMutations', () => ({ - useLayoutMutations: vi.fn() -})) - -vi.mock('@/composables/graph/useGraphNodeManager', () => ({ - useGraphNodeManager: vi.fn() -})) - -function createMockCanvas(): Pick< - LGraphCanvas, - 'select' | 'deselect' | 'deselectAll' -> { - return { +vi.mock('@/renderer/core/canvas/canvasStore', () => { + const canvas: Partial = { select: vi.fn(), deselect: vi.fn(), deselectAll: vi.fn() } -} - -function createMockNode(): Pick { + const updateSelectedItems = vi.fn() return { + useCanvasStore: vi.fn(() => ({ + canvas: canvas as LGraphCanvas, + updateSelectedItems, + selectedItems: [] + })) + } +}) + +vi.mock('@/renderer/core/canvas/useCanvasInteractions', () => ({ + useCanvasInteractions: vi.fn(() => ({ + shouldHandleNodePointerEvents: computed(() => true) // Default to allowing pointer events + })) +})) + +vi.mock('@/renderer/core/layout/operations/layoutMutations', () => { + const setSource = vi.fn() + const bringNodeToFront = vi.fn() + return { + useLayoutMutations: vi.fn(() => ({ + setSource, + bringNodeToFront + })) + } +}) + +vi.mock('@/composables/graph/useGraphNodeManager', () => { + const mockNode = { id: 'node-1', selected: false, flags: { pinned: false } } -} - -function createMockNodeManager( - node: Pick -) { + const nodeManager = shallowRef({ + getNode: vi.fn(() => mockNode as Partial as LGraphNode) + } as Partial as GraphNodeManager) return { - getNode: vi.fn().mockReturnValue(node) as ReturnType< - typeof useGraphNodeManager - >['getNode'] + useGraphNodeManager: vi.fn(() => nodeManager) } -} +}) -function createMockCanvasStore( - canvas: Pick -): Pick< - ReturnType, - 'canvas' | 'selectedItems' | 'updateSelectedItems' -> { +vi.mock('@/composables/graph/useVueNodeLifecycle', () => { + const nodeManager = useGraphNodeManager(undefined as unknown as LGraph) return { - canvas: canvas as LGraphCanvas, - selectedItems: [], - updateSelectedItems: vi.fn() + useVueNodeLifecycle: vi.fn(() => ({ + nodeManager + })) } -} - -function createMockLayoutMutations(): Pick< - ReturnType, - 'setSource' | 'bringNodeToFront' -> { - return { - setSource: vi.fn(), - bringNodeToFront: vi.fn() - } -} - -function createMockCanvasInteractions(): Pick< - ReturnType, - 'shouldHandleNodePointerEvents' -> { - return { - shouldHandleNodePointerEvents: computed(() => true) // Default to allowing pointer events - } -} +}) describe('useNodeEventHandlers', () => { - let mockCanvas: ReturnType - let mockNode: ReturnType - let mockNodeManager: ReturnType - let mockCanvasStore: ReturnType - let mockLayoutMutations: ReturnType - let mockCanvasInteractions: ReturnType + const { nodeManager: mockNodeManager } = useVueNodeLifecycle() + + const mockNode = mockNodeManager.value!.getNode('fake_id') + const mockLayoutMutations = useLayoutMutations() const testNodeData: VueNodeData = { id: 'node-1', @@ -104,28 +88,13 @@ describe('useNodeEventHandlers', () => { } beforeEach(async () => { - mockNode = createMockNode() - mockCanvas = createMockCanvas() - mockNodeManager = createMockNodeManager(mockNode) - mockCanvasStore = createMockCanvasStore(mockCanvas) - mockLayoutMutations = createMockLayoutMutations() - mockCanvasInteractions = createMockCanvasInteractions() - - vi.mocked(useCanvasStore).mockReturnValue( - mockCanvasStore as ReturnType - ) - vi.mocked(useLayoutMutations).mockReturnValue( - mockLayoutMutations as ReturnType - ) - vi.mocked(useCanvasInteractions).mockReturnValue( - mockCanvasInteractions as ReturnType - ) + vi.restoreAllMocks() }) describe('handleNodeSelect', () => { it('should select single node on regular click', () => { - const nodeManager = ref(mockNodeManager) - const { handleNodeSelect } = useNodeEventHandlers(nodeManager) + const { handleNodeSelect } = useNodeEventHandlers() + const { canvas, updateSelectedItems } = useCanvasStore() const event = new PointerEvent('pointerdown', { bubbles: true, @@ -135,17 +104,17 @@ describe('useNodeEventHandlers', () => { handleNodeSelect(event, testNodeData, false) - expect(mockCanvas.deselectAll).toHaveBeenCalledOnce() - expect(mockCanvas.select).toHaveBeenCalledWith(mockNode) - expect(mockCanvasStore.updateSelectedItems).toHaveBeenCalledOnce() + expect(canvas?.deselectAll).toHaveBeenCalledOnce() + expect(canvas?.select).toHaveBeenCalledWith(mockNode) + expect(updateSelectedItems).toHaveBeenCalledOnce() }) it('should toggle selection on ctrl+click', () => { - const nodeManager = ref(mockNodeManager) - const { handleNodeSelect } = useNodeEventHandlers(nodeManager) + const { handleNodeSelect } = useNodeEventHandlers() + const { canvas } = useCanvasStore() // Test selecting unselected node with ctrl - mockNode.selected = false + mockNode!.selected = false const ctrlClickEvent = new PointerEvent('pointerdown', { bubbles: true, @@ -155,16 +124,16 @@ describe('useNodeEventHandlers', () => { handleNodeSelect(ctrlClickEvent, testNodeData, false) - expect(mockCanvas.deselectAll).not.toHaveBeenCalled() - expect(mockCanvas.select).toHaveBeenCalledWith(mockNode) + expect(canvas?.deselectAll).not.toHaveBeenCalled() + expect(canvas?.select).toHaveBeenCalledWith(mockNode) }) it('should deselect on ctrl+click of selected node', () => { - const nodeManager = ref(mockNodeManager) - const { handleNodeSelect } = useNodeEventHandlers(nodeManager) + const { handleNodeSelect } = useNodeEventHandlers() + const { canvas } = useCanvasStore() // Test deselecting selected node with ctrl - mockNode.selected = true + mockNode!.selected = true const ctrlClickEvent = new PointerEvent('pointerdown', { bubbles: true, @@ -174,15 +143,15 @@ describe('useNodeEventHandlers', () => { handleNodeSelect(ctrlClickEvent, testNodeData, false) - expect(mockCanvas.deselect).toHaveBeenCalledWith(mockNode) - expect(mockCanvas.select).not.toHaveBeenCalled() + expect(canvas?.deselect).toHaveBeenCalledWith(mockNode) + expect(canvas?.select).not.toHaveBeenCalled() }) it('should handle meta key (Cmd) on Mac', () => { - const nodeManager = ref(mockNodeManager) - const { handleNodeSelect } = useNodeEventHandlers(nodeManager) + const { handleNodeSelect } = useNodeEventHandlers() + const { canvas } = useCanvasStore() - mockNode.selected = false + mockNode!.selected = false const metaClickEvent = new PointerEvent('pointerdown', { bubbles: true, @@ -192,15 +161,14 @@ describe('useNodeEventHandlers', () => { handleNodeSelect(metaClickEvent, testNodeData, false) - expect(mockCanvas.select).toHaveBeenCalledWith(mockNode) - expect(mockCanvas.deselectAll).not.toHaveBeenCalled() + expect(canvas?.select).toHaveBeenCalledWith(mockNode) + expect(canvas?.deselectAll).not.toHaveBeenCalled() }) it('should bring node to front when not pinned', () => { - const nodeManager = ref(mockNodeManager) - const { handleNodeSelect } = useNodeEventHandlers(nodeManager) + const { handleNodeSelect } = useNodeEventHandlers() - mockNode.flags.pinned = false + mockNode!.flags.pinned = false const event = new PointerEvent('pointerdown') handleNodeSelect(event, testNodeData, false) @@ -211,49 +179,14 @@ describe('useNodeEventHandlers', () => { }) it('should not bring pinned node to front', () => { - const nodeManager = ref(mockNodeManager) - const { handleNodeSelect } = useNodeEventHandlers(nodeManager) + const { handleNodeSelect } = useNodeEventHandlers() - mockNode.flags.pinned = true + mockNode!.flags.pinned = true const event = new PointerEvent('pointerdown') handleNodeSelect(event, testNodeData, false) expect(mockLayoutMutations.bringNodeToFront).not.toHaveBeenCalled() }) - - it('should handle missing canvas gracefully', () => { - const nodeManager = ref(mockNodeManager) - const { handleNodeSelect } = useNodeEventHandlers(nodeManager) - - mockCanvasStore.canvas = null - - const event = new PointerEvent('pointerdown') - expect(() => { - handleNodeSelect(event, testNodeData, false) - }).not.toThrow() - - expect(mockCanvas.select).not.toHaveBeenCalled() - }) - - it('should handle missing node gracefully', () => { - const nodeManager = ref(mockNodeManager) - const { handleNodeSelect } = useNodeEventHandlers(nodeManager) - - vi.mocked(mockNodeManager.getNode).mockReturnValue(undefined) - - const event = new PointerEvent('pointerdown') - const nodeData = { - id: 'missing-node', - title: 'Missing Node', - type: 'test' - } as any - - expect(() => { - handleNodeSelect(event, nodeData, false) - }).not.toThrow() - - expect(mockCanvas.select).not.toHaveBeenCalled() - }) }) }) diff --git a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts index 875919ccd..d439d98ca 100644 --- a/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts +++ b/tests-ui/tests/renderer/extensions/vueNodes/widgets/composables/useComboWidget.test.ts @@ -2,6 +2,7 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { LGraphNode } from '@/lib/litegraph/src/litegraph' import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets' +import { useAssetBrowserDialog } from '@/platform/assets/composables/useAssetBrowserDialog' import { assetService } from '@/platform/assets/services/assetService' import { useComboWidget } from '@/renderer/extensions/vueNodes/widgets/composables/useComboWidget' import type { InputSpec } from '@/schemas/nodeDef/nodeDefSchemaV2' @@ -29,13 +30,25 @@ vi.mock('@/platform/assets/services/assetService', () => ({ } })) +vi.mock('@/platform/assets/composables/useAssetBrowserDialog', () => { + const mockAssetBrowserDialogShow = vi.fn() + return { + useAssetBrowserDialog: vi.fn(() => ({ + show: mockAssetBrowserDialogShow + })) + } +}) + // Test factory functions function createMockWidget(overrides: Partial = {}): IBaseWidget { + const mockCallback = vi.fn() return { type: 'combo', options: {}, name: 'testWidget', value: undefined, + callback: mockCallback, + y: 0, ...overrides } as IBaseWidget } @@ -45,7 +58,16 @@ function createMockNode(comfyClass = 'TestNode'): LGraphNode { node.comfyClass = comfyClass // Spy on the addWidget method - vi.spyOn(node, 'addWidget').mockReturnValue(createMockWidget()) + vi.spyOn(node, 'addWidget').mockImplementation( + (type, name, value, callback) => { + const widget = createMockWidget({ type, name, value }) + // Store the callback function on the widget for testing + if (typeof callback === 'function') { + widget.callback = callback + } + return widget + } + ) return node } @@ -61,9 +83,9 @@ function createMockInputSpec(overrides: Partial = {}): InputSpec { describe('useComboWidget', () => { beforeEach(() => { vi.clearAllMocks() - // Reset to defaults mockSettingStoreGet.mockReturnValue(false) vi.mocked(assetService.isAssetBrowserEligible).mockReturnValue(false) + vi.mocked(useAssetBrowserDialog).mockClear() }) it('should handle undefined spec', () => { diff --git a/tests-ui/tests/services/assetService.test.ts b/tests-ui/tests/services/assetService.test.ts index d96ef765b..7a719c4e4 100644 --- a/tests-ui/tests/services/assetService.test.ts +++ b/tests-ui/tests/services/assetService.test.ts @@ -4,6 +4,8 @@ import type { AssetItem } from '@/platform/assets/schemas/assetSchema' import { assetService } from '@/platform/assets/services/assetService' import { api } from '@/scripts/api' +const mockGetCategoryForNodeType = vi.fn() + vi.mock('@/stores/modelToNodeStore', () => ({ useModelToNodeStore: vi.fn(() => ({ getRegisteredNodeTypes: vi.fn( @@ -14,7 +16,13 @@ vi.mock('@/stores/modelToNodeStore', () => ({ 'VAELoader', 'TestNode' ]) - ) + ), + getCategoryForNodeType: mockGetCategoryForNodeType, + modelToNodeMap: { + checkpoints: [{ nodeDef: { name: 'CheckpointLoaderSimple' } }], + loras: [{ nodeDef: { name: 'LoraLoader' } }], + vae: [{ nodeDef: { name: 'VAELoader' } }] + } })) })) @@ -210,4 +218,87 @@ describe('assetService', () => { ).toBe(false) }) }) + + describe('getAssetsForNodeType', () => { + beforeEach(() => { + mockGetCategoryForNodeType.mockClear() + }) + + it('should return empty array for unregistered node types', async () => { + mockGetCategoryForNodeType.mockReturnValue(undefined) + + const result = await assetService.getAssetsForNodeType('UnknownNode') + + expect(mockGetCategoryForNodeType).toHaveBeenCalledWith('UnknownNode') + expect(result).toEqual([]) + }) + + it('should use getCategoryForNodeType for efficient category lookup', async () => { + mockGetCategoryForNodeType.mockReturnValue('checkpoints') + const testAssets = [MOCK_ASSETS.checkpoints] + mockApiResponse(testAssets) + + const result = await assetService.getAssetsForNodeType( + 'CheckpointLoaderSimple' + ) + + expect(mockGetCategoryForNodeType).toHaveBeenCalledWith( + 'CheckpointLoaderSimple' + ) + expect(result).toEqual(testAssets) + + // Verify API call includes correct category + expect(api.fetchApi).toHaveBeenCalledWith( + '/assets?include_tags=models,checkpoints' + ) + }) + + it('should return empty array when no category found', async () => { + mockGetCategoryForNodeType.mockReturnValue(undefined) + + const result = await assetService.getAssetsForNodeType('TestNode') + + expect(result).toEqual([]) + expect(api.fetchApi).not.toHaveBeenCalled() + }) + + it('should handle API errors gracefully', async () => { + mockGetCategoryForNodeType.mockReturnValue('loras') + mockApiError(500, 'Internal Server Error') + + await expect( + assetService.getAssetsForNodeType('LoraLoader') + ).rejects.toThrow( + 'Unable to load assets for LoraLoader: Server returned 500. Please try again.' + ) + }) + + it('should return all assets without filtering for different categories', async () => { + // Test checkpoints + mockGetCategoryForNodeType.mockReturnValue('checkpoints') + const checkpointAssets = [MOCK_ASSETS.checkpoints] + mockApiResponse(checkpointAssets) + + let result = await assetService.getAssetsForNodeType( + 'CheckpointLoaderSimple' + ) + expect(result).toEqual(checkpointAssets) + + // Test loras + mockGetCategoryForNodeType.mockReturnValue('loras') + const loraAssets = [MOCK_ASSETS.loras] + mockApiResponse(loraAssets) + + result = await assetService.getAssetsForNodeType('LoraLoader') + expect(result).toEqual(loraAssets) + + // Test vae + mockGetCategoryForNodeType.mockReturnValue('vae') + const vaeAssets = [MOCK_ASSETS.vae] + mockApiResponse(vaeAssets) + + result = await assetService.getAssetsForNodeType('VAELoader') + expect(result).toEqual(vaeAssets) + }) + }) }) diff --git a/tests-ui/tests/store/modelToNodeStore.test.ts b/tests-ui/tests/store/modelToNodeStore.test.ts index 179c76b9e..b07c34a41 100644 --- a/tests-ui/tests/store/modelToNodeStore.test.ts +++ b/tests-ui/tests/store/modelToNodeStore.test.ts @@ -19,7 +19,7 @@ const EXPECTED_DEFAULT_TYPES = [ 'gligen' ] as const -type NodeDefStoreType = typeof import('@/stores/nodeDefStore') +type NodeDefStoreType = ReturnType // Create minimal but valid ComfyNodeDefImpl for testing function createMockNodeDef(name: string): ComfyNodeDefImpl { @@ -343,6 +343,107 @@ describe('useModelToNodeStore', () => { }) }) + describe('getCategoryForNodeType', () => { + it('should return category for known node type', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple') + ).toBe('checkpoints') + expect(modelToNodeStore.getCategoryForNodeType('LoraLoader')).toBe( + 'loras' + ) + expect(modelToNodeStore.getCategoryForNodeType('VAELoader')).toBe('vae') + }) + + it('should return undefined for unknown node type', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getCategoryForNodeType('NonExistentNode') + ).toBeUndefined() + expect(modelToNodeStore.getCategoryForNodeType('')).toBeUndefined() + }) + + it('should return first category when node type exists in multiple categories', () => { + const modelToNodeStore = useModelToNodeStore() + + // Test with a node that exists in the defaults but add our own first + // Since defaults register 'StyleModelLoader' in 'style_models', + // we verify our custom registrations come after defaults in Object.entries iteration + const result = modelToNodeStore.getCategoryForNodeType('StyleModelLoader') + expect(result).toBe('style_models') // This proves the method works correctly + + // Now test that custom registrations after defaults also work + modelToNodeStore.quickRegister( + 'unicorn_styles', + 'StyleModelLoader', + 'param1' + ) + const result2 = + modelToNodeStore.getCategoryForNodeType('StyleModelLoader') + // Should still be style_models since it was registered first by defaults + expect(result2).toBe('style_models') + }) + + it('should trigger lazy registration when called before registerDefaults', () => { + const modelToNodeStore = useModelToNodeStore() + + const result = modelToNodeStore.getCategoryForNodeType( + 'CheckpointLoaderSimple' + ) + expect(result).toBe('checkpoints') + }) + + it('should be performant for repeated lookups', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + // Measure performance without assuming implementation + const start = performance.now() + for (let i = 0; i < 1000; i++) { + modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple') + } + const end = performance.now() + + // Should be fast enough for UI responsiveness + expect(end - start).toBeLessThan(10) + }) + + it('should handle invalid input types gracefully', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + // These should not throw but return undefined + expect( + modelToNodeStore.getCategoryForNodeType(null as any) + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType(undefined as any) + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType(123 as any) + ).toBeUndefined() + }) + + it('should be case-sensitive for node type matching', () => { + const modelToNodeStore = useModelToNodeStore() + modelToNodeStore.registerDefaults() + + expect( + modelToNodeStore.getCategoryForNodeType('checkpointloadersimple') + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType('CHECKPOINTLOADERSIMPLE') + ).toBeUndefined() + expect( + modelToNodeStore.getCategoryForNodeType('CheckpointLoaderSimple') + ).toBe('checkpoints') + }) + }) + describe('edge cases', () => { it('should handle empty string model type', () => { const modelToNodeStore = useModelToNodeStore() diff --git a/tsconfig.json b/tsconfig.json index de5cb4def..97346f56e 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -29,14 +29,15 @@ "rootDir": "./" }, "include": [ - "src/**/*", + ".storybook/**/*", + "eslint.config.ts", + "global.d.ts", + "knip.config.ts", "src/**/*.vue", + "src/**/*", "src/types/**/*.d.ts", "tests-ui/**/*", - "global.d.ts", - "eslint.config.ts", "vite.config.mts", - "knip.config.ts", - ".storybook/**/*" + "vitest.config.ts", ] } diff --git a/vite.config.mts b/vite.config.mts index 0ca062273..25cd730aa 100644 --- a/vite.config.mts +++ b/vite.config.mts @@ -31,7 +31,8 @@ export default defineConfig({ ignored: [ '**/coverage/**', '**/playwright-report/**', - '**/*.{test,spec}.ts' + '**/*.{test,spec}.ts', + '*.config.{ts,mts}' ] }, proxy: { diff --git a/vitest.config.ts b/vitest.config.ts index 36fdb1a00..02565497e 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -32,7 +32,8 @@ export default defineConfig({ '**/.{idea,git,cache,output,temp}/**', '**/{karma,rollup,webpack,vite,vitest,jest,ava,babel,nyc,cypress,tsup,build,eslint,prettier}.config.*', 'src/lib/litegraph/test/**' - ] + ], + silent: 'passed-only' }, resolve: { alias: {