diff --git a/src/litegraph.ts b/src/litegraph.ts index 36176191f..080152021 100644 --- a/src/litegraph.ts +++ b/src/litegraph.ts @@ -136,7 +136,7 @@ export { LGraphNode, type NodeId } from "./LGraphNode" export { type LinkId, LLink } from "./LLink" export { clamp, createBounds } from "./measure" export { Reroute, type RerouteId } from "./Reroute" -export { type ExecutableLGraphNode, ExecutableNodeDTO } from "./subgraph/ExecutableNodeDTO" +export { type ExecutableLGraphNode, ExecutableNodeDTO, type ExecutionId } from "./subgraph/ExecutableNodeDTO" export { SubgraphNode } from "./subgraph/SubgraphNode" export type { CanvasPointerEvent } from "./types/events" export { diff --git a/src/subgraph/ExecutableNodeDTO.ts b/src/subgraph/ExecutableNodeDTO.ts index 36fd851d9..373c7c9c9 100644 --- a/src/subgraph/ExecutableNodeDTO.ts +++ b/src/subgraph/ExecutableNodeDTO.ts @@ -11,12 +11,14 @@ import { LGraphEventMode } from "@/litegraph" import { Subgraph } from "./Subgraph" +export type ExecutionId = string + /** * Interface describing the data transfer objects used when compiling a graph for execution. */ -export type ExecutableLGraphNode = Omit +export type ExecutableLGraphNode = Omit -type NodeAndInput = { +type ResolvedInput = { node: ExecutableLGraphNode origin_id: NodeId origin_slot: number @@ -35,7 +37,7 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { inputs: { linkId: number | null, name: string, type: ISlotType }[] /** Backing field for {@link id}. */ - #id: NodeId + #id: ExecutionId /** * The path to the acutal node through subgraph instances, represented as a list of all subgraph node IDs (instances), @@ -79,6 +81,8 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { readonly node: LGraphNode | SubgraphNode, /** A list of subgraph instance node IDs from the root graph to the containing instance. @see {@link id} */ readonly subgraphNodePath: readonly NodeId[], + /** A flattened map of all DTOs in this node network. Subgraph instances have been expanded into their inner nodes. */ + readonly nodesByExecutionId: Map, /** The actual subgraph instance that contains this node, otherise undefined. */ readonly subgraphNode?: SubgraphNode, ) { @@ -101,7 +105,7 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { /** Returns either the DTO itself, or the DTOs of the inner nodes of the subgraph. */ getInnerNodes(): ExecutableLGraphNode[] { - return this.subgraphNode ? this.subgraphNode.getInnerNodes() : [this] + return this.subgraphNode ? this.subgraphNode.getInnerNodes(this.nodesByExecutionId, this.subgraphNodePath) : [this] } /** @@ -112,7 +116,7 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { * If overriding, ensure that the set is passed on all recursive calls. * @returns The node and the origin ID / slot index of the output. */ - resolveInput(slot: number, visited = new Set()): NodeAndInput | undefined { + resolveInput(slot: number, visited = new Set()): ResolvedInput | undefined { const uniqueId = `${this.subgraphNode?.subgraph.id}:${this.node.id}[I]${slot}` if (visited.has(uniqueId)) throw new RecursionError(`While resolving subgraph input [${uniqueId}]`) visited.add(uniqueId) @@ -140,10 +144,10 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { const outerLink = subgraphNode.graph.getLink(linkId) if (!outerLink) throw new InvalidLinkError(`No outer link found for slot [${link.origin_slot}] ${input.name}`) - // Translate subgraph node IDs to instances (not worth optimising yet) - const subgraphNodes = this.graph.rootGraph.resolveSubgraphIdPath(this.subgraphNodePath) + const subgraphNodeExecutionId = this.subgraphNodePath.join(":") + const subgraphNodeDto = this.nodesByExecutionId.get(subgraphNodeExecutionId) + if (!subgraphNodeDto) throw new Error(`No subgraph node DTO found for id [${subgraphNodeExecutionId}]`) - const subgraphNodeDto = new ExecutableNodeDTO(subgraphNode, this.subgraphNodePath.slice(0, -1), subgraphNodes.at(-2)) return subgraphNodeDto.resolveInput(outerLink.target_slot, visited) } @@ -151,7 +155,9 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { const outputNode = this.graph.getNodeById(link.origin_id) if (!outputNode) throw new InvalidLinkError(`No input node found for id [${this.id}] slot [${slot}] ${input.name}`) - const outputNodeDto = new ExecutableNodeDTO(outputNode, this.subgraphNodePath, subgraphNode) + const outputNodeExecutionId = [...this.subgraphNodePath, outputNode.id].join(":") + const outputNodeDto = this.nodesByExecutionId.get(outputNodeExecutionId) + if (!outputNodeDto) throw new Error(`No output node DTO found for id [${outputNodeExecutionId}]`) return outputNodeDto.resolveOutput(link.origin_slot, input.type, visited) } @@ -163,7 +169,7 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { * @param visited A set of unique IDs to guard against infinite recursion. See {@link resolveInput}. * @returns The node and the origin ID / slot index of the output. */ - resolveOutput(slot: number, type: ISlotType, visited: Set): NodeAndInput | undefined { + resolveOutput(slot: number, type: ISlotType, visited: Set): ResolvedInput | undefined { const uniqueId = `${this.subgraphNode?.subgraph.id}:${this.node.id}[O]${slot}` if (visited.has(uniqueId)) throw new RecursionError(`While resolving subgraph output [${uniqueId}]`) visited.add(uniqueId) @@ -200,7 +206,9 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { const outputNode = this.graph.getNodeById(virtualLink.origin_id) if (!outputNode) throw new InvalidLinkError(`Virtual node failed to resolve parent [${this.id}] slot [${slot}]`) - const outputNodeDto = new ExecutableNodeDTO(outputNode, this.subgraphNodePath, this.subgraphNode) + const outputNodeExecutionId = [...this.subgraphNodePath, outputNode.id].join(":") + const outputNodeDto = this.nodesByExecutionId.get(outputNodeExecutionId) + if (!outputNodeDto) throw new Error(`No output node DTO found for id [${outputNode.id}]`) return outputNodeDto.resolveOutput(virtualLink.origin_slot, type, visited) } @@ -222,7 +230,7 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { * @param visited A set of unique IDs to guard against infinite recursion. See {@link resolveInput}. * @returns A DTO for the node, and the origin ID / slot index of the output. */ - #resolveSubgraphOutput(slot: number, type: ISlotType, visited: Set): NodeAndInput | undefined { + #resolveSubgraphOutput(slot: number, type: ISlotType, visited: Set): ResolvedInput | undefined { const { node } = this const output = node.outputs.at(slot) @@ -237,7 +245,10 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode { if (!innerNode) throw new Error(`No output node found for id [${this.id}] slot [${slot}] ${output.name}`) // Recurse into the subgraph - const innerNodeDto = new ExecutableNodeDTO(innerNode, [...this.subgraphNodePath, node.id], node) + const innerNodeExecutionId = [...this.subgraphNodePath, node.id, innerNode.id].join(":") + const innerNodeDto = this.nodesByExecutionId.get(innerNodeExecutionId) + if (!innerNodeDto) throw new Error(`No inner node DTO found for id [${innerNodeExecutionId}]`) + return innerNodeDto.resolveOutput(innerResolved.link.origin_slot, type, visited) } } diff --git a/src/subgraph/SubgraphNode.ts b/src/subgraph/SubgraphNode.ts index 95f905d99..a5295b94e 100644 --- a/src/subgraph/SubgraphNode.ts +++ b/src/subgraph/SubgraphNode.ts @@ -14,7 +14,7 @@ import { NodeInputSlot } from "@/node/NodeInputSlot" import { NodeOutputSlot } from "@/node/NodeOutputSlot" import { toConcreteWidget } from "@/widgets/widgetMap" -import { type ExecutableLGraphNode, ExecutableNodeDTO } from "./ExecutableNodeDTO" +import { type ExecutableLGraphNode, ExecutableNodeDTO, type ExecutionId } from "./ExecutableNodeDTO" /** * An instance of a {@link Subgraph}, displayed as a node on the containing (parent) graph. @@ -273,26 +273,34 @@ export class SubgraphNode extends LGraphNode implements BaseLGraph { console.debug(`[SubgraphNode.resolveSubgraphOutputLink] No inner link found for output slot [${slot}] ${outputSlot.name}`, this) } - /** @internal Used to flatten the subgraph before execution. Recursive; call with no args. */ + /** @internal Used to flatten the subgraph before execution. */ getInnerNodes( - /** The list of nodes to add to. */ - nodes: ExecutableLGraphNode[] = [], - /** The set of visited nodes. */ - visited = new Set(), + /** The set of computed node DTOs for this execution. */ + executableNodes: Map, /** The path of subgraph node IDs. */ subgraphNodePath: readonly NodeId[] = [], + /** Internal recursion param. The list of nodes to add to. */ + nodes: ExecutableLGraphNode[] = [], + /** Internal recursion param. The set of visited nodes. */ + visited = new Set(), ): ExecutableLGraphNode[] { if (visited.has(this)) throw new RecursionError("while flattening subgraph") visited.add(this) const subgraphInstanceIdPath = [...subgraphNodePath, this.id] + // Store the subgraph node DTO + const parentSubgraphNode = this.graph.rootGraph.resolveSubgraphIdPath(subgraphNodePath).at(-1) + const subgraphNodeDto = new ExecutableNodeDTO(this, subgraphNodePath, executableNodes, parentSubgraphNode) + executableNodes.set(subgraphNodeDto.id, subgraphNodeDto) + for (const node of this.subgraph.nodes) { if ("getInnerNodes" in node) { - node.getInnerNodes(nodes, new Set(visited), subgraphInstanceIdPath) + node.getInnerNodes(executableNodes, subgraphInstanceIdPath, nodes, new Set(visited)) } else { // Create minimal DTOs rather than cloning the node - const aVeryRealNode = new ExecutableNodeDTO(node, subgraphInstanceIdPath, this) + const aVeryRealNode = new ExecutableNodeDTO(node, subgraphInstanceIdPath, executableNodes, this) + executableNodes.set(aVeryRealNode.id, aVeryRealNode) nodes.push(aVeryRealNode) } }