diff --git a/src/extensions/core/uploadAudio.ts b/src/extensions/core/uploadAudio.ts index 94a16d04d..afd52f4cf 100644 --- a/src/extensions/core/uploadAudio.ts +++ b/src/extensions/core/uploadAudio.ts @@ -1,6 +1,7 @@ import type { LGraphNode } from '@comfyorg/litegraph' import type { IStringWidget } from '@comfyorg/litegraph/dist/types/widgets' +import { useChainCallback } from '@/composables/functional/useChainCallback' import { useNodeDragAndDrop } from '@/composables/node/useNodeDragAndDrop' import { useNodeFileInput } from '@/composables/node/useNodeFileInput' import { useNodePaste } from '@/composables/node/useNodePaste' @@ -9,6 +10,8 @@ import type { ResultItemType } from '@/schemas/apiSchema' import type { ComfyNodeDef } from '@/schemas/nodeDefSchema' import type { DOMWidget } from '@/scripts/domWidget' import { useToastStore } from '@/stores/toastStore' +import { NodeLocatorId } from '@/types' +import { getNodeByLocatorId } from '@/utils/graphTraversalUtil' import { api } from '../../scripts/api' import { app } from '../../scripts/app' @@ -137,14 +140,27 @@ app.registerExtension({ audioUIWidget.element.classList.remove('empty-audio-widget') } } + + audioUIWidget.onRemove = useChainCallback( + audioUIWidget.onRemove, + () => { + if (!audioUIWidget.element) return + audioUIWidget.element.pause() + audioUIWidget.element.src = '' + audioUIWidget.element.remove() + } + ) + return { widget: audioUIWidget } } } }, - onNodeOutputsUpdated(nodeOutputs: Record) { - for (const [nodeId, output] of Object.entries(nodeOutputs)) { - const node = app.graph.getNodeById(nodeId) + onNodeOutputsUpdated(nodeOutputs: Record) { + for (const [nodeLocatorId, output] of Object.entries(nodeOutputs)) { if ('audio' in output) { + const node = getNodeByLocatorId(app.graph, nodeLocatorId) + if (!node) continue + // @ts-expect-error fixme ts strict error const audioUIWidget = node.widgets.find( (w) => w.name === 'audioUI' diff --git a/src/scripts/app.ts b/src/scripts/app.ts index e7edebbc0..2061a5f6a 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -47,6 +47,7 @@ import { useDomWidgetStore } from '@/stores/domWidgetStore' import { useExecutionStore } from '@/stores/executionStore' import { useExtensionStore } from '@/stores/extensionStore' import { useFirebaseAuthStore } from '@/stores/firebaseAuthStore' +import { useNodeOutputStore } from '@/stores/imagePreviewStore' import { KeyComboImpl, useKeybindingStore } from '@/stores/keybindingStore' import { useModelStore } from '@/stores/modelStore' import { SYSTEM_NODE_DEFS, useNodeDefStore } from '@/stores/nodeDefStore' @@ -60,6 +61,10 @@ import type { ComfyExtension, MissingNodeType } from '@/types/comfy' import { ExtensionManager } from '@/types/extensionTypes' import { ColorAdjustOptions, adjustColor } from '@/utils/colorUtil' import { graphToPrompt } from '@/utils/executionUtil' +import { + getNodeByExecutionId, + triggerCallbackOnAllNodes +} from '@/utils/graphTraversalUtil' import { executeWidgetsCallback, fixLinkInputSlots, @@ -640,29 +645,21 @@ export class ComfyApp { }) api.addEventListener('executed', ({ detail }) => { - const output = this.nodeOutputs[detail.display_node || detail.node] - if (detail.merge && output) { - for (const k in detail.output ?? {}) { - const v = output[k] - if (v instanceof Array) { - output[k] = v.concat(detail.output[k]) - } else { - output[k] = detail.output[k] - } - } - } else { - this.nodeOutputs[detail.display_node || detail.node] = detail.output - } - const node = this.graph.getNodeById(detail.display_node || detail.node) - if (node) { - if (node.onExecuted) node.onExecuted(detail.output) + const nodeOutputStore = useNodeOutputStore() + const executionId = String(detail.display_node || detail.node) + + nodeOutputStore.setNodeOutputsByExecutionId(executionId, detail.output, { + merge: detail.merge + }) + + const node = getNodeByExecutionId(this.graph, executionId) + if (node && node.onExecuted) { + node.onExecuted(detail.output) } }) api.addEventListener('execution_start', () => { - this.graph.nodes.forEach((node) => { - if (node.onExecutionStart) node.onExecutionStart() - }) + triggerCallbackOnAllNodes(this.graph, 'onExecutionStart') }) api.addEventListener('execution_error', ({ detail }) => { @@ -690,11 +687,13 @@ export class ComfyApp { api.addEventListener('b_preview_with_metadata', ({ detail }) => { // Enhanced preview with explicit node context const { blob, displayNodeId } = detail + const { setNodePreviewsByExecutionId, revokePreviewsByExecutionId } = + useNodeOutputStore() // Ensure clean up if `executing` event is missed. - this.revokePreviews(displayNodeId) + revokePreviewsByExecutionId(displayNodeId) const blobUrl = URL.createObjectURL(blob) - // Preview cleanup is now handled in progress_state event to support multiple concurrent previews - this.nodePreviewImages[displayNodeId] = [blobUrl] + // Preview cleanup is handled in progress_state event to support multiple concurrent previews + setNodePreviewsByExecutionId(displayNodeId, [blobUrl]) }) api.init() @@ -1673,25 +1672,13 @@ export class ComfyApp { } } - /** - * Frees memory allocated to image preview blobs for a specific node, by revoking the URLs associated with them. - * @param nodeId ID of the node to revoke all preview images of - */ - revokePreviews(nodeId: NodeId) { - if (!this.nodePreviewImages[nodeId]?.[Symbol.iterator]) return - for (const url of this.nodePreviewImages[nodeId]) { - URL.revokeObjectURL(url) - } - } /** * Clean current state */ clean() { this.nodeOutputs = {} - for (const id of Object.keys(this.nodePreviewImages)) { - this.revokePreviews(id) - } - this.nodePreviewImages = {} + const { revokeAllPreviews } = useNodeOutputStore() + revokeAllPreviews() const executionStore = useExecutionStore() executionStore.lastNodeErrors = null executionStore.lastExecutionError = null diff --git a/src/stores/executionStore.ts b/src/stores/executionStore.ts index 377528769..bc33b75cd 100644 --- a/src/stores/executionStore.ts +++ b/src/stores/executionStore.ts @@ -24,6 +24,7 @@ import type { } from '@/schemas/comfyWorkflowSchema' import { api } from '@/scripts/api' import { app } from '@/scripts/app' +import { useNodeOutputStore } from '@/stores/imagePreviewStore' import type { NodeLocatorId } from '@/types/nodeIdentification' import { createNodeLocatorId } from '@/types/nodeIdentification' @@ -229,9 +230,9 @@ export const useExecutionStore = defineStore('execution', () => { api.addEventListener('progress_state', handleProgressState) api.addEventListener('status', handleStatus) api.addEventListener('execution_error', handleExecutionError) + api.addEventListener('progress_text', handleProgressText) + api.addEventListener('display_component', handleDisplayComponent) } - api.addEventListener('progress_text', handleProgressText) - api.addEventListener('display_component', handleDisplayComponent) function unbindExecutionEvents() { api.removeEventListener('execution_start', handleExecutionStart) @@ -244,6 +245,7 @@ export const useExecutionStore = defineStore('execution', () => { api.removeEventListener('status', handleStatus) api.removeEventListener('execution_error', handleExecutionError) api.removeEventListener('progress_text', handleProgressText) + api.removeEventListener('display_component', handleDisplayComponent) } function handleExecutionStart(e: CustomEvent) { @@ -294,8 +296,8 @@ export const useExecutionStore = defineStore('execution', () => { // Note that we're doing the *actual* node id instead of the display node id // here intentionally. That way, we don't clear the preview every time a new node // within an expanded graph starts executing. - app.revokePreviews(nodeId) - delete app.nodePreviewImages[nodeId] + const { revokePreviewsByExecutionId } = useNodeOutputStore() + revokePreviewsByExecutionId(nodeId) } } diff --git a/src/stores/imagePreviewStore.ts b/src/stores/imagePreviewStore.ts index b6d88c9bb..b2c19b8f9 100644 --- a/src/stores/imagePreviewStore.ts +++ b/src/stores/imagePreviewStore.ts @@ -8,6 +8,9 @@ import { } from '@/schemas/apiSchema' import { api } from '@/scripts/api' import { app } from '@/scripts/app' +import { useExecutionStore } from '@/stores/executionStore' +import { useWorkflowStore } from '@/stores/workflowStore' +import type { NodeLocatorId } from '@/types/nodeIdentification' import { parseFilePath } from '@/utils/formatUtil' import { isVideoNode } from '@/utils/litegraphUtil' @@ -22,17 +25,22 @@ const createOutputs = ( } } +interface SetOutputOptions { + merge?: boolean +} + export const useNodeOutputStore = defineStore('nodeOutput', () => { - const getNodeId = (node: LGraphNode): string => node.id.toString() + const { nodeIdToNodeLocatorId } = useWorkflowStore() + const { executionIdToNodeLocatorId } = useExecutionStore() function getNodeOutputs( node: LGraphNode ): ExecutedWsMessage['output'] | undefined { - return app.nodeOutputs[getNodeId(node)] + return app.nodeOutputs[nodeIdToNodeLocatorId(node.id)] } function getNodePreviews(node: LGraphNode): string[] | undefined { - return app.nodePreviewImages[getNodeId(node)] + return app.nodePreviewImages[nodeIdToNodeLocatorId(node.id)] } /** @@ -86,6 +94,35 @@ export const useNodeOutputStore = defineStore('nodeOutput', () => { }) } + /** + * Internal function to set outputs by NodeLocatorId. + * Handles the merge logic when needed. + */ + function setOutputsByLocatorId( + nodeLocatorId: NodeLocatorId, + outputs: ExecutedWsMessage['output'] | ResultItem, + options: SetOutputOptions = {} + ) { + if (options.merge) { + const existingOutput = app.nodeOutputs[nodeLocatorId] + if (existingOutput && outputs) { + for (const k in outputs) { + const existingValue = existingOutput[k] + const newValue = (outputs as Record)[k] + + if (Array.isArray(existingValue) && Array.isArray(newValue)) { + existingOutput[k] = existingValue.concat(newValue) + } else { + existingOutput[k] = newValue + } + } + return + } + } + + app.nodeOutputs[nodeLocatorId] = outputs + } + function setNodeOutputs( node: LGraphNode, filenames: string | string[] | ResultItem, @@ -96,24 +133,149 @@ export const useNodeOutputStore = defineStore('nodeOutput', () => { ) { if (!filenames || !node) return - const nodeId = getNodeId(node) - if (typeof filenames === 'string') { - app.nodeOutputs[nodeId] = createOutputs([filenames], folder, isAnimated) + setNodeOutputsByNodeId( + node.id, + createOutputs([filenames], folder, isAnimated) + ) } else if (!Array.isArray(filenames)) { - app.nodeOutputs[nodeId] = filenames + setNodeOutputsByNodeId(node.id, filenames) } else { const resultItems = createOutputs(filenames, folder, isAnimated) if (!resultItems?.images?.length) return - app.nodeOutputs[nodeId] = resultItems + setNodeOutputsByNodeId(node.id, resultItems) } } + /** + * Set node outputs by execution ID (hierarchical ID from backend). + * Converts the execution ID to a NodeLocatorId before storing. + * + * @param executionId - The execution ID (e.g., "123:456:789" or "789") + * @param outputs - The outputs to store + * @param options - Options for setting outputs + * @param options.merge - If true, merge with existing outputs (arrays are concatenated) + */ + function setNodeOutputsByExecutionId( + executionId: string, + outputs: ExecutedWsMessage['output'] | ResultItem, + options: SetOutputOptions = {} + ) { + const nodeLocatorId = executionIdToNodeLocatorId(executionId) + if (!nodeLocatorId) return + + setOutputsByLocatorId(nodeLocatorId, outputs, options) + } + + /** + * Set node outputs by node ID. + * Uses the current graph context to create the appropriate NodeLocatorId. + * + * @param nodeId - The node ID + * @param outputs - The outputs to store + * @param options - Options for setting outputs + * @param options.merge - If true, merge with existing outputs (arrays are concatenated) + */ + function setNodeOutputsByNodeId( + nodeId: string | number, + outputs: ExecutedWsMessage['output'] | ResultItem, + options: SetOutputOptions = {} + ) { + const nodeLocatorId = nodeIdToNodeLocatorId(nodeId) + if (!nodeLocatorId) return + + setOutputsByLocatorId(nodeLocatorId, outputs, options) + } + + /** + * Set node preview images by execution ID (hierarchical ID from backend). + * Converts the execution ID to a NodeLocatorId before storing. + * + * @param executionId - The execution ID (e.g., "123:456:789" or "789") + * @param previewImages - Array of preview image URLs to store + */ + function setNodePreviewsByExecutionId( + executionId: string, + previewImages: string[] + ) { + const nodeLocatorId = executionIdToNodeLocatorId(executionId) + if (!nodeLocatorId) return + + app.nodePreviewImages[nodeLocatorId] = previewImages + } + + /** + * Set node preview images by node ID. + * Uses the current graph context to create the appropriate NodeLocatorId. + * + * @param nodeId - The node ID + * @param previewImages - Array of preview image URLs to store + */ + function setNodePreviewsByNodeId( + nodeId: string | number, + previewImages: string[] + ) { + const nodeLocatorId = nodeIdToNodeLocatorId(nodeId) + app.nodePreviewImages[nodeLocatorId] = previewImages + } + + /** + * Revoke preview images by execution ID. + * Frees memory allocated to image preview blobs by revoking the URLs. + * + * @param executionId - The execution ID + */ + function revokePreviewsByExecutionId(executionId: string) { + const nodeLocatorId = executionIdToNodeLocatorId(executionId) + if (!nodeLocatorId) return + + revokePreviewsByLocatorId(nodeLocatorId) + } + + /** + * Revoke preview images by node locator ID. + * Frees memory allocated to image preview blobs by revoking the URLs. + * + * @param nodeLocatorId - The node locator ID + */ + function revokePreviewsByLocatorId(nodeLocatorId: NodeLocatorId) { + const previews = app.nodePreviewImages[nodeLocatorId] + if (!previews?.[Symbol.iterator]) return + + for (const url of previews) { + URL.revokeObjectURL(url) + } + + delete app.nodePreviewImages[nodeLocatorId] + } + + /** + * Revoke all preview images. + * Frees memory allocated to all image preview blobs. + */ + function revokeAllPreviews() { + for (const nodeLocatorId of Object.keys(app.nodePreviewImages)) { + const previews = app.nodePreviewImages[nodeLocatorId] + if (!previews?.[Symbol.iterator]) continue + + for (const url of previews) { + URL.revokeObjectURL(url) + } + } + app.nodePreviewImages = {} + } + return { getNodeOutputs, getNodeImageUrls, getNodePreviews, setNodeOutputs, + setNodeOutputsByExecutionId, + setNodeOutputsByNodeId, + setNodePreviewsByExecutionId, + setNodePreviewsByNodeId, + revokePreviewsByExecutionId, + revokeAllPreviews, getPreviewParam } }) diff --git a/src/stores/queueStore.ts b/src/stores/queueStore.ts index 376063958..316c3564a 100644 --- a/src/stores/queueStore.ts +++ b/src/stores/queueStore.ts @@ -14,6 +14,8 @@ import type { import type { ComfyWorkflowJSON, NodeId } from '@/schemas/comfyWorkflowSchema' import { api } from '@/scripts/api' import type { ComfyApp } from '@/scripts/app' +import { useExtensionService } from '@/services/extensionService' +import { useNodeOutputStore } from '@/stores/imagePreviewStore' // Task type used in the API. export type APITaskType = 'queue' | 'history' @@ -377,7 +379,18 @@ export class TaskItemImpl { } await app.loadGraphData(toRaw(this.workflow)) if (this.outputs) { - app.nodeOutputs = toRaw(this.outputs) + const nodeOutputsStore = useNodeOutputStore() + const rawOutputs = toRaw(this.outputs) + for (const nodeExecutionId in rawOutputs) { + nodeOutputsStore.setNodeOutputsByExecutionId( + nodeExecutionId, + rawOutputs[nodeExecutionId] + ) + } + useExtensionService().invokeExtensions( + 'onNodeOutputsUpdated', + app.nodeOutputs + ) } } diff --git a/src/utils/graphTraversalUtil.ts b/src/utils/graphTraversalUtil.ts new file mode 100644 index 000000000..32f3e5727 --- /dev/null +++ b/src/utils/graphTraversalUtil.ts @@ -0,0 +1,243 @@ +import type { LGraph, LGraphNode, Subgraph } from '@comfyorg/litegraph' + +import type { NodeLocatorId } from '@/types/nodeIdentification' +import { parseNodeLocatorId } from '@/types/nodeIdentification' + +/** + * Parses an execution ID into its component parts. + * + * @param executionId - The execution ID (e.g., "123:456:789" or "789") + * @returns Array of node IDs in the path, or null if invalid + */ +export function parseExecutionId(executionId: string): string[] | null { + if (!executionId || typeof executionId !== 'string') return null + return executionId.split(':').filter((part) => part.length > 0) +} + +/** + * Extracts the local node ID from an execution ID. + * + * @param executionId - The execution ID (e.g., "123:456:789" or "789") + * @returns The local node ID or null if invalid + */ +export function getLocalNodeIdFromExecutionId( + executionId: string +): string | null { + const parts = parseExecutionId(executionId) + return parts ? parts[parts.length - 1] : null +} + +/** + * Extracts the subgraph path from an execution ID. + * + * @param executionId - The execution ID (e.g., "123:456:789" or "789") + * @returns Array of subgraph node IDs (excluding the final node ID), or empty array + */ +export function getSubgraphPathFromExecutionId(executionId: string): string[] { + const parts = parseExecutionId(executionId) + return parts ? parts.slice(0, -1) : [] +} + +/** + * Visits each node in a graph (non-recursive, single level). + * + * @param graph - The graph to visit nodes from + * @param visitor - Function called for each node + */ +export function visitGraphNodes( + graph: LGraph | Subgraph, + visitor: (node: LGraphNode) => void +): void { + for (const node of graph.nodes) { + visitor(node) + } +} + +/** + * Traverses a path of subgraphs to reach a target graph. + * + * @param startGraph - The graph to start from + * @param path - Array of subgraph node IDs to traverse + * @returns The target graph or null if path is invalid + */ +export function traverseSubgraphPath( + startGraph: LGraph | Subgraph, + path: string[] +): LGraph | Subgraph | null { + let currentGraph: LGraph | Subgraph = startGraph + + for (const nodeId of path) { + const node = currentGraph.getNodeById(nodeId) + if (!node?.isSubgraphNode?.() || !node.subgraph) return null + currentGraph = node.subgraph + } + + return currentGraph +} + +/** + * Traverses all nodes in a graph hierarchy (including subgraphs) and invokes + * a callback on each node that has the specified property. + * + * @param graph - The root graph to start traversal from + * @param callbackProperty - The name of the callback property to invoke on each node + */ +export function triggerCallbackOnAllNodes( + graph: LGraph | Subgraph, + callbackProperty: keyof LGraphNode +): void { + visitGraphNodes(graph, (node) => { + // Recursively process subgraphs first + if (node.isSubgraphNode?.() && node.subgraph) { + triggerCallbackOnAllNodes(node.subgraph, callbackProperty) + } + + // Invoke callback if it exists on the node + const callback = node[callbackProperty] + if (typeof callback === 'function') { + callback.call(node) + } + }) +} + +/** + * Collects all nodes in a graph hierarchy (including subgraphs) into a flat array. + * + * @param graph - The root graph to collect nodes from + * @param filter - Optional filter function to include only specific nodes + * @returns Array of all nodes in the graph hierarchy + */ +export function collectAllNodes( + graph: LGraph | Subgraph, + filter?: (node: LGraphNode) => boolean +): LGraphNode[] { + const nodes: LGraphNode[] = [] + + visitGraphNodes(graph, (node) => { + // Recursively collect from subgraphs + if (node.isSubgraphNode?.() && node.subgraph) { + nodes.push(...collectAllNodes(node.subgraph, filter)) + } + + // Add node if it passes the filter (or no filter provided) + if (!filter || filter(node)) { + nodes.push(node) + } + }) + + return nodes +} + +/** + * Finds a node by ID anywhere in the graph hierarchy. + * + * @param graph - The root graph to search + * @param nodeId - The ID of the node to find + * @returns The node if found, null otherwise + */ +export function findNodeInHierarchy( + graph: LGraph | Subgraph, + nodeId: string | number +): LGraphNode | null { + // Check current graph + const node = graph.getNodeById(nodeId) + if (node) return node + + // Search in subgraphs + for (const node of graph.nodes) { + if (node.isSubgraphNode?.() && node.subgraph) { + const found = findNodeInHierarchy(node.subgraph, nodeId) + if (found) return found + } + } + + return null +} + +/** + * Find a subgraph by its UUID anywhere in the graph hierarchy. + * + * @param graph - The root graph to search + * @param targetUuid - The UUID of the subgraph to find + * @returns The subgraph if found, null otherwise + */ +export function findSubgraphByUuid( + graph: LGraph | Subgraph, + targetUuid: string +): Subgraph | null { + // Check all nodes in the current graph + for (const node of graph._nodes) { + if (node.isSubgraphNode?.() && node.subgraph) { + if (node.subgraph.id === targetUuid) { + return node.subgraph + } + // Recursively search in nested subgraphs + const found = findSubgraphByUuid(node.subgraph, targetUuid) + if (found) return found + } + } + return null +} + +/** + * Get a node by its execution ID from anywhere in the graph hierarchy. + * Execution IDs use hierarchical format like "123:456:789" for nested nodes. + * + * @param rootGraph - The root graph to search from + * @param executionId - The execution ID (e.g., "123:456:789" or "789") + * @returns The node if found, null otherwise + */ +export function getNodeByExecutionId( + rootGraph: LGraph, + executionId: string +): LGraphNode | null { + if (!rootGraph) return null + + const localNodeId = getLocalNodeIdFromExecutionId(executionId) + if (!localNodeId) return null + + const subgraphPath = getSubgraphPathFromExecutionId(executionId) + + // If no subgraph path, it's in the root graph + if (subgraphPath.length === 0) { + return rootGraph.getNodeById(localNodeId) || null + } + + // Traverse to the target subgraph + const targetGraph = traverseSubgraphPath(rootGraph, subgraphPath) + if (!targetGraph) return null + + // Get the node from the target graph + return targetGraph.getNodeById(localNodeId) || null +} + +/** + * Get a node by its locator ID from anywhere in the graph hierarchy. + * Locator IDs use UUID format like "uuid:nodeId" for subgraph nodes. + * + * @param rootGraph - The root graph to search from + * @param locatorId - The locator ID (e.g., "uuid:123" or "123") + * @returns The node if found, null otherwise + */ +export function getNodeByLocatorId( + rootGraph: LGraph, + locatorId: NodeLocatorId | string +): LGraphNode | null { + if (!rootGraph) return null + + const parsedIds = parseNodeLocatorId(locatorId) + if (!parsedIds) return null + + const { subgraphUuid, localNodeId } = parsedIds + + // If no subgraph UUID, it's in the root graph + if (!subgraphUuid) { + return rootGraph.getNodeById(localNodeId) || null + } + + // Find the subgraph with the matching UUID + const targetSubgraph = findSubgraphByUuid(rootGraph, subgraphUuid) + if (!targetSubgraph) return null + + return targetSubgraph.getNodeById(localNodeId) || null +} diff --git a/tests-ui/tests/utils/graphTraversalUtil.test.ts b/tests-ui/tests/utils/graphTraversalUtil.test.ts new file mode 100644 index 000000000..d485055bd --- /dev/null +++ b/tests-ui/tests/utils/graphTraversalUtil.test.ts @@ -0,0 +1,486 @@ +import type { LGraph, LGraphNode, Subgraph } from '@comfyorg/litegraph' +import { describe, expect, it, vi } from 'vitest' + +import { + collectAllNodes, + findNodeInHierarchy, + findSubgraphByUuid, + getLocalNodeIdFromExecutionId, + getNodeByExecutionId, + getNodeByLocatorId, + getSubgraphPathFromExecutionId, + parseExecutionId, + traverseSubgraphPath, + triggerCallbackOnAllNodes, + visitGraphNodes +} from '@/utils/graphTraversalUtil' + +// Mock node factory +function createMockNode( + id: string | number, + options: { + isSubgraph?: boolean + subgraph?: Subgraph + callback?: () => void + } = {} +): LGraphNode { + return { + id, + isSubgraphNode: options.isSubgraph ? () => true : undefined, + subgraph: options.subgraph, + onExecutionStart: options.callback + } as unknown as LGraphNode +} + +// Mock graph factory +function createMockGraph(nodes: LGraphNode[]): LGraph { + return { + _nodes: nodes, + nodes: nodes, + getNodeById: (id: string | number) => + nodes.find((n) => String(n.id) === String(id)) || null + } as unknown as LGraph +} + +// Mock subgraph factory +function createMockSubgraph(id: string, nodes: LGraphNode[]): Subgraph { + return { + id, + _nodes: nodes, + nodes: nodes, + getNodeById: (nodeId: string | number) => + nodes.find((n) => String(n.id) === String(nodeId)) || null + } as unknown as Subgraph +} + +describe('graphTraversalUtil', () => { + describe('Pure utility functions', () => { + describe('parseExecutionId', () => { + it('should parse simple execution ID', () => { + expect(parseExecutionId('123')).toEqual(['123']) + }) + + it('should parse complex execution ID', () => { + expect(parseExecutionId('123:456:789')).toEqual(['123', '456', '789']) + }) + + it('should handle empty parts', () => { + expect(parseExecutionId('123::789')).toEqual(['123', '789']) + }) + + it('should return null for invalid input', () => { + expect(parseExecutionId('')).toBeNull() + expect(parseExecutionId(null as any)).toBeNull() + expect(parseExecutionId(undefined as any)).toBeNull() + }) + }) + + describe('getLocalNodeIdFromExecutionId', () => { + it('should extract local node ID from simple ID', () => { + expect(getLocalNodeIdFromExecutionId('123')).toBe('123') + }) + + it('should extract local node ID from complex ID', () => { + expect(getLocalNodeIdFromExecutionId('123:456:789')).toBe('789') + }) + + it('should return null for invalid input', () => { + expect(getLocalNodeIdFromExecutionId('')).toBeNull() + }) + }) + + describe('getSubgraphPathFromExecutionId', () => { + it('should return empty array for root node', () => { + expect(getSubgraphPathFromExecutionId('123')).toEqual([]) + }) + + it('should return subgraph path for nested node', () => { + expect(getSubgraphPathFromExecutionId('123:456:789')).toEqual([ + '123', + '456' + ]) + }) + + it('should return empty array for invalid input', () => { + expect(getSubgraphPathFromExecutionId('')).toEqual([]) + }) + }) + + describe('visitGraphNodes', () => { + it('should visit all nodes in graph', () => { + const visited: number[] = [] + const nodes = [createMockNode(1), createMockNode(2), createMockNode(3)] + const graph = createMockGraph(nodes) + + visitGraphNodes(graph, (node) => { + visited.push(node.id as number) + }) + + expect(visited).toEqual([1, 2, 3]) + }) + + it('should handle empty graph', () => { + const visited: number[] = [] + const graph = createMockGraph([]) + + visitGraphNodes(graph, (node) => { + visited.push(node.id as number) + }) + + expect(visited).toEqual([]) + }) + }) + + describe('traverseSubgraphPath', () => { + it('should return start graph for empty path', () => { + const graph = createMockGraph([]) + const result = traverseSubgraphPath(graph, []) + expect(result).toBe(graph) + }) + + it('should traverse single level', () => { + const subgraph = createMockSubgraph('sub-uuid', []) + const node = createMockNode('1', { isSubgraph: true, subgraph }) + const graph = createMockGraph([node]) + + const result = traverseSubgraphPath(graph, ['1']) + expect(result).toBe(subgraph) + }) + + it('should traverse multiple levels', () => { + const deepSubgraph = createMockSubgraph('deep-uuid', []) + const midNode = createMockNode('2', { + isSubgraph: true, + subgraph: deepSubgraph + }) + const midSubgraph = createMockSubgraph('mid-uuid', [midNode]) + const topNode = createMockNode('1', { + isSubgraph: true, + subgraph: midSubgraph + }) + const graph = createMockGraph([topNode]) + + const result = traverseSubgraphPath(graph, ['1', '2']) + expect(result).toBe(deepSubgraph) + }) + + it('should return null for invalid path', () => { + const graph = createMockGraph([createMockNode('1')]) + const result = traverseSubgraphPath(graph, ['999']) + expect(result).toBeNull() + }) + }) + }) + + describe('Main functions', () => { + describe('triggerCallbackOnAllNodes', () => { + it('should trigger callbacks on all nodes in a flat graph', () => { + const callback1 = vi.fn() + const callback2 = vi.fn() + + const node1 = createMockNode(1, { callback: callback1 }) + const node2 = createMockNode(2, { callback: callback2 }) + const node3 = createMockNode(3) // No callback + + const graph = createMockGraph([node1, node2, node3]) + + triggerCallbackOnAllNodes(graph, 'onExecutionStart') + + expect(callback1).toHaveBeenCalledOnce() + expect(callback2).toHaveBeenCalledOnce() + }) + + it('should trigger callbacks on nodes in subgraphs', () => { + const callback1 = vi.fn() + const callback2 = vi.fn() + const callback3 = vi.fn() + + // Create a subgraph with one node + const subNode = createMockNode(100, { callback: callback3 }) + const subgraph = createMockSubgraph('sub-uuid', [subNode]) + + // Create main graph with two nodes, one being a subgraph + const node1 = createMockNode(1, { callback: callback1 }) + const node2 = createMockNode(2, { + isSubgraph: true, + subgraph, + callback: callback2 + }) + + const graph = createMockGraph([node1, node2]) + + triggerCallbackOnAllNodes(graph, 'onExecutionStart') + + expect(callback1).toHaveBeenCalledOnce() + expect(callback2).toHaveBeenCalledOnce() + expect(callback3).toHaveBeenCalledOnce() + }) + + it('should handle nested subgraphs', () => { + const callbacks = [vi.fn(), vi.fn(), vi.fn(), vi.fn()] + + // Create deeply nested structure + const deepNode = createMockNode(300, { callback: callbacks[3] }) + const deepSubgraph = createMockSubgraph('deep-uuid', [deepNode]) + + const midNode1 = createMockNode(200, { callback: callbacks[2] }) + const midNode2 = createMockNode(201, { + isSubgraph: true, + subgraph: deepSubgraph + }) + const midSubgraph = createMockSubgraph('mid-uuid', [midNode1, midNode2]) + + const node1 = createMockNode(1, { callback: callbacks[0] }) + const node2 = createMockNode(2, { + isSubgraph: true, + subgraph: midSubgraph, + callback: callbacks[1] + }) + + const graph = createMockGraph([node1, node2]) + + triggerCallbackOnAllNodes(graph, 'onExecutionStart') + + callbacks.forEach((cb) => expect(cb).toHaveBeenCalledOnce()) + }) + }) + + describe('collectAllNodes', () => { + it('should collect all nodes from a flat graph', () => { + const nodes = [createMockNode(1), createMockNode(2), createMockNode(3)] + + const graph = createMockGraph(nodes) + const collected = collectAllNodes(graph) + + expect(collected).toHaveLength(3) + expect(collected.map((n) => n.id)).toEqual([1, 2, 3]) + }) + + it('should collect nodes from subgraphs', () => { + const subNode = createMockNode(100) + const subgraph = createMockSubgraph('sub-uuid', [subNode]) + + const nodes = [ + createMockNode(1), + createMockNode(2, { isSubgraph: true, subgraph }) + ] + + const graph = createMockGraph(nodes) + const collected = collectAllNodes(graph) + + expect(collected).toHaveLength(3) + expect(collected.map((n) => n.id)).toContain(100) + }) + + it('should filter nodes when filter function provided', () => { + const nodes = [createMockNode(1), createMockNode(2), createMockNode(3)] + + const graph = createMockGraph(nodes) + const collected = collectAllNodes(graph, (node) => Number(node.id) > 1) + + expect(collected).toHaveLength(2) + expect(collected.map((n) => n.id)).toEqual([2, 3]) + }) + }) + + describe('findNodeInHierarchy', () => { + it('should find node in root graph', () => { + const nodes = [createMockNode(1), createMockNode(2), createMockNode(3)] + + const graph = createMockGraph(nodes) + const found = findNodeInHierarchy(graph, 2) + + expect(found).toBeTruthy() + expect(found?.id).toBe(2) + }) + + it('should find node in subgraph', () => { + const subNode = createMockNode(100) + const subgraph = createMockSubgraph('sub-uuid', [subNode]) + + const nodes = [ + createMockNode(1), + createMockNode(2, { isSubgraph: true, subgraph }) + ] + + const graph = createMockGraph(nodes) + const found = findNodeInHierarchy(graph, 100) + + expect(found).toBeTruthy() + expect(found?.id).toBe(100) + }) + + it('should return null for non-existent node', () => { + const nodes = [createMockNode(1), createMockNode(2)] + const graph = createMockGraph(nodes) + + const found = findNodeInHierarchy(graph, 999) + expect(found).toBeNull() + }) + }) + + describe('findSubgraphByUuid', () => { + it('should find subgraph by UUID', () => { + const targetUuid = 'target-uuid' + const subgraph = createMockSubgraph(targetUuid, []) + + const nodes = [ + createMockNode(1), + createMockNode(2, { isSubgraph: true, subgraph }) + ] + + const graph = createMockGraph(nodes) + const found = findSubgraphByUuid(graph, targetUuid) + + expect(found).toBe(subgraph) + expect(found?.id).toBe(targetUuid) + }) + + it('should find nested subgraph', () => { + const targetUuid = 'deep-uuid' + const deepSubgraph = createMockSubgraph(targetUuid, []) + + const midSubgraph = createMockSubgraph('mid-uuid', [ + createMockNode(200, { isSubgraph: true, subgraph: deepSubgraph }) + ]) + + const graph = createMockGraph([ + createMockNode(1, { isSubgraph: true, subgraph: midSubgraph }) + ]) + + const found = findSubgraphByUuid(graph, targetUuid) + + expect(found).toBe(deepSubgraph) + expect(found?.id).toBe(targetUuid) + }) + + it('should return null for non-existent UUID', () => { + const subgraph = createMockSubgraph('some-uuid', []) + const graph = createMockGraph([ + createMockNode(1, { isSubgraph: true, subgraph }) + ]) + + const found = findSubgraphByUuid(graph, 'non-existent-uuid') + expect(found).toBeNull() + }) + }) + + describe('getNodeByExecutionId', () => { + it('should find node in root graph', () => { + const nodes = [createMockNode('123'), createMockNode('456')] + + const graph = createMockGraph(nodes) + const found = getNodeByExecutionId(graph, '123') + + expect(found).toBeTruthy() + expect(found?.id).toBe('123') + }) + + it('should find node in subgraph using execution path', () => { + const targetNode = createMockNode('789') + const subgraph = createMockSubgraph('sub-uuid', [targetNode]) + + const subgraphNode = createMockNode('456', { + isSubgraph: true, + subgraph + }) + + const graph = createMockGraph([createMockNode('123'), subgraphNode]) + + const found = getNodeByExecutionId(graph, '456:789') + + expect(found).toBe(targetNode) + expect(found?.id).toBe('789') + }) + + it('should handle deeply nested execution paths', () => { + const targetNode = createMockNode('999') + const deepSubgraph = createMockSubgraph('deep-uuid', [targetNode]) + + const midNode = createMockNode('456', { + isSubgraph: true, + subgraph: deepSubgraph + }) + const midSubgraph = createMockSubgraph('mid-uuid', [midNode]) + + const topNode = createMockNode('123', { + isSubgraph: true, + subgraph: midSubgraph + }) + + const graph = createMockGraph([topNode]) + + const found = getNodeByExecutionId(graph, '123:456:999') + + expect(found).toBe(targetNode) + expect(found?.id).toBe('999') + }) + + it('should return null for invalid path', () => { + const subgraph = createMockSubgraph('sub-uuid', [createMockNode('789')]) + const graph = createMockGraph([ + createMockNode('456', { isSubgraph: true, subgraph }) + ]) + + // Wrong path - node 123 doesn't exist + const found = getNodeByExecutionId(graph, '123:789') + expect(found).toBeNull() + }) + + it('should return null for invalid execution ID', () => { + const graph = createMockGraph([createMockNode('123')]) + const found = getNodeByExecutionId(graph, '') + expect(found).toBeNull() + }) + }) + + describe('getNodeByLocatorId', () => { + it('should find node in root graph', () => { + const nodes = [createMockNode('123'), createMockNode('456')] + + const graph = createMockGraph(nodes) + const found = getNodeByLocatorId(graph, '123') + + expect(found).toBeTruthy() + expect(found?.id).toBe('123') + }) + + it('should find node in subgraph using UUID format', () => { + const targetUuid = 'a1b2c3d4-e5f6-7890-abcd-ef1234567890' + const targetNode = createMockNode('789') + const subgraph = createMockSubgraph(targetUuid, [targetNode]) + + const graph = createMockGraph([ + createMockNode('123'), + createMockNode('456', { isSubgraph: true, subgraph }) + ]) + + const locatorId = `${targetUuid}:789` + const found = getNodeByLocatorId(graph, locatorId) + + expect(found).toBe(targetNode) + expect(found?.id).toBe('789') + }) + + it('should return null for invalid locator ID', () => { + const graph = createMockGraph([createMockNode('123')]) + + const found = getNodeByLocatorId(graph, 'invalid:::format') + expect(found).toBeNull() + }) + + it('should return null when subgraph UUID not found', () => { + const subgraph = createMockSubgraph('some-uuid', [ + createMockNode('789') + ]) + const graph = createMockGraph([ + createMockNode('456', { isSubgraph: true, subgraph }) + ]) + + const locatorId = 'non-existent-uuid:789' + const found = getNodeByLocatorId(graph, locatorId) + expect(found).toBeNull() + }) + }) + }) +})