diff --git a/src/composables/canvas/useSelectedLiteGraphItems.ts b/src/composables/canvas/useSelectedLiteGraphItems.ts index 3d44fe089..32c29a1b2 100644 --- a/src/composables/canvas/useSelectedLiteGraphItems.ts +++ b/src/composables/canvas/useSelectedLiteGraphItems.ts @@ -1,6 +1,16 @@ -import { Positionable, Reroute } from '@comfyorg/litegraph' +import { + LGraphEventMode, + LGraphNode, + Positionable, + Reroute +} from '@comfyorg/litegraph' +import { app } from '@/scripts/app' import { useCanvasStore } from '@/stores/graphStore' +import { + collectFromNodes, + traverseNodesDepthFirst +} from '@/utils/graphTraversalUtil' /** * Composable for handling selected LiteGraph items filtering and operations. @@ -61,11 +71,92 @@ export function useSelectedLiteGraphItems() { return getSelectableItems().size > 1 } + /** + * Get only the selected nodes (LGraphNode instances) from the canvas. + * This filters out other types of selected items like groups or reroutes. + * If a selected node is a subgraph, this also includes all nodes within it. + * @returns Array of selected LGraphNode instances and their descendants. + */ + const getSelectedNodes = (): LGraphNode[] => { + const selectedNodes = app.canvas.selected_nodes + if (!selectedNodes) return [] + + // Convert selected_nodes object to array, preserving order + const nodeArray: LGraphNode[] = [] + for (const i in selectedNodes) { + nodeArray.push(selectedNodes[i]) + } + + // Check if any selected nodes are subgraphs + const hasSubgraphs = nodeArray.some( + (node) => node.isSubgraphNode?.() && node.subgraph + ) + + // If no subgraphs, just return the array directly to preserve order + if (!hasSubgraphs) { + return nodeArray + } + + // Use collectFromNodes to get all nodes including those in subgraphs + return collectFromNodes(nodeArray) + } + + /** + * Toggle the execution mode of all selected nodes with unified subgraph behavior. + * + * Top-level behavior (selected nodes): Standard toggle logic + * - If the selected node is already in the specified mode → set to ALWAYS + * - Otherwise → set to the specified mode + * + * Subgraph behavior (children of selected subgraph nodes): Unified state application + * - All children inherit the same mode that their parent subgraph node was set to + * - This creates predictable behavior: if you toggle a subgraph to "mute", + * ALL nodes inside become muted, regardless of their previous individual states + * + * @param mode - The LGraphEventMode to toggle to (e.g., NEVER for mute, BYPASS for bypass) + */ + const toggleSelectedNodesMode = (mode: LGraphEventMode): void => { + const selectedNodes = app.canvas.selected_nodes + if (!selectedNodes) return + + // Convert selected_nodes object to array + const selectedNodeArray: LGraphNode[] = [] + for (const i in selectedNodes) { + selectedNodeArray.push(selectedNodes[i]) + } + + // Process each selected node independently to determine its target state and apply to children + selectedNodeArray.forEach((selectedNode) => { + // Apply standard toggle logic to the selected node itself + const newModeForSelectedNode = + selectedNode.mode === mode ? LGraphEventMode.ALWAYS : mode + + selectedNode.mode = newModeForSelectedNode + + // If this selected node is a subgraph, apply the same mode uniformly to all its children + // This ensures predictable behavior: all children get the same state as their parent + if (selectedNode.isSubgraphNode?.() && selectedNode.subgraph) { + traverseNodesDepthFirst([selectedNode], { + visitor: (node) => { + // Skip the parent node since we already handled it above + if (node === selectedNode) return undefined + + // Apply the parent's new mode to all children uniformly + node.mode = newModeForSelectedNode + return undefined + } + }) + } + }) + } + return { isIgnoredItem, filterSelectableItems, getSelectableItems, hasSelectableItems, - hasMultipleSelectableItems + hasMultipleSelectableItems, + getSelectedNodes, + toggleSelectedNodesMode } } diff --git a/src/composables/useCoreCommands.ts b/src/composables/useCoreCommands.ts index d18e96f00..26b795b62 100644 --- a/src/composables/useCoreCommands.ts +++ b/src/composables/useCoreCommands.ts @@ -7,6 +7,7 @@ import { import { Point } from '@comfyorg/litegraph' import { useFirebaseAuthActions } from '@/composables/auth/useFirebaseAuthActions' +import { useSelectedLiteGraphItems } from '@/composables/canvas/useSelectedLiteGraphItems' import { DEFAULT_DARK_COLOR_PALETTE, DEFAULT_LIGHT_COLOR_PALETTE @@ -46,30 +47,10 @@ export function useCoreCommands(): ComfyCommand[] { const toastStore = useToastStore() const canvasStore = useCanvasStore() const executionStore = useExecutionStore() + const { getSelectedNodes, toggleSelectedNodesMode } = + useSelectedLiteGraphItems() const getTracker = () => workflowStore.activeWorkflow?.changeTracker - const getSelectedNodes = (): LGraphNode[] => { - const selectedNodes = app.canvas.selected_nodes - const result: LGraphNode[] = [] - if (selectedNodes) { - for (const i in selectedNodes) { - const node = selectedNodes[i] - result.push(node) - } - } - return result - } - - const toggleSelectedNodesMode = (mode: LGraphEventMode) => { - getSelectedNodes().forEach((node) => { - if (node.mode === mode) { - node.mode = LGraphEventMode.ALWAYS - } else { - node.mode = mode - } - }) - } - const moveSelectedNodes = ( positionUpdater: (pos: Point, gridSize: number) => Point ) => { diff --git a/src/utils/graphTraversalUtil.ts b/src/utils/graphTraversalUtil.ts index 3adb87726..db8c57607 100644 --- a/src/utils/graphTraversalUtil.ts +++ b/src/utils/graphTraversalUtil.ts @@ -388,20 +388,34 @@ export function getAllNonIoNodesInSubgraph(subgraph: Subgraph): LGraphNode[] { return subgraph.nodes.filter((node) => !isSubgraphIoNode(node)) } +/** + * Options for traverseNodesDepthFirst function + */ +export interface TraverseNodesOptions { + /** Function called for each node during traversal */ + visitor?: (node: LGraphNode, context: T) => T + /** Initial context value */ + initialContext?: T + /** Whether to traverse into subgraph nodes (default: true) */ + expandSubgraphs?: boolean +} + /** * Performs depth-first traversal of nodes and their subgraphs. * Generic visitor pattern that can be used for various node processing tasks. * * @param nodes - Starting nodes for traversal - * @param visitor - Function called for each node with its context - * @param expandSubgraphs - Whether to traverse into subgraph nodes (default: true) + * @param options - Optional traversal configuration */ -export function traverseNodesDepthFirst( +export function traverseNodesDepthFirst( nodes: LGraphNode[], - visitor: (node: LGraphNode, context: T) => T, - initialContext: T, - expandSubgraphs: boolean = true + options?: TraverseNodesOptions ): void { + const { + visitor = () => undefined as T, + initialContext = undefined as T, + expandSubgraphs = true + } = options || {} type StackItem = { node: LGraphNode; context: T } const stack: StackItem[] = [] @@ -429,28 +443,42 @@ export function traverseNodesDepthFirst( } } +/** + * Options for collectFromNodes function + */ +export interface CollectFromNodesOptions { + /** Function that returns data to collect for each node */ + collector?: (node: LGraphNode, context: C) => T | null + /** Function that builds context for child nodes */ + contextBuilder?: (node: LGraphNode, parentContext: C) => C + /** Initial context value */ + initialContext?: C + /** Whether to traverse into subgraph nodes (default: true) */ + expandSubgraphs?: boolean +} + /** * Collects nodes with custom data during depth-first traversal. * Generic collector that can gather any type of data per node. * * @param nodes - Starting nodes for traversal - * @param collector - Function that returns data to collect for each node - * @param contextBuilder - Function that builds context for child nodes - * @param expandSubgraphs - Whether to traverse into subgraph nodes + * @param options - Optional collection configuration * @returns Array of collected data */ -export function collectFromNodes( +export function collectFromNodes( nodes: LGraphNode[], - collector: (node: LGraphNode, context: C) => T | null, - contextBuilder: (node: LGraphNode, parentContext: C) => C, - initialContext: C, - expandSubgraphs: boolean = true + options?: CollectFromNodesOptions ): T[] { + const { + collector = (node: LGraphNode) => node as unknown as T, + contextBuilder = () => undefined as C, + initialContext = undefined as C, + expandSubgraphs = true + } = options || {} const results: T[] = [] - traverseNodesDepthFirst( - nodes, - (node, context) => { + traverseNodesDepthFirst(nodes, { + visitor: (node, context) => { const data = collector(node, context) if (data !== null) { results.push(data) @@ -459,7 +487,7 @@ export function collectFromNodes( }, initialContext, expandSubgraphs - ) + }) return results } @@ -474,19 +502,16 @@ export function collectFromNodes( export function getExecutionIdsForSelectedNodes( selectedNodes: LGraphNode[] ): NodeExecutionId[] { - return collectFromNodes( - selectedNodes, - // Collector: build execution ID for each node - (node, parentExecutionId: string): NodeExecutionId => { + return collectFromNodes(selectedNodes, { + collector: (node, parentExecutionId) => { const nodeId = String(node.id) return parentExecutionId ? `${parentExecutionId}:${nodeId}` : nodeId }, - // Context builder: pass execution ID to children - (node, parentExecutionId: string) => { + contextBuilder: (node, parentExecutionId) => { const nodeId = String(node.id) return parentExecutionId ? `${parentExecutionId}:${nodeId}` : nodeId }, - '', // Initial context: empty parent execution ID - true // Expand subgraphs - ) + initialContext: '', + expandSubgraphs: true + }) } diff --git a/tests-ui/tests/composables/canvas/useSelectedLiteGraphItems.test.ts b/tests-ui/tests/composables/canvas/useSelectedLiteGraphItems.test.ts index 4bbbc906b..f42c9c8ed 100644 --- a/tests-ui/tests/composables/canvas/useSelectedLiteGraphItems.test.ts +++ b/tests-ui/tests/composables/canvas/useSelectedLiteGraphItems.test.ts @@ -1,14 +1,34 @@ -import { Positionable, Reroute } from '@comfyorg/litegraph' +import { + LGraphEventMode, + LGraphNode, + Positionable, + Reroute +} from '@comfyorg/litegraph' import { createPinia, setActivePinia } from 'pinia' import { beforeEach, describe, expect, it, vi } from 'vitest' import { useSelectedLiteGraphItems } from '@/composables/canvas/useSelectedLiteGraphItems' +import { app } from '@/scripts/app' import { useCanvasStore } from '@/stores/graphStore' +// Mock the app module +vi.mock('@/scripts/app', () => ({ + app: { + canvas: { + selected_nodes: null + } + } +})) + // Mock the litegraph module vi.mock('@comfyorg/litegraph', () => ({ Reroute: class Reroute { constructor() {} + }, + LGraphEventMode: { + ALWAYS: 0, + NEVER: 2, + BYPASS: 4 } })) @@ -181,6 +201,142 @@ describe('useSelectedLiteGraphItems', () => { }) }) + describe('node-specific methods', () => { + it('getSelectedNodes should return only LGraphNode instances', () => { + const { getSelectedNodes } = useSelectedLiteGraphItems() + const node1 = { id: 1, mode: LGraphEventMode.ALWAYS } as LGraphNode + const node2 = { id: 2, mode: LGraphEventMode.NEVER } as LGraphNode + + // Mock app.canvas.selected_nodes + app.canvas.selected_nodes = { '0': node1, '1': node2 } + + const selectedNodes = getSelectedNodes() + expect(selectedNodes).toHaveLength(2) + expect(selectedNodes[0]).toBe(node1) + expect(selectedNodes[1]).toBe(node2) + }) + + it('getSelectedNodes should return empty array when no nodes selected', () => { + const { getSelectedNodes } = useSelectedLiteGraphItems() + + // @ts-expect-error - Testing null case + app.canvas.selected_nodes = null + + const selectedNodes = getSelectedNodes() + expect(selectedNodes).toHaveLength(0) + }) + + it('toggleSelectedNodesMode should toggle node modes correctly', () => { + const { toggleSelectedNodesMode } = useSelectedLiteGraphItems() + const node1 = { id: 1, mode: LGraphEventMode.ALWAYS } as LGraphNode + const node2 = { id: 2, mode: LGraphEventMode.NEVER } as LGraphNode + + app.canvas.selected_nodes = { '0': node1, '1': node2 } + + // Toggle to NEVER mode + toggleSelectedNodesMode(LGraphEventMode.NEVER) + + // node1 should change from ALWAYS to NEVER + // node2 should change from NEVER to ALWAYS (since it was already NEVER) + expect(node1.mode).toBe(LGraphEventMode.NEVER) + expect(node2.mode).toBe(LGraphEventMode.ALWAYS) + }) + + it('toggleSelectedNodesMode should set mode to ALWAYS when already in target mode', () => { + const { toggleSelectedNodesMode } = useSelectedLiteGraphItems() + const node = { id: 1, mode: LGraphEventMode.BYPASS } as LGraphNode + + app.canvas.selected_nodes = { '0': node } + + // Toggle to BYPASS mode (node is already BYPASS) + toggleSelectedNodesMode(LGraphEventMode.BYPASS) + + // Should change to ALWAYS + expect(node.mode).toBe(LGraphEventMode.ALWAYS) + }) + + it('getSelectedNodes should include nodes from subgraphs', () => { + const { getSelectedNodes } = useSelectedLiteGraphItems() + const subNode1 = { id: 11, mode: LGraphEventMode.ALWAYS } as LGraphNode + const subNode2 = { id: 12, mode: LGraphEventMode.NEVER } as LGraphNode + const subgraphNode = { + id: 1, + mode: LGraphEventMode.ALWAYS, + isSubgraphNode: () => true, + subgraph: { + nodes: [subNode1, subNode2] + } + } as unknown as LGraphNode + const regularNode = { id: 2, mode: LGraphEventMode.NEVER } as LGraphNode + + app.canvas.selected_nodes = { '0': subgraphNode, '1': regularNode } + + const selectedNodes = getSelectedNodes() + expect(selectedNodes).toHaveLength(4) // subgraphNode + 2 sub nodes + regularNode + expect(selectedNodes).toContainEqual(subgraphNode) + expect(selectedNodes).toContainEqual(regularNode) + expect(selectedNodes).toContainEqual(subNode1) + expect(selectedNodes).toContainEqual(subNode2) + }) + + it('toggleSelectedNodesMode should apply unified state to subgraph children', () => { + const { toggleSelectedNodesMode } = useSelectedLiteGraphItems() + const subNode1 = { id: 11, mode: LGraphEventMode.ALWAYS } as LGraphNode + const subNode2 = { id: 12, mode: LGraphEventMode.NEVER } as LGraphNode + const subgraphNode = { + id: 1, + mode: LGraphEventMode.ALWAYS, + isSubgraphNode: () => true, + subgraph: { + nodes: [subNode1, subNode2] + } + } as unknown as LGraphNode + const regularNode = { id: 2, mode: LGraphEventMode.BYPASS } as LGraphNode + + app.canvas.selected_nodes = { '0': subgraphNode, '1': regularNode } + + // Toggle to NEVER mode + toggleSelectedNodesMode(LGraphEventMode.NEVER) + + // Selected nodes follow standard toggle logic: + // subgraphNode: ALWAYS -> NEVER (since ALWAYS != NEVER) + expect(subgraphNode.mode).toBe(LGraphEventMode.NEVER) + // regularNode: BYPASS -> NEVER (since BYPASS != NEVER) + expect(regularNode.mode).toBe(LGraphEventMode.NEVER) + + // Subgraph children get unified state (same as their parent): + // Both children should now be NEVER, regardless of their previous states + expect(subNode1.mode).toBe(LGraphEventMode.NEVER) // was ALWAYS, now NEVER + expect(subNode2.mode).toBe(LGraphEventMode.NEVER) // was NEVER, stays NEVER + }) + + it('toggleSelectedNodesMode should toggle to ALWAYS when subgraph is already in target mode', () => { + const { toggleSelectedNodesMode } = useSelectedLiteGraphItems() + const subNode1 = { id: 11, mode: LGraphEventMode.ALWAYS } as LGraphNode + const subNode2 = { id: 12, mode: LGraphEventMode.BYPASS } as LGraphNode + const subgraphNode = { + id: 1, + mode: LGraphEventMode.NEVER, // Already in NEVER mode + isSubgraphNode: () => true, + subgraph: { + nodes: [subNode1, subNode2] + } + } as unknown as LGraphNode + + app.canvas.selected_nodes = { '0': subgraphNode } + + // Toggle to NEVER mode (but subgraphNode is already NEVER) + toggleSelectedNodesMode(LGraphEventMode.NEVER) + + // Selected subgraph should toggle to ALWAYS (since it was already NEVER) + expect(subgraphNode.mode).toBe(LGraphEventMode.ALWAYS) + + // All children should also get ALWAYS (unified with parent's new state) + expect(subNode1.mode).toBe(LGraphEventMode.ALWAYS) + expect(subNode2.mode).toBe(LGraphEventMode.ALWAYS) + }) + }) + describe('dynamic behavior', () => { it('methods should reflect changes when selectedItems change', () => { const { diff --git a/tests-ui/tests/utils/graphTraversalUtil.test.ts b/tests-ui/tests/utils/graphTraversalUtil.test.ts index 48ff7351b..185f7454c 100644 --- a/tests-ui/tests/utils/graphTraversalUtil.test.ts +++ b/tests-ui/tests/utils/graphTraversalUtil.test.ts @@ -820,14 +820,13 @@ describe('graphTraversalUtil', () => { createMockNode('3') ] - traverseNodesDepthFirst( - nodes, - (node, context) => { + traverseNodesDepthFirst(nodes, { + visitor: (node, context) => { visited.push(`${node.id}:${context}`) return `${context}-${node.id}` }, - 'root' - ) + initialContext: 'root' + }) expect(visited).toEqual(['3:root', '2:root', '1:root']) // DFS processes in LIFO order }) @@ -841,14 +840,13 @@ describe('graphTraversalUtil', () => { createMockNode('2', { isSubgraph: true, subgraph }) ] - traverseNodesDepthFirst( - nodes, - (node, depth: number) => { + traverseNodesDepthFirst(nodes, { + visitor: (node, depth: number) => { visited.push(`${node.id}:${depth}`) return depth + 1 }, - 0 - ) + initialContext: 0 + }) expect(visited).toEqual(['2:0', 'sub1:1', '1:0']) // DFS: last node first, then its children }) @@ -862,15 +860,14 @@ describe('graphTraversalUtil', () => { createMockNode('2', { isSubgraph: true, subgraph }) ] - traverseNodesDepthFirst( - nodes, - (node, context) => { + traverseNodesDepthFirst(nodes, { + visitor: (node, context) => { visited.push(String(node.id)) return context }, - null, - false - ) + initialContext: null, + expandSubgraphs: false + }) expect(visited).toEqual(['2', '1']) // DFS processes in LIFO order expect(visited).not.toContain('sub1') @@ -893,14 +890,13 @@ describe('graphTraversalUtil', () => { subgraph: midSubgraph }) - traverseNodesDepthFirst( - [topNode], - (node, path: string) => { + traverseNodesDepthFirst([topNode], { + visitor: (node, path: string) => { visited.push(`${node.id}:${path}`) return path ? `${path}/${node.id}` : String(node.id) }, - '' - ) + initialContext: '' + }) expect(visited).toEqual(['100:', '200:100', '300:100/200']) }) @@ -914,12 +910,11 @@ describe('graphTraversalUtil', () => { createMockNode('3') ] - const results = collectFromNodes( - nodes, - (node) => `node-${node.id}`, - (_node, context) => context, - null - ) + const results = collectFromNodes(nodes, { + collector: (node) => `node-${node.id}`, + contextBuilder: (_node, context) => context, + initialContext: null + }) expect(results).toEqual(['node-3', 'node-2', 'node-1']) // DFS processes in LIFO order }) @@ -931,12 +926,11 @@ describe('graphTraversalUtil', () => { createMockNode('3') ] - const results = collectFromNodes( - nodes, - (node) => (Number(node.id) > 1 ? `node-${node.id}` : null), - (_node, context) => context, - null - ) + const results = collectFromNodes(nodes, { + collector: (node) => (Number(node.id) > 1 ? `node-${node.id}` : null), + contextBuilder: (_node, context) => context, + initialContext: null + }) expect(results).toEqual(['node-3', 'node-2']) // DFS processes in LIFO order, node-1 filtered out }) @@ -949,13 +943,12 @@ describe('graphTraversalUtil', () => { createMockNode('2', { isSubgraph: true, subgraph }) ] - const results = collectFromNodes( - nodes, - (node, prefix: string) => `${prefix}${node.id}`, - (node, prefix: string) => `${prefix}${node.id}-`, - 'node-', - true - ) + const results = collectFromNodes(nodes, { + collector: (node, prefix: string) => `${prefix}${node.id}`, + contextBuilder: (node, prefix: string) => `${prefix}${node.id}-`, + initialContext: 'node-', + expandSubgraphs: true + }) expect(results).toEqual([ 'node-2', @@ -973,13 +966,12 @@ describe('graphTraversalUtil', () => { createMockNode('2', { isSubgraph: true, subgraph }) ] - const results = collectFromNodes( - nodes, - (node) => String(node.id), - (_node, context) => context, - null, - false - ) + const results = collectFromNodes(nodes, { + collector: (node) => String(node.id), + contextBuilder: (_node, context) => context, + initialContext: null, + expandSubgraphs: false + }) expect(results).toEqual(['2', '1']) // DFS processes in LIFO order })