diff --git a/src/extensions/core/groupNode.ts b/src/extensions/core/groupNode.ts index ccc973390..800255af1 100644 --- a/src/extensions/core/groupNode.ts +++ b/src/extensions/core/groupNode.ts @@ -1,5 +1,11 @@ -import { LiteGraph } from '@comfyorg/litegraph' -import { LGraphNode, type NodeId } from '@comfyorg/litegraph/dist/LGraphNode' +import { + type ExecutableLGraphNode, + type ExecutionId, + LGraphNode, + LiteGraph, + SubgraphNode +} from '@comfyorg/litegraph' +import { type NodeId } from '@comfyorg/litegraph/dist/LGraphNode' import { t } from '@/i18n' import { @@ -13,6 +19,8 @@ import { useNodeDefStore } from '@/stores/nodeDefStore' import { useToastStore } from '@/stores/toastStore' import { useWidgetStore } from '@/stores/widgetStore' import { ComfyExtension } from '@/types/comfy' +import { ExecutableGroupNodeChildDTO } from '@/utils/executableGroupNodeChildDTO' +import { GROUP } from '@/utils/executableGroupNodeDto' import { deserialiseAndCreate, serialise } from '@/utils/vintageClipboard' import { api } from '../../scripts/api' @@ -26,8 +34,6 @@ type GroupNodeWorkflowData = { nodes: ComfyNode[] } -const GROUP = Symbol() - // v1 Prefix + Separator: workflow/ // v2 Prefix + Separator: workflow> (ComfyUI_frontend v1.2.63) const PREFIX = 'workflow' @@ -813,6 +819,7 @@ export class GroupNodeHandler { innerNodeIndex++ ) { const innerNode = this.innerNodes[innerNodeIndex] + innerNode.graph ??= this.node.graph for (const w of innerNode.widgets ?? []) { if (w.type === 'converted-widget') { @@ -899,7 +906,20 @@ export class GroupNodeHandler { return link } - this.node.getInnerNodes = () => { + /** @internal Used to flatten the subgraph before execution. Recursive; call with no args. */ + this.node.getInnerNodes = ( + computedNodeDtos: Map, + /** The path of subgraph node IDs. */ + subgraphNodePath: readonly NodeId[] = [], + /** The list of nodes to add to. */ + nodes: ExecutableLGraphNode[] = [], + /** The set of visited nodes. */ + visited = new Set() + ): ExecutableLGraphNode[] => { + if (visited.has(this.node)) + throw new Error('RecursionError: while flattening subgraph') + visited.add(this.node) + if (!this.innerNodes) { // @ts-expect-error fixme ts strict error this.node.setInnerNodes( @@ -910,6 +930,8 @@ export class GroupNodeHandler { innerNode.configure(n) // @ts-expect-error fixme ts strict error innerNode.id = `${this.node.id}:${i}` + // @ts-expect-error fixme ts strict error + innerNode.graph = this.node.graph return innerNode }) ) @@ -917,7 +939,31 @@ export class GroupNodeHandler { this.updateInnerWidgets() - return this.innerNodes + const subgraphInstanceIdPath = [...subgraphNodePath, this.node.id] + + // Assertion: Deprecated, does not matter. + const subgraphNode = (this.node.graph?.getNodeById( + subgraphNodePath.at(-1) + ) ?? undefined) as SubgraphNode | undefined + + for (const node of this.innerNodes) { + node.graph ??= this.node.graph + + // Create minimal DTOs rather than cloning the node + const currentId = String(node.id) + node.id = currentId.split(':').at(-1) + const aVeryRealNode = new ExecutableGroupNodeChildDTO( + node, + subgraphInstanceIdPath, + computedNodeDtos, + subgraphNode + ) + node.id = currentId + aVeryRealNode.groupNodeHandler = this + + nodes.push(aVeryRealNode) + } + return nodes } // @ts-expect-error fixme ts strict error @@ -1503,6 +1549,9 @@ export class GroupNodeHandler { this.linkOutputs(node, i) app.graph.remove(node) + + // Set internal ID to what is expected after workflow is reloaded + node.id = `${this.node.id}:${i}` } this.linkInputs() @@ -1608,8 +1657,14 @@ async function convertSelectedNodesToGroupNode() { if (nodes.length === 1) { throw new Error('Please select multiple nodes to convert to group node') } - if (nodes.some((n) => GroupNodeHandler.isGroupNode(n))) { - throw new Error('Selected nodes contain a group node') + + for (const node of nodes) { + if (node instanceof SubgraphNode) { + throw new Error('Selected nodes contain a subgraph node') + } + if (GroupNodeHandler.isGroupNode(node)) { + throw new Error('Selected nodes contain a group node') + } } return await GroupNodeHandler.fromNodes(nodes) } diff --git a/src/types/litegraph-augmentation.d.ts b/src/types/litegraph-augmentation.d.ts index ac6757f73..68e2ac8f4 100644 --- a/src/types/litegraph-augmentation.d.ts +++ b/src/types/litegraph-augmentation.d.ts @@ -69,7 +69,7 @@ declare module '@comfyorg/litegraph/dist/interfaces' { * ComfyUI extensions of litegraph */ declare module '@comfyorg/litegraph' { - import type { ExecutableLGraphNode } from '@comfyorg/litegraph' + import type { ExecutableLGraphNode, ExecutionId } from '@comfyorg/litegraph' import type { IBaseWidget } from '@comfyorg/litegraph/dist/types/widgets' interface LGraphNodeConstructor { @@ -99,8 +99,10 @@ declare module '@comfyorg/litegraph' { setInnerNodes?(nodes: LGraphNode[]): void /** Originally a group node API. */ getInnerNodes?( + nodesByExecutionId: Map, + subgraphNodePath?: readonly NodeId[], nodes?: ExecutableLGraphNode[], - subgraphs?: WeakSet + subgraphs?: Set ): ExecutableLGraphNode[] /** @deprecated groupNode */ convertToNodes?(): LGraphNode[] diff --git a/src/utils/executableGroupNodeChildDTO.ts b/src/utils/executableGroupNodeChildDTO.ts new file mode 100644 index 000000000..1832adfdd --- /dev/null +++ b/src/utils/executableGroupNodeChildDTO.ts @@ -0,0 +1,53 @@ +import { + type ExecutableLGraphNode, + ExecutableNodeDTO, + type ExecutionId, + type LGraphNode, + type NodeId, + type SubgraphNode +} from '@comfyorg/litegraph' + +import type { GroupNodeHandler } from '@/extensions/core/groupNode' + +export class ExecutableGroupNodeChildDTO extends ExecutableNodeDTO { + groupNodeHandler?: GroupNodeHandler + + constructor( + /** The actual node that this DTO wraps. */ + node: LGraphNode | SubgraphNode, + /** A list of subgraph instance node IDs from the root graph to the containing instance. @see {@link id} */ + subgraphNodePath: readonly NodeId[], + /** A flattened map of all DTOs in this node network. Subgraph instances have been expanded into their inner nodes. */ + nodesByExecutionId: Map, + /** The actual subgraph instance that contains this node, otherise undefined. */ + subgraphNode?: SubgraphNode | undefined, + groupNodeHandler?: GroupNodeHandler + ) { + super(node, subgraphNodePath, nodesByExecutionId, subgraphNode) + this.groupNodeHandler = groupNodeHandler + } + + override resolveInput(slot: number) { + const inputNode = this.node.getInputNode(slot) + if (!inputNode) return + + const link = this.node.getInputLink(slot) + if (!link) throw new Error('Failed to get input link') + + const id = String(inputNode.id).split(':').at(-1) + if (id === undefined) throw new Error('Invalid input node id') + + const inputNodeDto = this.nodesByExecutionId?.get(id) + if (!inputNodeDto) { + throw new Error( + `Failed to get input node ${id} for group node child ${this.id} with slot ${slot}` + ) + } + + return { + node: inputNodeDto, + origin_id: inputNode.id, + origin_slot: link.origin_slot + } + } +} diff --git a/src/utils/executableGroupNodeDto.ts b/src/utils/executableGroupNodeDto.ts new file mode 100644 index 000000000..09044dcce --- /dev/null +++ b/src/utils/executableGroupNodeDto.ts @@ -0,0 +1,71 @@ +import { + type ExecutableLGraphNode, + ExecutableNodeDTO, + type ISlotType, + LGraphEventMode, + type LGraphNode +} from '@comfyorg/litegraph' + +export const GROUP = Symbol() + +export function isGroupNode(node: LGraphNode): boolean { + return node.constructor?.nodeData?.[GROUP] !== undefined +} + +export class ExecutableGroupNodeDTO extends ExecutableNodeDTO { + override get isVirtualNode(): true { + return true + } + + override getInnerNodes(): ExecutableLGraphNode[] { + return this.node.getInnerNodes?.(this.nodesByExecutionId) ?? [] + } + + override resolveOutput(slot: number, type: ISlotType, visited: Set) { + // Temporary duplication: Bypass nodes are bypassed using the first input with matching type + if (this.mode === LGraphEventMode.BYPASS) { + const { inputs } = this + + // Bypass nodes by finding first input with matching type + const parentInputIndexes = Object.keys(inputs).map(Number) + // Prioritise exact slot index + const indexes = [slot, ...parentInputIndexes] + const matchingIndex = indexes.find((i) => inputs[i]?.type === type) + + // No input types match + if (matchingIndex === undefined) return + + return this.resolveInput(matchingIndex, visited) + } + + const linkId = this.node.outputs[slot]?.links?.at(0) + const link = this.node.graph?.getLink(linkId) + if (!link) { + throw new Error( + `Failed to get link for group node ${this.node.id} with link ${linkId}` + ) + } + + const updated = this.node.updateLink?.(link) + if (!updated) { + throw new Error( + `Failed to update link for group node ${this.node.id} with link ${linkId}` + ) + } + + const node = this.node + .getInnerNodes?.(this.nodesByExecutionId) + .find((node) => node.id === updated.origin_id) + if (!node) { + throw new Error( + `Failed to get node for group node ${this.node.id} with link ${linkId}` + ) + } + + return { + node, + origin_id: `${this.id}:${(updated.origin_id as string).split(':').at(-1)}`, + origin_slot: updated.origin_slot + } + } +} diff --git a/src/utils/executionUtil.ts b/src/utils/executionUtil.ts index 3b061aa46..fec35d9bd 100644 --- a/src/utils/executionUtil.ts +++ b/src/utils/executionUtil.ts @@ -1,4 +1,9 @@ -import type { LGraph, NodeId } from '@comfyorg/litegraph' +import type { + ExecutableLGraphNode, + ExecutionId, + LGraph, + NodeId +} from '@comfyorg/litegraph' import { ExecutableNodeDTO, LGraphEventMode, @@ -10,6 +15,7 @@ import type { ComfyWorkflowJSON } from '@/schemas/comfyWorkflowSchema' +import { ExecutableGroupNodeDTO, isGroupNode } from './executableGroupNodeDto' import { compressWidgetInputSlots } from './litegraphUtil' /** @@ -54,7 +60,9 @@ export const graphToPrompt = async ( const { sortNodes = false, queueNodeIds } = options for (const node of graph.computeExecutionOrder(false)) { - const innerNodes = node.getInnerNodes ? node.getInnerNodes() : [node] + const innerNodes = node.getInnerNodes + ? node.getInnerNodes(new Map()) + : [node] for (const innerNode of innerNodes) { if (innerNode.isVirtualNode) { innerNode.applyToGraph?.() @@ -78,82 +86,80 @@ export const graphToPrompt = async ( workflow.extra ??= {} workflow.extra.frontendVersion = __COMFYUI_FRONTEND_VERSION__ - const computedNodeDtos = graph - .computeExecutionOrder(false) - .map( - (node) => - new ExecutableNodeDTO( + const nodeDtoMap = new Map() + for (const node of graph.computeExecutionOrder(false)) { + const dto: ExecutableLGraphNode = isGroupNode(node) + ? new ExecutableGroupNodeDTO(node, [], nodeDtoMap) + : new ExecutableNodeDTO( node, [], + nodeDtoMap, node instanceof SubgraphNode ? node : undefined ) - ) + + for (const innerNode of dto.getInnerNodes()) { + nodeDtoMap.set(innerNode.id, innerNode) + } + + nodeDtoMap.set(dto.id, dto) + } let output: ComfyApiWorkflow = {} // Process nodes in order of execution - for (const outerNode of computedNodeDtos) { + for (const node of nodeDtoMap.values()) { // Don't serialize muted nodes if ( - outerNode.mode === LGraphEventMode.NEVER || - outerNode.mode === LGraphEventMode.BYPASS + node.isVirtualNode || + node.mode === LGraphEventMode.NEVER || + node.mode === LGraphEventMode.BYPASS ) { continue } - for (const node of outerNode.getInnerNodes()) { - if ( - node.isVirtualNode || - node.mode === LGraphEventMode.NEVER || - node.mode === LGraphEventMode.BYPASS - ) { - continue + const inputs: ComfyApiWorkflow[string]['inputs'] = {} + const { widgets } = node + + // Store all widget values + if (widgets) { + for (const [i, widget] of widgets.entries()) { + if (!widget.name || widget.options?.serialize === false) continue + + const widgetValue = widget.serializeValue + ? await widget.serializeValue(node, i) + : widget.value + // By default, Array values are reserved to represent node connections. + // We need to wrap the array as an object to avoid the misinterpretation + // of the array as a node connection. + // The backend automatically unwraps the object to an array during + // execution. + inputs[widget.name] = Array.isArray(widgetValue) + ? { + __value__: widgetValue + } + : widgetValue } + } - const inputs: ComfyApiWorkflow[string]['inputs'] = {} - const { widgets } = node + // Store all node links + for (const [i, input] of node.inputs.entries()) { + const resolvedInput = node.resolveInput(i) + if (!resolvedInput) continue - // Store all widget values - if (widgets) { - for (const [i, widget] of widgets.entries()) { - if (!widget.name || widget.options?.serialize === false) continue + inputs[input.name] = [ + String(resolvedInput.origin_id), + // @ts-expect-error link.origin_slot is already number. + parseInt(resolvedInput.origin_slot) + ] + } - const widgetValue = widget.serializeValue - ? await widget.serializeValue(node, i) - : widget.value - // By default, Array values are reserved to represent node connections. - // We need to wrap the array as an object to avoid the misinterpretation - // of the array as a node connection. - // The backend automatically unwraps the object to an array during - // execution. - inputs[widget.name] = Array.isArray(widgetValue) - ? { - __value__: widgetValue - } - : widgetValue - } - } - - // Store all node links - for (const [i, input] of node.inputs.entries()) { - const resolvedInput = node.resolveInput(i) - if (!resolvedInput) continue - - inputs[input.name] = [ - String(resolvedInput.origin_id), - // @ts-expect-error link.origin_slot is already number. - parseInt(resolvedInput.origin_slot) - ] - } - - output[String(node.id)] = { - inputs, - // TODO(huchenlei): Filter out all nodes that cannot be mapped to a - // comfyClass. - class_type: node.comfyClass!, - // Ignored by the backend. - _meta: { - title: node.title - } + output[String(node.id)] = { + inputs, + // TODO(huchenlei): Filter out all nodes that cannot be mapped to a + // comfyClass. + class_type: node.comfyClass!, + // Ignored by the backend. + _meta: { + title: node.title } } }