diff --git a/src/utils/executionUtil.ts b/src/utils/executionUtil.ts index feed67cbc..3b061aa46 100644 --- a/src/utils/executionUtil.ts +++ b/src/utils/executionUtil.ts @@ -1,5 +1,9 @@ -import type { LGraph, LGraphNode, NodeId } from '@comfyorg/litegraph' -import { LGraphEventMode } from '@comfyorg/litegraph' +import type { LGraph, NodeId } from '@comfyorg/litegraph' +import { + ExecutableNodeDTO, + LGraphEventMode, + SubgraphNode +} from '@comfyorg/litegraph' import type { ComfyApiWorkflow, @@ -74,20 +78,31 @@ export const graphToPrompt = async ( workflow.extra ??= {} workflow.extra.frontendVersion = __COMFYUI_FRONTEND_VERSION__ + const computedNodeDtos = graph + .computeExecutionOrder(false) + .map( + (node) => + new ExecutableNodeDTO( + node, + [], + node instanceof SubgraphNode ? node : undefined + ) + ) + let output: ComfyApiWorkflow = {} // Process nodes in order of execution - for (const outerNode of graph.computeExecutionOrder(false)) { - const skipNode = + for (const outerNode of computedNodeDtos) { + // Don't serialize muted nodes + if ( outerNode.mode === LGraphEventMode.NEVER || outerNode.mode === LGraphEventMode.BYPASS - const innerNodes = - !skipNode && outerNode.getInnerNodes - ? outerNode.getInnerNodes() - : [outerNode] - for (const node of innerNodes) { + ) { + continue + } + + for (const node of outerNode.getInnerNodes()) { if ( node.isVirtualNode || - // Don't serialize muted nodes node.mode === LGraphEventMode.NEVER || node.mode === LGraphEventMode.BYPASS ) { @@ -120,58 +135,14 @@ export const graphToPrompt = async ( // Store all node links for (const [i, input] of node.inputs.entries()) { - let parent: LGraphNode | null | undefined = node.getInputNode(i) - if (!parent) continue + const resolvedInput = node.resolveInput(i) + if (!resolvedInput) continue - let link = node.getInputLink(i) - while ( - parent?.mode === LGraphEventMode.BYPASS || - parent?.isVirtualNode - ) { - if (!link) break - - if (parent.isVirtualNode) { - link = parent.getInputLink(link.origin_slot) - if (!link) break - - parent = parent.isSubgraphNode() - ? parent.resolveSubgraphOutputLink(link.origin_slot)?.outputNode - : parent.getInputNode(link.target_slot) - - if (!parent) break - } else if (!parent.inputs) { - // Maintains existing behaviour if parent.getInputLink is overriden - break - } else if (parent.mode === LGraphEventMode.BYPASS) { - // Bypass nodes by finding first input with matching type - const parentInputIndexes = Object.keys(parent.inputs).map(Number) - // Prioritise exact slot index - const indexes = [link.origin_slot].concat(parentInputIndexes) - - const matchingIndex = indexes.find( - (index) => parent?.inputs[index]?.type === input.type - ) - // No input types match - if (matchingIndex === undefined) break - - link = parent.getInputLink(matchingIndex) - if (link) parent = parent.getInputNode(matchingIndex) - } - } - - if (link) { - if (parent?.updateLink) { - // Subgraph node / groupNode callback; deprecated, should be replaced - link = parent.updateLink(link) - } - if (link) { - inputs[input.name] = [ - String(link.origin_id), - // @ts-expect-error link.origin_slot is already number. - parseInt(link.origin_slot) - ] - } - } + inputs[input.name] = [ + String(resolvedInput.origin_id), + // @ts-expect-error link.origin_slot is already number. + parseInt(resolvedInput.origin_slot) + ] } output[String(node.id)] = { diff --git a/src/utils/litegraphUtil.ts b/src/utils/litegraphUtil.ts index 865c4f6c6..738fbd7ca 100644 --- a/src/utils/litegraphUtil.ts +++ b/src/utils/litegraphUtil.ts @@ -1,6 +1,10 @@ import { ColorOption, LGraph, Reroute } from '@comfyorg/litegraph' import { LGraphGroup, LGraphNode, isColorable } from '@comfyorg/litegraph' -import type { ISerialisedGraph } from '@comfyorg/litegraph/dist/types/serialisation' +import type { + ExportedSubgraph, + ISerialisableNodeInput, + ISerialisedGraph +} from '@comfyorg/litegraph/dist/types/serialisation' import type { IBaseWidget, IComboWidget @@ -167,12 +171,11 @@ export function fixLinkInputSlots(graph: LGraph) { * This should match the serialization format of legacy widget conversion. * * @param graph - The graph to compress widget input slots for. + * @throws If an infinite loop is detected. */ export function compressWidgetInputSlots(graph: ISerialisedGraph) { for (const node of graph.nodes) { - node.inputs = node.inputs?.filter( - (input) => !(input.widget && input.link === null) - ) + node.inputs = node.inputs?.filter(matchesLegacyApi) for (const [inputIndex, input] of node.inputs?.entries() ?? []) { if (input.link) { @@ -183,4 +186,44 @@ export function compressWidgetInputSlots(graph: ISerialisedGraph) { } } } + + compressSubgraphWidgetInputSlots(graph.definitions?.subgraphs) +} + +function matchesLegacyApi(input: ISerialisableNodeInput) { + return !(input.widget && input.link === null) +} + +/** + * Duplication to handle the legacy link arrays in the root workflow. + * @see compressWidgetInputSlots + * @param subgraph The subgraph to compress widget input slots for. + */ +function compressSubgraphWidgetInputSlots( + subgraphs: ExportedSubgraph[] | undefined, + visited = new WeakSet() +) { + if (!subgraphs) return + + for (const subgraph of subgraphs) { + if (visited.has(subgraph)) throw new Error('Infinite loop detected') + visited.add(subgraph) + + if (subgraph.nodes) { + for (const node of subgraph.nodes) { + node.inputs = node.inputs?.filter(matchesLegacyApi) + + if (!subgraph.links) continue + + for (const [inputIndex, input] of node.inputs?.entries() ?? []) { + if (input.link) { + const link = subgraph.links.find((link) => link.id === input.link) + if (link) link.target_slot = inputIndex + } + } + } + } + + compressSubgraphWidgetInputSlots(subgraph.definitions?.subgraphs, visited) + } }