mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-03-09 23:20:04 +00:00
Add Subgraphs (#1000)
This commit is contained in:
39
src/subgraph/EmptySubgraphInput.ts
Normal file
39
src/subgraph/EmptySubgraphInput.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import type { SubgraphInputNode } from "./SubgraphInputNode"
|
||||
import type { INodeInputSlot, Point } from "@/interfaces"
|
||||
import type { LGraphNode } from "@/LGraphNode"
|
||||
import type { RerouteId } from "@/Reroute"
|
||||
|
||||
import { LLink } from "@/LLink"
|
||||
import { nextUniqueName } from "@/strings"
|
||||
import { zeroUuid } from "@/utils/uuid"
|
||||
|
||||
import { SubgraphInput } from "./SubgraphInput"
|
||||
|
||||
/**
|
||||
* A virtual slot that simply creates a new input slot when connected to.
|
||||
*/
|
||||
export class EmptySubgraphInput extends SubgraphInput {
|
||||
declare parent: SubgraphInputNode
|
||||
|
||||
constructor(parent: SubgraphInputNode) {
|
||||
super({
|
||||
id: zeroUuid,
|
||||
name: "",
|
||||
type: "",
|
||||
}, parent)
|
||||
}
|
||||
|
||||
override connect(slot: INodeInputSlot, node: LGraphNode, afterRerouteId?: RerouteId): LLink | undefined {
|
||||
const { subgraph } = this.parent
|
||||
const existingNames = subgraph.inputs.map(x => x.name)
|
||||
|
||||
const name = nextUniqueName(slot.name, existingNames)
|
||||
const input = subgraph.addInput(name, String(slot.type))
|
||||
return input.connect(slot, node, afterRerouteId)
|
||||
}
|
||||
|
||||
override get labelPos(): Point {
|
||||
const [x, y, , height] = this.boundingRect
|
||||
return [x, y + height * 0.5]
|
||||
}
|
||||
}
|
||||
39
src/subgraph/EmptySubgraphOutput.ts
Normal file
39
src/subgraph/EmptySubgraphOutput.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import type { SubgraphOutputNode } from "./SubgraphOutputNode"
|
||||
import type { INodeOutputSlot, Point } from "@/interfaces"
|
||||
import type { LGraphNode } from "@/LGraphNode"
|
||||
import type { RerouteId } from "@/Reroute"
|
||||
|
||||
import { LLink } from "@/LLink"
|
||||
import { nextUniqueName } from "@/strings"
|
||||
import { zeroUuid } from "@/utils/uuid"
|
||||
|
||||
import { SubgraphOutput } from "./SubgraphOutput"
|
||||
|
||||
/**
|
||||
* A virtual slot that simply creates a new output slot when connected to.
|
||||
*/
|
||||
export class EmptySubgraphOutput extends SubgraphOutput {
|
||||
declare parent: SubgraphOutputNode
|
||||
|
||||
constructor(parent: SubgraphOutputNode) {
|
||||
super({
|
||||
id: zeroUuid,
|
||||
name: "",
|
||||
type: "",
|
||||
}, parent)
|
||||
}
|
||||
|
||||
override connect(slot: INodeOutputSlot, node: LGraphNode, afterRerouteId?: RerouteId): LLink | undefined {
|
||||
const { subgraph } = this.parent
|
||||
const existingNames = subgraph.outputs.map(x => x.name)
|
||||
|
||||
const name = nextUniqueName(slot.name, existingNames)
|
||||
const output = subgraph.addOutput(name, String(slot.type))
|
||||
return output.connect(slot, node, afterRerouteId)
|
||||
}
|
||||
|
||||
override get labelPos(): Point {
|
||||
const [x, y, , height] = this.boundingRect
|
||||
return [x, y + height * 0.5]
|
||||
}
|
||||
}
|
||||
232
src/subgraph/ExecutableNodeDTO.ts
Normal file
232
src/subgraph/ExecutableNodeDTO.ts
Normal file
@@ -0,0 +1,232 @@
|
||||
import type { SubgraphNode } from "./SubgraphNode"
|
||||
import type { CallbackParams, CallbackReturn, ISlotType } from "@/interfaces"
|
||||
import type { LGraph } from "@/LGraph"
|
||||
import type { LGraphNode, NodeId } from "@/LGraphNode"
|
||||
|
||||
import { InvalidLinkError } from "@/infrastructure/InvalidLinkError"
|
||||
import { NullGraphError } from "@/infrastructure/NullGraphError"
|
||||
import { RecursionError } from "@/infrastructure/RecursionError"
|
||||
import { SlotIndexError } from "@/infrastructure/SlotIndexError"
|
||||
import { LGraphEventMode } from "@/litegraph"
|
||||
|
||||
import { Subgraph } from "./Subgraph"
|
||||
|
||||
/**
|
||||
* Interface describing the data transfer objects used when compiling a graph for execution.
|
||||
*/
|
||||
export type ExecutableLGraphNode = Omit<ExecutableNodeDTO, "graph" | "node" | "subgraphNodePath" | "subgraphNode">
|
||||
|
||||
type NodeAndInput = {
|
||||
node: ExecutableLGraphNode
|
||||
origin_id: NodeId
|
||||
origin_slot: number
|
||||
}
|
||||
|
||||
/**
|
||||
* Concrete implementation of {@link ExecutableLGraphNode}.
|
||||
* @remarks This is the class that is used to create the data transfer objects for executable nodes.
|
||||
*/
|
||||
export class ExecutableNodeDTO implements ExecutableLGraphNode {
|
||||
applyToGraph?(...args: CallbackParams<typeof this.node.applyToGraph>): CallbackReturn<typeof this.node.applyToGraph>
|
||||
|
||||
/** The graph that this node is a part of. */
|
||||
readonly graph: LGraph | Subgraph
|
||||
|
||||
inputs: { linkId: number | null, name: string, type: ISlotType }[]
|
||||
|
||||
/** Backing field for {@link id}. */
|
||||
#id: NodeId
|
||||
|
||||
/**
|
||||
* The path to the acutal node through subgraph instances, represented as a list of all subgraph node IDs (instances),
|
||||
* followed by the actual original node ID within the subgraph. Each segment is separated by `:`.
|
||||
*
|
||||
* e.g. `1:2:3`:
|
||||
* - `1` is the node ID of the first subgraph node in the parent workflow
|
||||
* - `2` is the node ID of the second subgraph node in the first subgraph
|
||||
* - `3` is the node ID of the actual node in the subgraph definition
|
||||
*/
|
||||
get id() {
|
||||
return this.#id
|
||||
}
|
||||
|
||||
get type() {
|
||||
return this.node.type
|
||||
}
|
||||
|
||||
get title() {
|
||||
return this.node.title
|
||||
}
|
||||
|
||||
get mode() {
|
||||
return this.node.mode
|
||||
}
|
||||
|
||||
get comfyClass() {
|
||||
return this.node.comfyClass
|
||||
}
|
||||
|
||||
get isVirtualNode() {
|
||||
return this.node.isVirtualNode
|
||||
}
|
||||
|
||||
get widgets() {
|
||||
return this.node.widgets
|
||||
}
|
||||
|
||||
constructor(
|
||||
/** The actual node that this DTO wraps. */
|
||||
readonly node: LGraphNode | SubgraphNode,
|
||||
/** A list of subgraph instance node IDs from the root graph to the containing instance. @see {@link id} */
|
||||
readonly subgraphNodePath: readonly NodeId[],
|
||||
/** The actual subgraph instance that contains this node, otherise undefined. */
|
||||
readonly subgraphNode?: SubgraphNode,
|
||||
) {
|
||||
if (!node.graph) throw new NullGraphError()
|
||||
|
||||
// Set the internal ID of the DTO
|
||||
this.#id = [...this.subgraphNodePath, this.node.id].join(":")
|
||||
this.graph = node.graph
|
||||
this.inputs = this.node.inputs.map(x => ({
|
||||
linkId: x.link,
|
||||
name: x.name,
|
||||
type: x.type,
|
||||
}))
|
||||
|
||||
// Only create a wrapper if the node has an applyToGraph method
|
||||
if (this.node.applyToGraph) {
|
||||
this.applyToGraph = (...args) => this.node.applyToGraph?.(...args)
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns either the DTO itself, or the DTOs of the inner nodes of the subgraph. */
|
||||
getInnerNodes(): ExecutableLGraphNode[] {
|
||||
return this.subgraphNode ? this.subgraphNode.getInnerNodes() : [this]
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves the executable node & link IDs for a given input slot.
|
||||
* @param slot The slot index of the input.
|
||||
* @param visited Leave empty unless overriding this method.
|
||||
* A set of unique IDs, used to guard against infinite recursion.
|
||||
* If overriding, ensure that the set is passed on all recursive calls.
|
||||
* @returns The node and the origin ID / slot index of the output.
|
||||
*/
|
||||
resolveInput(slot: number, visited = new Set<string>()): NodeAndInput | undefined {
|
||||
const uniqueId = `${this.subgraphNode?.subgraph.id}:${this.node.id}[I]${slot}`
|
||||
if (visited.has(uniqueId)) throw new RecursionError(`While resolving subgraph input [${uniqueId}]`)
|
||||
visited.add(uniqueId)
|
||||
|
||||
const input = this.inputs.at(slot)
|
||||
if (!input) throw new SlotIndexError(`No input found for flattened id [${this.id}] slot [${slot}]`)
|
||||
|
||||
// Nothing connected
|
||||
if (input.linkId == null) return
|
||||
|
||||
const link = this.graph.getLink(input.linkId)
|
||||
if (!link) throw new InvalidLinkError(`No link found in parent graph for id [${this.id}] slot [${slot}] ${input.name}`)
|
||||
|
||||
const { subgraphNode } = this
|
||||
|
||||
// Link goes up and out of this subgraph
|
||||
if (subgraphNode && link.originIsIoNode) {
|
||||
const subgraphNodeInput = subgraphNode.inputs.at(link.origin_slot)
|
||||
if (!subgraphNodeInput) throw new SlotIndexError(`No input found for slot [${link.origin_slot}] ${input.name}`)
|
||||
|
||||
// Nothing connected
|
||||
const linkId = subgraphNodeInput.link
|
||||
if (linkId == null) return
|
||||
|
||||
const outerLink = subgraphNode.graph.getLink(linkId)
|
||||
if (!outerLink) throw new InvalidLinkError(`No outer link found for slot [${link.origin_slot}] ${input.name}`)
|
||||
|
||||
// Translate subgraph node IDs to instances (not worth optimising yet)
|
||||
const subgraphNodes = this.graph.rootGraph.resolveSubgraphIdPath(this.subgraphNodePath)
|
||||
|
||||
const subgraphNodeDto = new ExecutableNodeDTO(subgraphNode, this.subgraphNodePath.slice(0, -1), subgraphNodes.at(-2))
|
||||
return subgraphNodeDto.resolveInput(outerLink.target_slot, visited)
|
||||
}
|
||||
|
||||
// Not part of a subgraph; use the original link
|
||||
const outputNode = this.graph.getNodeById(link.origin_id)
|
||||
if (!outputNode) throw new InvalidLinkError(`No input node found for id [${this.id}] slot [${slot}] ${input.name}`)
|
||||
|
||||
const outputNodeDto = new ExecutableNodeDTO(outputNode, this.subgraphNodePath, subgraphNode)
|
||||
|
||||
return outputNodeDto.resolveOutput(link.origin_slot, input.type, visited)
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines whether this output is a valid endpoint for a link (non-virtual, non-bypass).
|
||||
* @param slot The slot index of the output.
|
||||
* @param type The type of the input
|
||||
* @param visited A set of unique IDs to guard against infinite recursion. See {@link resolveInput}.
|
||||
* @returns The node and the origin ID / slot index of the output.
|
||||
*/
|
||||
resolveOutput(slot: number, type: ISlotType, visited: Set<string>): NodeAndInput | undefined {
|
||||
const uniqueId = `${this.subgraphNode?.subgraph.id}:${this.node.id}[O]${slot}`
|
||||
if (visited.has(uniqueId)) throw new RecursionError(`While resolving subgraph output [${uniqueId}]`)
|
||||
visited.add(uniqueId)
|
||||
|
||||
// Upstreamed: Bypass nodes are bypassed using the first input with matching type
|
||||
if (this.mode === LGraphEventMode.BYPASS) {
|
||||
const { inputs } = this
|
||||
|
||||
// Bypass nodes by finding first input with matching type
|
||||
const parentInputIndexes = Object.keys(inputs).map(Number)
|
||||
// Prioritise exact slot index
|
||||
const indexes = [slot, ...parentInputIndexes]
|
||||
const matchingIndex = indexes.find(i => inputs[i]?.type === type)
|
||||
|
||||
// No input types match
|
||||
if (matchingIndex === undefined) {
|
||||
console.debug(`[ExecutableNodeDTO.resolveOutput] No input types match type [${type}] for id [${this.id}] slot [${slot}]`, this)
|
||||
return
|
||||
}
|
||||
|
||||
return this.resolveInput(matchingIndex, visited)
|
||||
}
|
||||
|
||||
const { node } = this
|
||||
if (node.isSubgraphNode()) return this.#resolveSubgraphOutput(slot, type, visited)
|
||||
|
||||
// Upstreamed: Other virtual nodes are bypassed using the same input/output index (slots must match)
|
||||
if (node.isVirtualNode) {
|
||||
if (this.inputs.at(slot)) return this.resolveInput(slot, visited)
|
||||
|
||||
// Virtual nodes without a matching input should be discarded.
|
||||
return
|
||||
}
|
||||
|
||||
return {
|
||||
node: this,
|
||||
origin_id: this.id,
|
||||
origin_slot: slot,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves the link inside a subgraph node, from the subgraph IO node to the node inside the subgraph.
|
||||
* @param slot The slot index of the output on the subgraph node.
|
||||
* @param visited A set of unique IDs to guard against infinite recursion. See {@link resolveInput}.
|
||||
* @returns A DTO for the node, and the origin ID / slot index of the output.
|
||||
*/
|
||||
#resolveSubgraphOutput(slot: number, type: ISlotType, visited: Set<string>): NodeAndInput | undefined {
|
||||
const { node } = this
|
||||
const output = node.outputs.at(slot)
|
||||
|
||||
if (!output) throw new SlotIndexError(`No output found for flattened id [${this.id}] slot [${slot}]`)
|
||||
if (!node.isSubgraphNode()) throw new TypeError(`Node is not a subgraph node: ${node.id}`)
|
||||
|
||||
// Link inside the subgraph
|
||||
const innerResolved = node.resolveSubgraphOutputLink(slot)
|
||||
if (!innerResolved) return
|
||||
|
||||
const innerNode = innerResolved.outputNode
|
||||
if (!innerNode) throw new Error(`No output node found for id [${this.id}] slot [${slot}] ${output.name}`)
|
||||
|
||||
// Recurse into the subgraph
|
||||
const innerNodeDto = new ExecutableNodeDTO(innerNode, [...this.subgraphNodePath, node.id], node)
|
||||
return innerNodeDto.resolveOutput(innerResolved.link.origin_slot, type, visited)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
import type { ExportedSubgraph, ExposedWidget, Serialisable, SerialisableGraph } from "@/types/serialisation"
|
||||
import type { DefaultConnectionColors } from "@/interfaces"
|
||||
import type { LGraphCanvas } from "@/LGraphCanvas"
|
||||
import type { ExportedSubgraph, ExposedWidget, ISerialisedGraph, Serialisable, SerialisableGraph } from "@/types/serialisation"
|
||||
|
||||
import { type BaseLGraph, LGraph } from "@/LGraph"
|
||||
import { createUuidv4, type LGraphNode } from "@/litegraph"
|
||||
|
||||
import { SubgraphInput } from "./SubgraphInput"
|
||||
import { SubgraphInputNode } from "./SubgraphInputNode"
|
||||
@@ -12,44 +15,229 @@ export type GraphOrSubgraph = LGraph | Subgraph
|
||||
|
||||
/** A subgraph definition. */
|
||||
export class Subgraph extends LGraph implements BaseLGraph, Serialisable<ExportedSubgraph> {
|
||||
/** Limits the number of levels / depth that subgraphs may be nested. Prevents uncontrolled programmatic nesting. */
|
||||
static MAX_NESTED_SUBGRAPHS = 1000
|
||||
|
||||
/** The display name of the subgraph. */
|
||||
name: string
|
||||
name: string = "Unnamed Subgraph"
|
||||
|
||||
readonly inputNode = new SubgraphInputNode(this)
|
||||
readonly outputNode = new SubgraphOutputNode(this)
|
||||
|
||||
/** Ordered list of inputs to the subgraph itself. Similar to a reroute, with the input side in the graph, and the output side in the subgraph. */
|
||||
readonly inputs: SubgraphInput[]
|
||||
readonly inputs: SubgraphInput[] = []
|
||||
/** Ordered list of outputs from the subgraph itself. Similar to a reroute, with the input side in the subgraph, and the output side in the graph. */
|
||||
readonly outputs: SubgraphOutput[]
|
||||
readonly outputs: SubgraphOutput[] = []
|
||||
/** A list of node widgets displayed in the parent graph, on the subgraph object. */
|
||||
readonly widgets: ExposedWidget[]
|
||||
readonly widgets: ExposedWidget[] = []
|
||||
|
||||
#rootGraph: LGraph
|
||||
override get rootGraph(): LGraph {
|
||||
return this.parents[0]
|
||||
}
|
||||
|
||||
/** @inheritdoc */
|
||||
get pathToRootGraph(): readonly [LGraph, ...Subgraph[]] {
|
||||
return [...this.parents, this]
|
||||
return this.#rootGraph
|
||||
}
|
||||
|
||||
constructor(
|
||||
readonly parents: readonly [LGraph, ...Subgraph[]],
|
||||
rootGraph: LGraph,
|
||||
data: ExportedSubgraph,
|
||||
) {
|
||||
if (!parents.length) throw new Error("Subgraph must have at least one parent")
|
||||
if (!rootGraph) throw new Error("Root graph is required")
|
||||
|
||||
const cloned = structuredClone(data)
|
||||
const { name, inputs, outputs, widgets } = cloned
|
||||
super()
|
||||
|
||||
this.name = name
|
||||
this.inputs = inputs?.map(x => new SubgraphInput(x, this.inputNode)) ?? []
|
||||
this.outputs = outputs?.map(x => new SubgraphOutput(x, this.outputNode)) ?? []
|
||||
this.widgets = widgets ?? []
|
||||
this.#rootGraph = rootGraph
|
||||
|
||||
this.configure(cloned)
|
||||
const cloned = structuredClone(data)
|
||||
this._configureBase(cloned)
|
||||
this.#configureSubgraph(cloned)
|
||||
}
|
||||
|
||||
getIoNodeOnPos(x: number, y: number): SubgraphInputNode | SubgraphOutputNode | undefined {
|
||||
const { inputNode, outputNode } = this
|
||||
if (inputNode.containsPoint([x, y])) return inputNode
|
||||
if (outputNode.containsPoint([x, y])) return outputNode
|
||||
}
|
||||
|
||||
#configureSubgraph(data: ISerialisedGraph & ExportedSubgraph | SerialisableGraph & ExportedSubgraph): void {
|
||||
const { name, inputs, outputs, widgets } = data
|
||||
|
||||
this.name = name
|
||||
if (inputs) {
|
||||
this.inputs.length = 0
|
||||
for (const input of inputs) {
|
||||
this.inputs.push(new SubgraphInput(input, this.inputNode))
|
||||
}
|
||||
}
|
||||
|
||||
if (outputs) {
|
||||
this.outputs.length = 0
|
||||
for (const output of outputs) {
|
||||
this.outputs.push(new SubgraphOutput(output, this.outputNode))
|
||||
}
|
||||
}
|
||||
|
||||
if (widgets) {
|
||||
this.widgets.length = 0
|
||||
for (const widget of widgets) {
|
||||
this.widgets.push(widget)
|
||||
}
|
||||
}
|
||||
|
||||
this.inputNode.configure(data.inputNode)
|
||||
this.outputNode.configure(data.outputNode)
|
||||
}
|
||||
|
||||
override configure(data: ISerialisedGraph & ExportedSubgraph | SerialisableGraph & ExportedSubgraph, keep_old?: boolean): boolean | undefined {
|
||||
const r = super.configure(data, keep_old)
|
||||
|
||||
this.#configureSubgraph(data)
|
||||
return r
|
||||
}
|
||||
|
||||
override attachCanvas(canvas: LGraphCanvas): void {
|
||||
super.attachCanvas(canvas)
|
||||
canvas.subgraph = this
|
||||
}
|
||||
|
||||
addInput(name: string, type: string): SubgraphInput {
|
||||
const input = new SubgraphInput({
|
||||
id: createUuidv4(),
|
||||
name,
|
||||
type,
|
||||
}, this.inputNode)
|
||||
|
||||
this.inputs.push(input)
|
||||
|
||||
const subgraphId = this.id
|
||||
this.#forAllNodes((node) => {
|
||||
if (node.type === subgraphId) {
|
||||
node.addInput(name, type)
|
||||
}
|
||||
})
|
||||
|
||||
return input
|
||||
}
|
||||
|
||||
addOutput(name: string, type: string): SubgraphOutput {
|
||||
const output = new SubgraphOutput({
|
||||
id: createUuidv4(),
|
||||
name,
|
||||
type,
|
||||
}, this.outputNode)
|
||||
|
||||
this.outputs.push(output)
|
||||
|
||||
const subgraphId = this.id
|
||||
this.#forAllNodes((node) => {
|
||||
if (node.type === subgraphId) {
|
||||
node.addOutput(name, type)
|
||||
}
|
||||
})
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
#forAllNodes(callback: (node: LGraphNode) => void): void {
|
||||
forNodes(this.rootGraph.nodes)
|
||||
for (const subgraph of this.rootGraph.subgraphs.values()) {
|
||||
forNodes(subgraph.nodes)
|
||||
}
|
||||
|
||||
function forNodes(nodes: LGraphNode[]) {
|
||||
for (const node of nodes) {
|
||||
callback(node)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Renames an input slot in the subgraph.
|
||||
* @param input The input slot to rename.
|
||||
* @param name The new name for the input slot.
|
||||
*/
|
||||
renameInput(input: SubgraphInput, name: string): void {
|
||||
input.label = name
|
||||
const index = this.inputs.indexOf(input)
|
||||
if (index === -1) throw new Error("Input not found")
|
||||
|
||||
this.#forAllNodes((node) => {
|
||||
if (node.type === this.id) {
|
||||
node.inputs[index].label = name
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Renames an output slot in the subgraph.
|
||||
* @param output The output slot to rename.
|
||||
* @param name The new name for the output slot.
|
||||
*/
|
||||
renameOutput(output: SubgraphOutput, name: string): void {
|
||||
output.label = name
|
||||
const index = this.outputs.indexOf(output)
|
||||
if (index === -1) throw new Error("Output not found")
|
||||
|
||||
this.#forAllNodes((node) => {
|
||||
if (node.type === this.id) {
|
||||
node.outputs[index].label = name
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes an input slot from the subgraph.
|
||||
* @param input The input slot to remove.
|
||||
*/
|
||||
removeInput(input: SubgraphInput): void {
|
||||
input.disconnect()
|
||||
|
||||
const index = this.inputs.indexOf(input)
|
||||
if (index === -1) throw new Error("Input not found")
|
||||
|
||||
this.inputs.splice(index, 1)
|
||||
|
||||
const { length } = this.inputs
|
||||
for (let i = index; i < length; i++) {
|
||||
this.inputs[i].decrementSlots("inputs")
|
||||
}
|
||||
|
||||
this.#forAllNodes((node) => {
|
||||
if (node.type === this.id) {
|
||||
node.removeInput(index)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes an output slot from the subgraph.
|
||||
* @param output The output slot to remove.
|
||||
*/
|
||||
removeOutput(output: SubgraphOutput): void {
|
||||
output.disconnect()
|
||||
|
||||
const index = this.outputs.indexOf(output)
|
||||
if (index === -1) throw new Error("Output not found")
|
||||
|
||||
this.outputs.splice(index, 1)
|
||||
|
||||
const { length } = this.outputs
|
||||
for (let i = index; i < length; i++) {
|
||||
this.outputs[i].decrementSlots("outputs")
|
||||
}
|
||||
|
||||
this.#forAllNodes((node) => {
|
||||
if (node.type === this.id) {
|
||||
node.removeOutput(index)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
draw(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors): void {
|
||||
this.inputNode.draw(ctx, colorContext)
|
||||
this.outputNode.draw(ctx, colorContext)
|
||||
}
|
||||
|
||||
clone(): Subgraph {
|
||||
return new Subgraph(this.rootGraph, this.asSerialisable())
|
||||
}
|
||||
|
||||
override asSerialisable(): ExportedSubgraph & Required<Pick<SerialisableGraph, "nodes" | "groups" | "extra">> {
|
||||
|
||||
@@ -1,53 +1,65 @@
|
||||
import type { EmptySubgraphInput } from "./EmptySubgraphInput"
|
||||
import type { EmptySubgraphOutput } from "./EmptySubgraphOutput"
|
||||
import type { Subgraph } from "./Subgraph"
|
||||
import type { SubgraphInput } from "./SubgraphInput"
|
||||
import type { SubgraphOutput } from "./SubgraphOutput"
|
||||
import type { Point, Positionable, ReadOnlyRect, Rect } from "@/interfaces"
|
||||
import type { LinkConnector } from "@/canvas/LinkConnector"
|
||||
import type { DefaultConnectionColors, Hoverable, Point, Positionable } from "@/interfaces"
|
||||
import type { NodeId } from "@/LGraphNode"
|
||||
import type { ExportedSubgraphIONode, Serialisable } from "@/types/serialisation"
|
||||
|
||||
import { isPointInRect, snapPoint } from "@/measure"
|
||||
import { Rectangle } from "@/infrastructure/Rectangle"
|
||||
import { type CanvasColour, type CanvasPointer, type CanvasPointerEvent, type IContextMenuValue, LiteGraph } from "@/litegraph"
|
||||
import { snapPoint } from "@/measure"
|
||||
import { CanvasItem } from "@/types/globalEnums"
|
||||
|
||||
export abstract class SubgraphIONodeBase implements Positionable, Serialisable<ExportedSubgraphIONode> {
|
||||
export abstract class SubgraphIONodeBase<TSlot extends SubgraphInput | SubgraphOutput> implements Positionable, Hoverable, Serialisable<ExportedSubgraphIONode> {
|
||||
static margin = 10
|
||||
static defaultWidth = 100
|
||||
static minWidth = 100
|
||||
static roundedRadius = 10
|
||||
|
||||
readonly #boundingRect: Float32Array = new Float32Array(4)
|
||||
readonly #pos: Point = this.#boundingRect.subarray(0, 2)
|
||||
readonly #size: Point = this.#boundingRect.subarray(2, 4)
|
||||
readonly #boundingRect: Rectangle = new Rectangle()
|
||||
|
||||
abstract readonly id: NodeId
|
||||
|
||||
get boundingRect(): Rect {
|
||||
get boundingRect(): Rectangle {
|
||||
return this.#boundingRect
|
||||
}
|
||||
|
||||
selected: boolean = false
|
||||
pinned: boolean = false
|
||||
readonly removable = false
|
||||
|
||||
isPointerOver: boolean = false
|
||||
|
||||
abstract readonly emptySlot: EmptySubgraphInput | EmptySubgraphOutput
|
||||
|
||||
get pos() {
|
||||
return this.#pos
|
||||
return this.boundingRect.pos
|
||||
}
|
||||
|
||||
set pos(value) {
|
||||
if (!value || value.length < 2) return
|
||||
|
||||
this.#pos[0] = value[0]
|
||||
this.#pos[1] = value[1]
|
||||
this.boundingRect.pos = value
|
||||
}
|
||||
|
||||
get size() {
|
||||
return this.#size
|
||||
return this.boundingRect.size
|
||||
}
|
||||
|
||||
set size(value) {
|
||||
if (!value || value.length < 2) return
|
||||
|
||||
this.#size[0] = value[0]
|
||||
this.#size[1] = value[1]
|
||||
this.boundingRect.size = value
|
||||
}
|
||||
|
||||
abstract readonly slots: SubgraphInput[] | SubgraphOutput[]
|
||||
protected get sideLineWidth(): number {
|
||||
return this.isPointerOver ? 2.5 : 2
|
||||
}
|
||||
|
||||
protected get sideStrokeStyle(): CanvasColour {
|
||||
return this.isPointerOver ? "white" : "#efefef"
|
||||
}
|
||||
|
||||
abstract readonly slots: TSlot[]
|
||||
abstract get allSlots(): TSlot[]
|
||||
|
||||
constructor(
|
||||
/** The subgraph that this node belongs to. */
|
||||
@@ -64,19 +76,210 @@ export abstract class SubgraphIONodeBase implements Positionable, Serialisable<E
|
||||
return this.pinned ? false : snapPoint(this.pos, snapTo)
|
||||
}
|
||||
|
||||
abstract onPointerDown(e: CanvasPointerEvent, pointer: CanvasPointer, linkConnector: LinkConnector): void
|
||||
|
||||
// #region Hoverable
|
||||
|
||||
containsPoint(point: Point): boolean {
|
||||
return isPointInRect(point, this.boundingRect)
|
||||
return this.boundingRect.containsPoint(point)
|
||||
}
|
||||
|
||||
abstract get slotAnchorX(): number
|
||||
|
||||
onPointerMove(e: CanvasPointerEvent): CanvasItem {
|
||||
const containsPoint = this.boundingRect.containsXy(e.canvasX, e.canvasY)
|
||||
let underPointer = containsPoint ? CanvasItem.SubgraphIoNode : CanvasItem.Nothing
|
||||
|
||||
if (containsPoint) {
|
||||
if (!this.isPointerOver) this.onPointerEnter()
|
||||
|
||||
for (const slot of this.allSlots) {
|
||||
slot.onPointerMove(e)
|
||||
if (slot.isPointerOver) underPointer |= CanvasItem.SubgraphIoSlot
|
||||
}
|
||||
} else if (this.isPointerOver) {
|
||||
this.onPointerLeave()
|
||||
}
|
||||
return underPointer
|
||||
}
|
||||
|
||||
onPointerEnter() {
|
||||
this.isPointerOver = true
|
||||
}
|
||||
|
||||
onPointerLeave() {
|
||||
this.isPointerOver = false
|
||||
|
||||
for (const slot of this.slots) {
|
||||
slot.isPointerOver = false
|
||||
}
|
||||
}
|
||||
|
||||
// #endregion Hoverable
|
||||
|
||||
/**
|
||||
* Renames an IO slot in the subgraph.
|
||||
* @param slot The slot to rename.
|
||||
* @param name The new name for the slot.
|
||||
*/
|
||||
abstract renameSlot(slot: TSlot, name: string): void
|
||||
|
||||
/**
|
||||
* Removes an IO slot from the subgraph.
|
||||
* @param slot The slot to remove.
|
||||
*/
|
||||
abstract removeSlot(slot: TSlot): void
|
||||
|
||||
/**
|
||||
* Gets the slot at a given position in canvas space.
|
||||
* @param x The x coordinate of the position.
|
||||
* @param y The y coordinate of the position.
|
||||
* @returns The slot at the given position, otherwise `undefined`.
|
||||
*/
|
||||
getSlotInPosition(x: number, y: number): TSlot | undefined {
|
||||
for (const slot of this.allSlots) {
|
||||
if (slot.boundingRect.containsXy(x, y)) {
|
||||
return slot
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Shows the context menu for an IO slot.
|
||||
* @param slot The slot to show the context menu for.
|
||||
* @param event The event that triggered the context menu.
|
||||
*/
|
||||
protected showSlotContextMenu(slot: TSlot, event: CanvasPointerEvent): void {
|
||||
const options: IContextMenuValue[] = this.#getSlotMenuOptions(slot)
|
||||
if (!(options.length > 0)) return
|
||||
|
||||
new LiteGraph.ContextMenu(
|
||||
options,
|
||||
{
|
||||
event: event as any,
|
||||
title: slot.name || "Subgraph Output",
|
||||
callback: (item: IContextMenuValue) => {
|
||||
this.#onSlotMenuAction(item, slot, event)
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the context menu options for an IO slot.
|
||||
* @param slot The slot to get the context menu options for.
|
||||
* @returns The context menu options.
|
||||
*/
|
||||
#getSlotMenuOptions(slot: TSlot): IContextMenuValue[] {
|
||||
const options: IContextMenuValue[] = []
|
||||
|
||||
// Disconnect option if slot has connections
|
||||
if (slot !== this.emptySlot && slot.linkIds.length > 0) {
|
||||
options.push({ content: "Disconnect Links", value: "disconnect" })
|
||||
}
|
||||
|
||||
// Remove / rename slot option (except for the empty slot)
|
||||
if (slot !== this.emptySlot) {
|
||||
options.push(
|
||||
{ content: "Remove Slot", value: "remove" },
|
||||
{ content: "Rename Slot", value: "rename" },
|
||||
)
|
||||
}
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the action for an IO slot context menu.
|
||||
* @param selectedItem The item that was selected from the context menu.
|
||||
* @param slot The slot
|
||||
* @param event The event that triggered the context menu.
|
||||
*/
|
||||
#onSlotMenuAction(selectedItem: IContextMenuValue, slot: TSlot, event: CanvasPointerEvent): void {
|
||||
switch (selectedItem.value) {
|
||||
// Disconnect all links from this output
|
||||
case "disconnect":
|
||||
slot.disconnect()
|
||||
break
|
||||
|
||||
// Remove the slot
|
||||
case "remove":
|
||||
if (slot !== this.emptySlot) {
|
||||
this.removeSlot(slot)
|
||||
}
|
||||
break
|
||||
|
||||
// Rename the slot
|
||||
case "rename":
|
||||
if (slot !== this.emptySlot) {
|
||||
this.subgraph.canvasAction(c => c.prompt(
|
||||
"Slot name",
|
||||
slot.name,
|
||||
(newName: string) => {
|
||||
if (newName) this.renameSlot(slot, newName)
|
||||
},
|
||||
event,
|
||||
))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
this.subgraph.setDirtyCanvas(true)
|
||||
}
|
||||
|
||||
/** Arrange the slots in this node. */
|
||||
arrange(): void {
|
||||
const { minWidth, roundedRadius } = SubgraphIONodeBase
|
||||
const [, y] = this.boundingRect
|
||||
const x = this.slotAnchorX
|
||||
const { size } = this
|
||||
|
||||
let maxWidth = minWidth
|
||||
let currentY = y + roundedRadius
|
||||
|
||||
for (const slot of this.allSlots) {
|
||||
const [slotWidth, slotHeight] = slot.measure()
|
||||
slot.arrange([x, currentY, slotWidth, slotHeight])
|
||||
|
||||
currentY += slotHeight
|
||||
if (slotWidth > maxWidth) maxWidth = slotWidth
|
||||
}
|
||||
|
||||
size[0] = maxWidth + 2 * roundedRadius
|
||||
size[1] = currentY - y + roundedRadius
|
||||
}
|
||||
|
||||
draw(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors): void {
|
||||
const { lineWidth, strokeStyle, fillStyle, font, textBaseline } = ctx
|
||||
this.drawProtected(ctx, colorContext)
|
||||
Object.assign(ctx, { lineWidth, strokeStyle, fillStyle, font, textBaseline })
|
||||
}
|
||||
|
||||
/** @internal Leaves {@link ctx} dirty. */
|
||||
protected abstract drawProtected(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors): void
|
||||
|
||||
/** @internal Leaves {@link ctx} dirty. */
|
||||
protected drawSlots(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors): void {
|
||||
ctx.fillStyle = "#AAA"
|
||||
ctx.font = "12px Arial"
|
||||
ctx.textBaseline = "middle"
|
||||
|
||||
for (const slot of this.allSlots) {
|
||||
slot.draw({ ctx, colorContext })
|
||||
slot.drawLabel(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
configure(data: ExportedSubgraphIONode): void {
|
||||
this.#boundingRect.set(data.bounding)
|
||||
this.pinned = data.pinned ?? false
|
||||
}
|
||||
|
||||
asSerialisable(): ExportedSubgraphIONode {
|
||||
return {
|
||||
id: this.id,
|
||||
bounding: serialiseRect(this.boundingRect),
|
||||
bounding: this.boundingRect.export(),
|
||||
pinned: this.pinned ? true : undefined,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function serialiseRect(rect: ReadOnlyRect): [number, number, number, number] {
|
||||
return [rect[0], rect[1], rect[2], rect[3]]
|
||||
}
|
||||
|
||||
@@ -1,8 +1,104 @@
|
||||
import type { Point, ReadOnlyRect } from "@/interfaces"
|
||||
import type { SubgraphInputNode } from "./SubgraphInputNode"
|
||||
import type { INodeInputSlot, Point, ReadOnlyRect } from "@/interfaces"
|
||||
import type { LGraphNode } from "@/LGraphNode"
|
||||
import type { RerouteId } from "@/Reroute"
|
||||
|
||||
import { LLink } from "@/LLink"
|
||||
import { NodeSlotType } from "@/types/globalEnums"
|
||||
|
||||
import { SubgraphSlot } from "./SubgraphSlotBase"
|
||||
|
||||
/**
|
||||
* An input "slot" from a parent graph into a subgraph.
|
||||
*
|
||||
* IMPORTANT: A subgraph "input" is both an input AND an output. It creates an extra link connection point between
|
||||
* a parent graph and a subgraph, so is conceptually similar to a reroute.
|
||||
*
|
||||
* This can be a little confusing, but is easier to visualise when imagining editing a subgraph.
|
||||
* You have "Subgraph Inputs", because they are coming into the subgraph, which then connect to "node inputs".
|
||||
*
|
||||
* Functionally, however, when editing a subgraph, that "subgraph input" is the "origin" or "output side" of a link.
|
||||
*/
|
||||
export class SubgraphInput extends SubgraphSlot {
|
||||
declare parent: SubgraphInputNode
|
||||
|
||||
override connect(slot: INodeInputSlot, node: LGraphNode, afterRerouteId?: RerouteId): LLink | undefined {
|
||||
const { subgraph } = this.parent
|
||||
|
||||
// Allow nodes to block connection
|
||||
const inputIndex = node.inputs.indexOf(slot)
|
||||
if (node.onConnectInput?.(inputIndex, this.type, this, this.parent, -1) === false) return
|
||||
|
||||
// if (slot instanceof SubgraphOutput) {
|
||||
// // Subgraph IO nodes have no special handling at present.
|
||||
// return new LLink(
|
||||
// ++subgraph.state.lastLinkId,
|
||||
// this.type,
|
||||
// this.parent.id,
|
||||
// this.parent.slots.indexOf(this),
|
||||
// node.id,
|
||||
// inputIndex,
|
||||
// afterRerouteId,
|
||||
// )
|
||||
// }
|
||||
|
||||
// Disconnect target input, if it is already connected.
|
||||
if (slot.link != null) {
|
||||
subgraph.beforeChange()
|
||||
const link = subgraph.getLink(slot.link)
|
||||
this.parent._disconnectNodeInput(node, slot, link)
|
||||
}
|
||||
|
||||
const link = new LLink(
|
||||
++subgraph.state.lastLinkId,
|
||||
slot.type,
|
||||
this.parent.id,
|
||||
this.parent.slots.indexOf(this),
|
||||
node.id,
|
||||
inputIndex,
|
||||
afterRerouteId,
|
||||
)
|
||||
|
||||
// Add to graph links list
|
||||
subgraph._links.set(link.id, link)
|
||||
|
||||
// Set link ID in each slot
|
||||
this.linkIds.push(link.id)
|
||||
slot.link = link.id
|
||||
|
||||
// Reroutes
|
||||
const reroutes = LLink.getReroutes(subgraph, link)
|
||||
for (const reroute of reroutes) {
|
||||
reroute.linkIds.add(link.id)
|
||||
if (reroute.floating) delete reroute.floating
|
||||
reroute._dragging = undefined
|
||||
}
|
||||
|
||||
// If this is the terminus of a floating link, remove it
|
||||
const lastReroute = reroutes.at(-1)
|
||||
if (lastReroute) {
|
||||
for (const linkId of lastReroute.floatingLinkIds) {
|
||||
const link = subgraph.floatingLinks.get(linkId)
|
||||
if (link?.parentId === lastReroute.id) {
|
||||
subgraph.removeFloatingLink(link)
|
||||
}
|
||||
}
|
||||
}
|
||||
subgraph._version++
|
||||
|
||||
node.onConnectionsChange?.(
|
||||
NodeSlotType.INPUT,
|
||||
inputIndex,
|
||||
true,
|
||||
link,
|
||||
slot,
|
||||
)
|
||||
|
||||
subgraph.afterChange()
|
||||
|
||||
return link
|
||||
}
|
||||
|
||||
get labelPos(): Point {
|
||||
const [x, y, , height] = this.boundingRect
|
||||
return [x, y + height * 0.5]
|
||||
|
||||
@@ -1,12 +1,187 @@
|
||||
import type { Positionable } from "@/interfaces"
|
||||
import type { NodeId } from "@/LGraphNode"
|
||||
import type { SubgraphInput } from "./SubgraphInput"
|
||||
import type { LinkConnector } from "@/canvas/LinkConnector"
|
||||
import type { CanvasPointer } from "@/CanvasPointer"
|
||||
import type { DefaultConnectionColors, INodeInputSlot, ISlotType, Positionable } from "@/interfaces"
|
||||
import type { LGraphNode, NodeId } from "@/LGraphNode"
|
||||
import type { RerouteId } from "@/Reroute"
|
||||
import type { CanvasPointerEvent } from "@/types/events"
|
||||
import type { NodeLike } from "@/types/NodeLike"
|
||||
|
||||
import { SUBGRAPH_INPUT_ID } from "@/constants"
|
||||
import { Rectangle } from "@/infrastructure/Rectangle"
|
||||
import { LLink } from "@/LLink"
|
||||
import { NodeSlotType } from "@/types/globalEnums"
|
||||
import { findFreeSlotOfType } from "@/utils/collections"
|
||||
|
||||
import { EmptySubgraphInput } from "./EmptySubgraphInput"
|
||||
import { SubgraphIONodeBase } from "./SubgraphIONodeBase"
|
||||
|
||||
export class SubgraphInputNode extends SubgraphIONodeBase implements Positionable {
|
||||
readonly id: NodeId = -10
|
||||
export class SubgraphInputNode extends SubgraphIONodeBase<SubgraphInput> implements Positionable {
|
||||
readonly id: NodeId = SUBGRAPH_INPUT_ID
|
||||
|
||||
readonly emptySlot: EmptySubgraphInput = new EmptySubgraphInput(this)
|
||||
|
||||
get slots() {
|
||||
return this.subgraph.inputs
|
||||
}
|
||||
|
||||
override get allSlots(): SubgraphInput[] {
|
||||
return [...this.slots, this.emptySlot]
|
||||
}
|
||||
|
||||
get slotAnchorX() {
|
||||
const [x, , width] = this.boundingRect
|
||||
return x + width - SubgraphIONodeBase.roundedRadius
|
||||
}
|
||||
|
||||
override onPointerDown(e: CanvasPointerEvent, pointer: CanvasPointer, linkConnector: LinkConnector): void {
|
||||
// Left-click handling for dragging connections
|
||||
if (e.button === 0) {
|
||||
for (const slot of this.allSlots) {
|
||||
const slotBounds = Rectangle.fromCentre(slot.pos, slot.boundingRect.height)
|
||||
|
||||
if (slotBounds.containsXy(e.canvasX, e.canvasY)) {
|
||||
pointer.onDragStart = () => {
|
||||
linkConnector.dragNewFromSubgraphInput(this.subgraph, this, slot)
|
||||
}
|
||||
pointer.onDragEnd = (eUp) => {
|
||||
linkConnector.dropLinks(this.subgraph, eUp)
|
||||
}
|
||||
pointer.finally = () => {
|
||||
linkConnector.reset(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check for right-click
|
||||
} else if (e.button === 2) {
|
||||
const slot = this.getSlotInPosition(e.canvasX, e.canvasY)
|
||||
if (slot) this.showSlotContextMenu(slot, e)
|
||||
}
|
||||
}
|
||||
|
||||
/** @inheritdoc */
|
||||
override renameSlot(slot: SubgraphInput, name: string): void {
|
||||
this.subgraph.renameInput(slot, name)
|
||||
}
|
||||
|
||||
/** @inheritdoc */
|
||||
override removeSlot(slot: SubgraphInput): void {
|
||||
this.subgraph.removeInput(slot)
|
||||
}
|
||||
|
||||
canConnectTo(inputNode: NodeLike, input: INodeInputSlot, fromSlot: SubgraphInput): boolean {
|
||||
return inputNode.canConnectTo(this, input, fromSlot)
|
||||
}
|
||||
|
||||
connectSlots(fromSlot: SubgraphInput, inputNode: LGraphNode, input: INodeInputSlot, afterRerouteId: RerouteId | undefined): LLink {
|
||||
const { subgraph } = this
|
||||
|
||||
const outputIndex = this.slots.indexOf(fromSlot)
|
||||
const inputIndex = inputNode.inputs.indexOf(input)
|
||||
|
||||
if (outputIndex === -1 || inputIndex === -1) throw new Error("Invalid slot indices.")
|
||||
|
||||
return new LLink(
|
||||
++subgraph.state.lastLinkId,
|
||||
input.type || fromSlot.type,
|
||||
this.id,
|
||||
outputIndex,
|
||||
inputNode.id,
|
||||
inputIndex,
|
||||
afterRerouteId,
|
||||
)
|
||||
}
|
||||
|
||||
// #region Legacy LGraphNode compatibility
|
||||
|
||||
connectByType(
|
||||
slot: number,
|
||||
target_node: LGraphNode,
|
||||
target_slotType: ISlotType,
|
||||
optsIn?: { afterRerouteId?: RerouteId },
|
||||
): LLink | undefined {
|
||||
const inputSlot = target_node.findInputByType(target_slotType)
|
||||
if (!inputSlot) return
|
||||
|
||||
return this.slots[slot].connect(inputSlot.slot, target_node, optsIn?.afterRerouteId)
|
||||
}
|
||||
|
||||
findOutputSlot(name: string): SubgraphInput | undefined {
|
||||
return this.slots.find(output => output.name === name)
|
||||
}
|
||||
|
||||
findOutputByType(type: ISlotType): SubgraphInput | undefined {
|
||||
return findFreeSlotOfType(this.slots, type, slot => slot.linkIds.length > 0)?.slot
|
||||
}
|
||||
|
||||
// #endregion Legacy LGraphNode compatibility
|
||||
|
||||
_disconnectNodeInput(node: LGraphNode, input: INodeInputSlot, link: LLink | undefined): void {
|
||||
const { subgraph } = this
|
||||
|
||||
// Break floating links
|
||||
if (input._floatingLinks?.size) {
|
||||
for (const link of input._floatingLinks) {
|
||||
subgraph.removeFloatingLink(link)
|
||||
}
|
||||
}
|
||||
|
||||
input.link = null
|
||||
subgraph.setDirtyCanvas(false, true)
|
||||
|
||||
if (!link) return
|
||||
|
||||
const subgraphInputIndex = link.origin_slot
|
||||
link.disconnect(subgraph, "output")
|
||||
subgraph._version++
|
||||
|
||||
const subgraphInput = this.slots.at(subgraphInputIndex)
|
||||
if (!subgraphInput) {
|
||||
console.debug("disconnectNodeInput: subgraphInput not found", this, subgraphInputIndex)
|
||||
return
|
||||
}
|
||||
|
||||
// search in the inputs list for this link
|
||||
const index = subgraphInput.linkIds.indexOf(link.id)
|
||||
if (index !== -1) {
|
||||
subgraphInput.linkIds.splice(index, 1)
|
||||
} else {
|
||||
console.debug("disconnectNodeInput: link ID not found in subgraphInput linkIds", link.id)
|
||||
}
|
||||
|
||||
node.onConnectionsChange?.(
|
||||
NodeSlotType.OUTPUT,
|
||||
index,
|
||||
false,
|
||||
link,
|
||||
subgraphInput,
|
||||
)
|
||||
}
|
||||
|
||||
override drawProtected(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors): void {
|
||||
const { roundedRadius } = SubgraphIONodeBase
|
||||
const transform = ctx.getTransform()
|
||||
|
||||
const [x, y, width, height] = this.boundingRect
|
||||
ctx.translate(x, y)
|
||||
|
||||
// Draw top rounded part
|
||||
ctx.strokeStyle = this.sideStrokeStyle
|
||||
ctx.lineWidth = this.sideLineWidth
|
||||
ctx.beginPath()
|
||||
ctx.arc(width - roundedRadius, roundedRadius, roundedRadius, Math.PI * 1.5, 0)
|
||||
|
||||
// Straight line to bottom
|
||||
ctx.moveTo(width, roundedRadius)
|
||||
ctx.lineTo(width, height - roundedRadius)
|
||||
|
||||
// Bottom rounded part
|
||||
ctx.arc(width - roundedRadius, height - roundedRadius, roundedRadius, 0, Math.PI * 0.5)
|
||||
ctx.stroke()
|
||||
|
||||
// Restore context
|
||||
ctx.setTransform(transform)
|
||||
|
||||
this.drawSlots(ctx, colorContext)
|
||||
}
|
||||
}
|
||||
|
||||
148
src/subgraph/SubgraphNode.ts
Normal file
148
src/subgraph/SubgraphNode.ts
Normal file
@@ -0,0 +1,148 @@
|
||||
import type { ISubgraphInput } from "@/interfaces"
|
||||
import type { BaseLGraph, LGraph } from "@/LGraph"
|
||||
import type { INodeInputSlot, ISlotType, NodeId } from "@/litegraph"
|
||||
import type { GraphOrSubgraph, Subgraph } from "@/subgraph/Subgraph"
|
||||
import type { ExportedSubgraphInstance } from "@/types/serialisation"
|
||||
import type { UUID } from "@/utils/uuid"
|
||||
|
||||
import { RecursionError } from "@/infrastructure/RecursionError"
|
||||
import { LGraphNode } from "@/LGraphNode"
|
||||
import { LLink, type ResolvedConnection } from "@/LLink"
|
||||
import { NodeInputSlot } from "@/node/NodeInputSlot"
|
||||
import { NodeOutputSlot } from "@/node/NodeOutputSlot"
|
||||
|
||||
import { type ExecutableLGraphNode, ExecutableNodeDTO } from "./ExecutableNodeDTO"
|
||||
|
||||
/**
|
||||
* An instance of a {@link Subgraph}, displayed as a node on the containing (parent) graph.
|
||||
*/
|
||||
export class SubgraphNode extends LGraphNode implements BaseLGraph {
|
||||
override readonly type: UUID
|
||||
override readonly isVirtualNode = true as const
|
||||
|
||||
get rootGraph(): LGraph {
|
||||
return this.graph.rootGraph
|
||||
}
|
||||
|
||||
override get displayType(): string {
|
||||
return "Subgraph node"
|
||||
}
|
||||
|
||||
override isSubgraphNode(): this is SubgraphNode {
|
||||
return true
|
||||
}
|
||||
|
||||
constructor(
|
||||
/** The (sub)graph that contains this subgraph instance. */
|
||||
override readonly graph: GraphOrSubgraph,
|
||||
/** The definition of this subgraph; how its nodes are configured, etc. */
|
||||
readonly subgraph: Subgraph,
|
||||
instanceData: ExportedSubgraphInstance,
|
||||
) {
|
||||
super(subgraph.name, subgraph.id)
|
||||
|
||||
this.type = subgraph.id
|
||||
this.configure(instanceData)
|
||||
}
|
||||
|
||||
override configure(info: ExportedSubgraphInstance): void {
|
||||
this.inputs.length = 0
|
||||
this.inputs.push(
|
||||
...this.subgraph.inputNode.slots.map(
|
||||
slot => new NodeInputSlot({ name: slot.name, localized_name: slot.localized_name, label: slot.label, type: slot.type, link: null }, this),
|
||||
),
|
||||
)
|
||||
|
||||
this.outputs.length = 0
|
||||
this.outputs.push(
|
||||
...this.subgraph.outputNode.slots.map(
|
||||
slot => new NodeOutputSlot({ name: slot.name, localized_name: slot.localized_name, label: slot.label, type: slot.type, links: null }, this),
|
||||
),
|
||||
)
|
||||
|
||||
super.configure(info)
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensures the subgraph slot is in the params before adding the input as normal.
|
||||
* @param name The name of the input slot.
|
||||
* @param type The type of the input slot.
|
||||
* @param inputProperties Properties that are directly assigned to the created input. Default: a new, empty object.
|
||||
* @returns The new input slot.
|
||||
* @remarks Assertion is required to instantiate empty generic POJO.
|
||||
*/
|
||||
override addInput<TInput extends Partial<ISubgraphInput>>(name: string, type: ISlotType, inputProperties: TInput = {} as TInput): INodeInputSlot & TInput {
|
||||
// Bypasses type narrowing on this.inputs
|
||||
return super.addInput(name, type, inputProperties)
|
||||
}
|
||||
|
||||
override getInputLink(slot: number): LLink | null {
|
||||
// Output side: the link from inside the subgraph
|
||||
const innerLink = this.subgraph.outputNode.slots[slot].getLinks().at(0)
|
||||
if (!innerLink) {
|
||||
console.warn(`SubgraphNode.getInputLink: no inner link found for slot ${slot}`)
|
||||
return null
|
||||
}
|
||||
|
||||
const newLink = LLink.create(innerLink)
|
||||
newLink.origin_id = `${this.id}:${innerLink.origin_id}`
|
||||
newLink.origin_slot = innerLink.origin_slot
|
||||
|
||||
return newLink
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the internal links connected to the given input slot inside the subgraph, and resolves the nodes / slots.
|
||||
* @param slot The slot index
|
||||
* @returns The resolved connections, or undefined if no input node is found.
|
||||
* @remarks This is used to resolve the input links when dragging a link from a subgraph input slot.
|
||||
*/
|
||||
resolveSubgraphInputLinks(slot: number): ResolvedConnection[] {
|
||||
const inputSlot = this.subgraph.inputNode.slots[slot]
|
||||
const innerLinks = inputSlot.getLinks()
|
||||
if (innerLinks.length === 0) {
|
||||
console.debug(`[SubgraphNode.resolveSubgraphInputLinks] No inner links found for input slot [${slot}] ${inputSlot.name}`, this)
|
||||
return []
|
||||
}
|
||||
return innerLinks.map(link => link.resolve(this.subgraph))
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds the internal link connected to the given output slot inside the subgraph, and resolves the nodes / slots.
|
||||
* @param slot The slot index
|
||||
* @returns The output node if found, otherwise undefined.
|
||||
*/
|
||||
resolveSubgraphOutputLink(slot: number): ResolvedConnection | undefined {
|
||||
const outputSlot = this.subgraph.outputNode.slots[slot]
|
||||
const innerLink = outputSlot.getLinks().at(0)
|
||||
if (innerLink) return innerLink.resolve(this.subgraph)
|
||||
|
||||
console.debug(`[SubgraphNode.resolveSubgraphOutputLink] No inner link found for output slot [${slot}] ${outputSlot.name}`, this)
|
||||
}
|
||||
|
||||
/** @internal Used to flatten the subgraph before execution. Recursive; call with no args. */
|
||||
getInnerNodes(
|
||||
/** The list of nodes to add to. */
|
||||
nodes: ExecutableLGraphNode[] = [],
|
||||
/** The set of visited nodes. */
|
||||
visited = new WeakSet<SubgraphNode>(),
|
||||
/** The path of subgraph node IDs. */
|
||||
subgraphNodePath: readonly NodeId[] = [],
|
||||
): ExecutableLGraphNode[] {
|
||||
if (visited.has(this)) throw new RecursionError("while flattening subgraph")
|
||||
visited.add(this)
|
||||
|
||||
const subgraphInstanceIdPath = [...subgraphNodePath, this.id]
|
||||
|
||||
for (const node of this.subgraph.nodes) {
|
||||
if ("getInnerNodes" in node) {
|
||||
node.getInnerNodes(nodes, visited, subgraphInstanceIdPath)
|
||||
} else {
|
||||
// Create minimal DTOs rather than cloning the node
|
||||
const aVeryRealNode = new ExecutableNodeDTO(node, subgraphInstanceIdPath, this)
|
||||
nodes.push(aVeryRealNode)
|
||||
}
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,99 @@
|
||||
import type { Point, ReadOnlyRect } from "@/interfaces"
|
||||
import type { SubgraphOutputNode } from "./SubgraphOutputNode"
|
||||
import type { INodeOutputSlot, Point, ReadOnlyRect } from "@/interfaces"
|
||||
import type { LGraphNode } from "@/LGraphNode"
|
||||
import type { RerouteId } from "@/Reroute"
|
||||
|
||||
import { LLink } from "@/LLink"
|
||||
import { NodeSlotType } from "@/types/globalEnums"
|
||||
import { removeFromArray } from "@/utils/collections"
|
||||
|
||||
import { SubgraphSlot } from "./SubgraphSlotBase"
|
||||
|
||||
/**
|
||||
* An output "slot" from a subgraph to a parent graph.
|
||||
*
|
||||
* IMPORTANT: A subgraph "output" is both an output AND an input. It creates an extra link connection point between
|
||||
* a parent graph and a subgraph, so is conceptually similar to a reroute.
|
||||
*
|
||||
* This can be a little confusing, but is easier to visualise when imagining editing a subgraph.
|
||||
* You have "Subgraph Outputs", because they go from inside the subgraph and out, but links to them come from "node outputs".
|
||||
*
|
||||
* Functionally, however, when editing a subgraph, that "subgraph output" is the "target" or "input side" of a link.
|
||||
*/
|
||||
export class SubgraphOutput extends SubgraphSlot {
|
||||
declare parent: SubgraphOutputNode
|
||||
|
||||
override connect(slot: INodeOutputSlot, node: LGraphNode, afterRerouteId?: RerouteId): LLink | undefined {
|
||||
const { subgraph } = this.parent
|
||||
|
||||
// Allow nodes to block connection
|
||||
const outputIndex = node.outputs.indexOf(slot)
|
||||
if (outputIndex === -1) throw new Error("Slot is not an output of the given node")
|
||||
|
||||
if (node.onConnectOutput?.(outputIndex, this.type, this, this.parent, -1) === false) return
|
||||
|
||||
// Link should not be present, but just in case, disconnect it
|
||||
const existingLink = this.getLinks().at(0)
|
||||
if (existingLink != null) {
|
||||
subgraph.beforeChange()
|
||||
|
||||
existingLink.disconnect(subgraph, "input")
|
||||
const resolved = existingLink.resolve(subgraph)
|
||||
const links = resolved.output?.links
|
||||
if (links) removeFromArray(links, existingLink.id)
|
||||
}
|
||||
|
||||
const link = new LLink(
|
||||
++subgraph.state.lastLinkId,
|
||||
slot.type,
|
||||
node.id,
|
||||
outputIndex,
|
||||
this.parent.id,
|
||||
this.parent.slots.indexOf(this),
|
||||
afterRerouteId,
|
||||
)
|
||||
|
||||
// Add to graph links list
|
||||
subgraph._links.set(link.id, link)
|
||||
|
||||
// Set link ID in each slot
|
||||
this.linkIds[0] = link.id
|
||||
slot.links ??= []
|
||||
slot.links.push(link.id)
|
||||
|
||||
// Reroutes
|
||||
const reroutes = LLink.getReroutes(subgraph, link)
|
||||
for (const reroute of reroutes) {
|
||||
reroute.linkIds.add(link.id)
|
||||
if (reroute.floating) delete reroute.floating
|
||||
reroute._dragging = undefined
|
||||
}
|
||||
|
||||
// If this is the terminus of a floating link, remove it
|
||||
const lastReroute = reroutes.at(-1)
|
||||
if (lastReroute) {
|
||||
for (const linkId of lastReroute.floatingLinkIds) {
|
||||
const link = subgraph.floatingLinks.get(linkId)
|
||||
if (link?.parentId === lastReroute.id) {
|
||||
subgraph.removeFloatingLink(link)
|
||||
}
|
||||
}
|
||||
}
|
||||
subgraph._version++
|
||||
|
||||
node.onConnectionsChange?.(
|
||||
NodeSlotType.OUTPUT,
|
||||
outputIndex,
|
||||
true,
|
||||
link,
|
||||
slot,
|
||||
)
|
||||
|
||||
subgraph.afterChange()
|
||||
|
||||
return link
|
||||
}
|
||||
|
||||
get labelPos(): Point {
|
||||
const [x, y, , height] = this.boundingRect
|
||||
return [x + height, y + height * 0.5]
|
||||
|
||||
@@ -1,12 +1,119 @@
|
||||
import type { Positionable } from "@/interfaces"
|
||||
import type { NodeId } from "@/LGraphNode"
|
||||
import type { SubgraphOutput } from "./SubgraphOutput"
|
||||
import type { LinkConnector } from "@/canvas/LinkConnector"
|
||||
import type { CanvasPointer } from "@/CanvasPointer"
|
||||
import type { DefaultConnectionColors, ISlotType, Positionable } from "@/interfaces"
|
||||
import type { INodeOutputSlot } from "@/interfaces"
|
||||
import type { LGraphNode, NodeId } from "@/LGraphNode"
|
||||
import type { LLink } from "@/LLink"
|
||||
import type { RerouteId } from "@/Reroute"
|
||||
import type { CanvasPointerEvent } from "@/types/events"
|
||||
import type { NodeLike } from "@/types/NodeLike"
|
||||
import type { SubgraphIO } from "@/types/serialisation"
|
||||
|
||||
import { SUBGRAPH_OUTPUT_ID } from "@/constants"
|
||||
import { Rectangle } from "@/infrastructure/Rectangle"
|
||||
import { findFreeSlotOfType } from "@/utils/collections"
|
||||
|
||||
import { EmptySubgraphOutput } from "./EmptySubgraphOutput"
|
||||
import { SubgraphIONodeBase } from "./SubgraphIONodeBase"
|
||||
|
||||
export class SubgraphOutputNode extends SubgraphIONodeBase implements Positionable {
|
||||
readonly id: NodeId = -20
|
||||
export class SubgraphOutputNode extends SubgraphIONodeBase<SubgraphOutput> implements Positionable {
|
||||
readonly id: NodeId = SUBGRAPH_OUTPUT_ID
|
||||
|
||||
readonly emptySlot: EmptySubgraphOutput = new EmptySubgraphOutput(this)
|
||||
|
||||
get slots() {
|
||||
return this.subgraph.outputs
|
||||
}
|
||||
|
||||
override get allSlots(): SubgraphOutput[] {
|
||||
return [...this.slots, this.emptySlot]
|
||||
}
|
||||
|
||||
get slotAnchorX() {
|
||||
const [x] = this.boundingRect
|
||||
return x + SubgraphIONodeBase.roundedRadius
|
||||
}
|
||||
|
||||
override onPointerDown(e: CanvasPointerEvent, pointer: CanvasPointer, linkConnector: LinkConnector): void {
|
||||
// Left-click handling for dragging connections
|
||||
if (e.button === 0) {
|
||||
for (const slot of this.allSlots) {
|
||||
const slotBounds = Rectangle.fromCentre(slot.pos, slot.boundingRect.height)
|
||||
|
||||
if (slotBounds.containsXy(e.canvasX, e.canvasY)) {
|
||||
pointer.onDragStart = () => {
|
||||
linkConnector.dragNewFromSubgraphOutput(this.subgraph, this, slot)
|
||||
}
|
||||
pointer.onDragEnd = (eUp) => {
|
||||
linkConnector.dropLinks(this.subgraph, eUp)
|
||||
}
|
||||
pointer.finally = () => {
|
||||
linkConnector.reset(true)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check for right-click
|
||||
} else if (e.button === 2) {
|
||||
const slot = this.getSlotInPosition(e.canvasX, e.canvasY)
|
||||
if (slot) this.showSlotContextMenu(slot, e)
|
||||
}
|
||||
}
|
||||
|
||||
/** @inheritdoc */
|
||||
override renameSlot(slot: SubgraphOutput, name: string): void {
|
||||
this.subgraph.renameOutput(slot, name)
|
||||
}
|
||||
|
||||
/** @inheritdoc */
|
||||
override removeSlot(slot: SubgraphOutput): void {
|
||||
this.subgraph.removeOutput(slot)
|
||||
}
|
||||
|
||||
canConnectTo(outputNode: NodeLike, fromSlot: SubgraphOutput, output: INodeOutputSlot | SubgraphIO): boolean {
|
||||
return outputNode.canConnectTo(this, fromSlot, output)
|
||||
}
|
||||
|
||||
connectByTypeOutput(
|
||||
slot: number,
|
||||
target_node: LGraphNode,
|
||||
target_slotType: ISlotType,
|
||||
optsIn?: { afterRerouteId?: RerouteId },
|
||||
): LLink | undefined {
|
||||
const outputSlot = target_node.findOutputByType(target_slotType)
|
||||
if (!outputSlot) return
|
||||
|
||||
return this.slots[slot].connect(outputSlot.slot, target_node, optsIn?.afterRerouteId)
|
||||
}
|
||||
|
||||
findInputByType(type: ISlotType): SubgraphOutput | undefined {
|
||||
return findFreeSlotOfType(this.slots, type, slot => slot.linkIds.length > 0)?.slot
|
||||
}
|
||||
|
||||
override drawProtected(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors): void {
|
||||
const { roundedRadius } = SubgraphIONodeBase
|
||||
const transform = ctx.getTransform()
|
||||
|
||||
const [x, y, , height] = this.boundingRect
|
||||
ctx.translate(x, y)
|
||||
|
||||
// Draw bottom rounded part
|
||||
ctx.strokeStyle = this.sideStrokeStyle
|
||||
ctx.lineWidth = this.sideLineWidth
|
||||
ctx.beginPath()
|
||||
ctx.arc(roundedRadius, roundedRadius, roundedRadius, Math.PI, Math.PI * 1.5)
|
||||
|
||||
// Straight line to bottom
|
||||
ctx.moveTo(0, roundedRadius)
|
||||
ctx.lineTo(0, height - roundedRadius)
|
||||
|
||||
// Bottom rounded part
|
||||
ctx.arc(roundedRadius, height - roundedRadius, roundedRadius, Math.PI, Math.PI * 0.5, true)
|
||||
ctx.stroke()
|
||||
|
||||
// Restore context
|
||||
ctx.setTransform(transform)
|
||||
|
||||
this.drawSlots(ctx, colorContext)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,27 +1,43 @@
|
||||
import type { SubgraphIONodeBase } from "./SubgraphIONodeBase"
|
||||
import type { Point, ReadOnlyRect, Rect } from "@/interfaces"
|
||||
import type { LinkId } from "@/LLink"
|
||||
import type { SubgraphInputNode } from "./SubgraphInputNode"
|
||||
import type { SubgraphOutputNode } from "./SubgraphOutputNode"
|
||||
import type { DefaultConnectionColors, Hoverable, INodeInputSlot, INodeOutputSlot, Point, ReadOnlyRect, ReadOnlySize } from "@/interfaces"
|
||||
import type { LGraphNode } from "@/LGraphNode"
|
||||
import type { LinkId, LLink } from "@/LLink"
|
||||
import type { RerouteId } from "@/Reroute"
|
||||
import type { CanvasPointerEvent } from "@/types/events"
|
||||
import type { Serialisable, SubgraphIO } from "@/types/serialisation"
|
||||
|
||||
import { SlotShape } from "@/draw"
|
||||
import { ConstrainedSize } from "@/infrastructure/ConstrainedSize"
|
||||
import { Rectangle } from "@/infrastructure/Rectangle"
|
||||
import { LGraphCanvas } from "@/LGraphCanvas"
|
||||
import { LiteGraph } from "@/litegraph"
|
||||
import { SlotBase } from "@/node/SlotBase"
|
||||
import { createUuidv4, type UUID } from "@/utils/uuid"
|
||||
|
||||
export interface SubgraphSlotDrawOptions {
|
||||
ctx: CanvasRenderingContext2D
|
||||
colorContext: DefaultConnectionColors
|
||||
lowQuality?: boolean
|
||||
}
|
||||
|
||||
/** Shared base class for the slots used on Subgraph . */
|
||||
export abstract class SubgraphSlot extends SlotBase implements SubgraphIO, Serialisable<SubgraphIO> {
|
||||
export abstract class SubgraphSlot extends SlotBase implements SubgraphIO, Hoverable, Serialisable<SubgraphIO> {
|
||||
static get defaultHeight() {
|
||||
return LiteGraph.NODE_SLOT_HEIGHT
|
||||
}
|
||||
|
||||
readonly #pos: Point = new Float32Array(2)
|
||||
|
||||
readonly measurement: ConstrainedSize = new ConstrainedSize(SubgraphSlot.defaultHeight, SubgraphSlot.defaultHeight)
|
||||
|
||||
readonly id: UUID
|
||||
readonly parent: SubgraphIONodeBase
|
||||
readonly parent: SubgraphInputNode | SubgraphOutputNode
|
||||
override type: string
|
||||
|
||||
readonly linkIds: LinkId[] = []
|
||||
|
||||
override readonly boundingRect: Rect = [0, 0, 0, SubgraphSlot.defaultHeight]
|
||||
override readonly boundingRect: Rectangle = new Rectangle(0, 0, 0, SubgraphSlot.defaultHeight)
|
||||
|
||||
override get pos() {
|
||||
return this.#pos
|
||||
@@ -46,8 +62,8 @@ export abstract class SubgraphSlot extends SlotBase implements SubgraphIO, Seria
|
||||
|
||||
abstract get labelPos(): Point
|
||||
|
||||
constructor(slot: SubgraphIO, parent: SubgraphIONodeBase) {
|
||||
super(slot.name, slot.type, slot.boundingRect)
|
||||
constructor(slot: SubgraphIO, parent: SubgraphInputNode | SubgraphOutputNode) {
|
||||
super(slot.name, slot.type)
|
||||
|
||||
Object.assign(this, slot)
|
||||
this.id = slot.id ?? createUuidv4()
|
||||
@@ -55,10 +71,111 @@ export abstract class SubgraphSlot extends SlotBase implements SubgraphIO, Seria
|
||||
this.parent = parent
|
||||
}
|
||||
|
||||
isPointerOver: boolean = false
|
||||
|
||||
containsPoint(point: Point): boolean {
|
||||
return this.boundingRect.containsPoint(point)
|
||||
}
|
||||
|
||||
onPointerMove(e: CanvasPointerEvent): void {
|
||||
this.isPointerOver = this.boundingRect.containsXy(e.canvasX, e.canvasY)
|
||||
}
|
||||
|
||||
getLinks(): LLink[] {
|
||||
const links: LLink[] = []
|
||||
const { subgraph } = this.parent
|
||||
|
||||
for (const id of this.linkIds) {
|
||||
const link = subgraph.getLink(id)
|
||||
if (link) links.push(link)
|
||||
}
|
||||
return links
|
||||
}
|
||||
|
||||
decrementSlots(inputsOrOutputs: "inputs" | "outputs"): void {
|
||||
const { links } = this.parent.subgraph
|
||||
const linkProperty = inputsOrOutputs === "inputs" ? "origin_slot" : "target_slot"
|
||||
|
||||
for (const linkId of this.linkIds) {
|
||||
const link = links.get(linkId)
|
||||
if (link) link[linkProperty]--
|
||||
else console.warn("decrementSlots: link ID not found", linkId)
|
||||
}
|
||||
}
|
||||
|
||||
measure(): ReadOnlySize {
|
||||
const width = LGraphCanvas._measureText?.(this.displayName) ?? 0
|
||||
|
||||
const { defaultHeight } = SubgraphSlot
|
||||
this.measurement.setValues(width + defaultHeight, defaultHeight)
|
||||
return this.measurement.toSize()
|
||||
}
|
||||
|
||||
abstract arrange(rect: ReadOnlyRect): void
|
||||
|
||||
abstract connect(
|
||||
slot: INodeInputSlot | INodeOutputSlot,
|
||||
node: LGraphNode,
|
||||
afterRerouteId?: RerouteId,
|
||||
): LLink | undefined
|
||||
|
||||
/**
|
||||
* Disconnects all links connected to this slot.
|
||||
*/
|
||||
disconnect(): void {
|
||||
const { subgraph } = this.parent
|
||||
|
||||
for (const linkId of this.linkIds) {
|
||||
subgraph.removeLink(linkId)
|
||||
}
|
||||
|
||||
this.linkIds.length = 0
|
||||
}
|
||||
|
||||
/** @remarks Leaves the context dirty. */
|
||||
drawLabel(ctx: CanvasRenderingContext2D): void {
|
||||
if (!this.displayName) return
|
||||
|
||||
const [x, y] = this.labelPos
|
||||
ctx.fillStyle = this.isPointerOver ? "white" : "#AAA"
|
||||
|
||||
ctx.fillText(this.displayName, x, y)
|
||||
}
|
||||
|
||||
/** @remarks Leaves the context dirty. */
|
||||
draw({ ctx, colorContext, lowQuality }: SubgraphSlotDrawOptions): void {
|
||||
// Assertion: SlotShape is a subset of RenderShape
|
||||
const shape = this.shape as unknown as SlotShape
|
||||
const { isPointerOver, pos: [x, y] } = this
|
||||
|
||||
ctx.beginPath()
|
||||
|
||||
// Default rendering for circle, hollow circle.
|
||||
const color = this.renderingColor(colorContext)
|
||||
if (lowQuality) {
|
||||
ctx.fillStyle = color
|
||||
|
||||
ctx.rect(x - 4, y - 4, 8, 8)
|
||||
ctx.fill()
|
||||
} else if (shape === SlotShape.HollowCircle) {
|
||||
ctx.lineWidth = 3
|
||||
ctx.strokeStyle = color
|
||||
|
||||
const radius = isPointerOver ? 4 : 3
|
||||
ctx.arc(x, y, radius, 0, Math.PI * 2)
|
||||
ctx.stroke()
|
||||
} else {
|
||||
// Normal circle
|
||||
ctx.fillStyle = color
|
||||
|
||||
const radius = isPointerOver ? 5 : 4
|
||||
ctx.arc(x, y, radius, 0, Math.PI * 2)
|
||||
ctx.fill()
|
||||
}
|
||||
}
|
||||
|
||||
asSerialisable(): SubgraphIO {
|
||||
const { id, name, type, linkIds, localized_name, label, dir, shape, color_off, color_on, pos, boundingRect } = this
|
||||
return { id, name, type, linkIds, localized_name, label, dir, shape, color_off, color_on, pos, boundingRect }
|
||||
const { id, name, type, linkIds, localized_name, label, dir, shape, color_off, color_on, pos } = this
|
||||
return { id, name, type, linkIds, localized_name, label, dir, shape, color_off, color_on, pos }
|
||||
}
|
||||
}
|
||||
|
||||
338
src/subgraph/subgraphUtils.ts
Normal file
338
src/subgraph/subgraphUtils.ts
Normal file
@@ -0,0 +1,338 @@
|
||||
import type { INodeOutputSlot, Positionable } from "@/interfaces"
|
||||
import type { LGraph } from "@/LGraph"
|
||||
import type { ISerialisedNode, SerialisableLLink, SubgraphIO } from "@/types/serialisation"
|
||||
|
||||
import { SUBGRAPH_INPUT_ID, SUBGRAPH_OUTPUT_ID } from "@/constants"
|
||||
import { LGraphGroup } from "@/LGraphGroup"
|
||||
import { LGraphNode } from "@/LGraphNode"
|
||||
import { createUuidv4, LiteGraph } from "@/litegraph"
|
||||
import { LLink, type ResolvedConnection } from "@/LLink"
|
||||
import { Reroute } from "@/Reroute"
|
||||
import { nextUniqueName } from "@/strings"
|
||||
|
||||
import { SubgraphInputNode } from "./SubgraphInputNode"
|
||||
import { SubgraphOutputNode } from "./SubgraphOutputNode"
|
||||
|
||||
export interface FilteredItems {
|
||||
nodes: Set<LGraphNode>
|
||||
reroutes: Set<Reroute>
|
||||
groups: Set<LGraphGroup>
|
||||
subgraphInputNodes: Set<SubgraphInputNode>
|
||||
subgraphOutputNodes: Set<SubgraphOutputNode>
|
||||
unknown: Set<Positionable>
|
||||
}
|
||||
|
||||
export function splitPositionables(items: Iterable<Positionable>): FilteredItems {
|
||||
const nodes = new Set<LGraphNode>()
|
||||
const reroutes = new Set<Reroute>()
|
||||
const groups = new Set<LGraphGroup>()
|
||||
const subgraphInputNodes = new Set<SubgraphInputNode>()
|
||||
const subgraphOutputNodes = new Set<SubgraphOutputNode>()
|
||||
|
||||
const unknown = new Set<Positionable>()
|
||||
|
||||
for (const item of items) {
|
||||
switch (true) {
|
||||
case item instanceof LGraphNode:
|
||||
nodes.add(item)
|
||||
break
|
||||
case item instanceof LGraphGroup:
|
||||
groups.add(item)
|
||||
break
|
||||
case item instanceof Reroute:
|
||||
reroutes.add(item)
|
||||
break
|
||||
case item instanceof SubgraphInputNode:
|
||||
subgraphInputNodes.add(item)
|
||||
break
|
||||
case item instanceof SubgraphOutputNode:
|
||||
subgraphOutputNodes.add(item)
|
||||
break
|
||||
default:
|
||||
unknown.add(item)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
nodes,
|
||||
reroutes,
|
||||
groups,
|
||||
subgraphInputNodes,
|
||||
subgraphOutputNodes,
|
||||
unknown,
|
||||
}
|
||||
}
|
||||
|
||||
interface BoundaryLinks {
|
||||
boundaryLinks: LLink[]
|
||||
boundaryFloatingLinks: LLink[]
|
||||
internalLinks: LLink[]
|
||||
boundaryInputLinks: LLink[]
|
||||
boundaryOutputLinks: LLink[]
|
||||
}
|
||||
|
||||
export function getBoundaryLinks(graph: LGraph, items: Set<Positionable>): BoundaryLinks {
|
||||
const internalLinks: LLink[] = []
|
||||
const boundaryLinks: LLink[] = []
|
||||
const boundaryInputLinks: LLink[] = []
|
||||
const boundaryOutputLinks: LLink[] = []
|
||||
const boundaryFloatingLinks: LLink[] = []
|
||||
const visited = new WeakSet<Positionable>()
|
||||
|
||||
for (const item of items) {
|
||||
if (visited.has(item)) continue
|
||||
visited.add(item)
|
||||
|
||||
// Nodes
|
||||
if (item instanceof LGraphNode) {
|
||||
const node = item
|
||||
|
||||
// Inputs
|
||||
if (node.inputs) {
|
||||
for (const input of node.inputs) {
|
||||
addFloatingLinks(input._floatingLinks)
|
||||
|
||||
if (input.link == null) continue
|
||||
|
||||
const resolved = LLink.resolve(input.link, graph)
|
||||
if (!resolved) {
|
||||
console.debug(`Failed to resolve link ID [${input.link}]`)
|
||||
continue
|
||||
}
|
||||
|
||||
// Output end of this link is outside the items set
|
||||
const { link, outputNode } = resolved
|
||||
if (outputNode) {
|
||||
if (!items.has(outputNode)) {
|
||||
boundaryInputLinks.push(link)
|
||||
} else {
|
||||
internalLinks.push(link)
|
||||
}
|
||||
} else if (link.origin_id === SUBGRAPH_INPUT_ID) {
|
||||
// Subgraph input node - always boundary
|
||||
boundaryInputLinks.push(link)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Outputs
|
||||
if (node.outputs) {
|
||||
for (const output of node.outputs) {
|
||||
addFloatingLinks(output._floatingLinks)
|
||||
|
||||
if (!output.links) continue
|
||||
|
||||
const many = LLink.resolveMany(output.links, graph)
|
||||
for (const { link, inputNode } of many) {
|
||||
if (
|
||||
// Subgraph output node
|
||||
link.target_id === SUBGRAPH_OUTPUT_ID ||
|
||||
// Input end of this link is outside the items set
|
||||
(inputNode && !items.has(inputNode))
|
||||
) {
|
||||
boundaryOutputLinks.push(link)
|
||||
}
|
||||
// Internal links are discovered on input side.
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (item instanceof Reroute) {
|
||||
// Reroutes
|
||||
const reroute = item
|
||||
|
||||
// TODO: This reroute should be on one side of the boundary. We should mark the reroute that is on each side of the boundary.
|
||||
// TODO: This could occur any number of times on a link; each time should be marked as a separate boundary.
|
||||
// TODO: e.g. A link with 3 reroutes, the first and last reroute are in `items`, but the middle reroute is not. This will be two "in" and two "out" boundaries.
|
||||
const results = LLink.resolveMany(reroute.linkIds, graph)
|
||||
for (const { link } of results) {
|
||||
const reroutes = LLink.getReroutes(graph, link)
|
||||
const reroutesOutside = reroutes.filter(reroute => !items.has(reroute))
|
||||
|
||||
// for (const reroute of reroutes) {
|
||||
// // TODO: Do the checks here.
|
||||
// }
|
||||
|
||||
const { inputNode, outputNode } = link.resolve(graph)
|
||||
|
||||
if (
|
||||
reroutesOutside.length ||
|
||||
(inputNode && !items.has(inputNode)) ||
|
||||
(outputNode && !items.has(outputNode))
|
||||
) {
|
||||
boundaryLinks.push(link)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { boundaryLinks, boundaryFloatingLinks, internalLinks, boundaryInputLinks, boundaryOutputLinks }
|
||||
|
||||
/**
|
||||
* Adds any floating links that cross the boundary.
|
||||
* @param floatingLinks The floating links to check
|
||||
*/
|
||||
function addFloatingLinks(floatingLinks: Set<LLink> | undefined): void {
|
||||
if (!floatingLinks) return
|
||||
|
||||
for (const link of floatingLinks) {
|
||||
const crossesBoundary = LLink
|
||||
.getReroutes(graph, link)
|
||||
.some(reroute => !items.has(reroute))
|
||||
|
||||
if (crossesBoundary) boundaryFloatingLinks.push(link)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function multiClone(nodes: Iterable<LGraphNode>): ISerialisedNode[] {
|
||||
const clonedNodes: ISerialisedNode[] = []
|
||||
|
||||
// Selectively clone - keep IDs & links
|
||||
for (const node of nodes) {
|
||||
const newNode = LiteGraph.createNode(node.type)
|
||||
if (!newNode) {
|
||||
console.warn("Failed to create node", node.type)
|
||||
continue
|
||||
}
|
||||
|
||||
// Must be cloned; litegraph "serialize" is mostly shallow clone
|
||||
const data = LiteGraph.cloneObject(node.serialize())
|
||||
newNode.configure(data)
|
||||
|
||||
clonedNodes.push(newNode.serialize())
|
||||
}
|
||||
|
||||
return clonedNodes
|
||||
}
|
||||
|
||||
/**
|
||||
* Groups resolved connections by output object. If the output is nullish, the connection will be in its own group.
|
||||
* @param resolvedConnections The resolved connections to group
|
||||
* @returns A map of grouped connections.
|
||||
*/
|
||||
export function groupResolvedByOutput(
|
||||
resolvedConnections: ResolvedConnection[],
|
||||
): Map<SubgraphIO | INodeOutputSlot | object, ResolvedConnection[]> {
|
||||
const groupedByOutput: ReturnType<typeof groupResolvedByOutput> = new Map()
|
||||
|
||||
for (const resolved of resolvedConnections) {
|
||||
// Force no group (unique object) if output is undefined; corruption or an error has occurred
|
||||
const groupBy = resolved.subgraphInput ?? resolved.output ?? {}
|
||||
const group = groupedByOutput.get(groupBy)
|
||||
if (group) {
|
||||
group.push(resolved)
|
||||
} else {
|
||||
groupedByOutput.set(groupBy, [resolved])
|
||||
}
|
||||
}
|
||||
|
||||
return groupedByOutput
|
||||
}
|
||||
|
||||
export function mapSubgraphInputsAndLinks(resolvedInputLinks: ResolvedConnection[], links: SerialisableLLink[]): SubgraphIO[] {
|
||||
// Group matching links
|
||||
const groupedByOutput = groupResolvedByOutput(resolvedInputLinks)
|
||||
|
||||
// Create one input for each output (outside subgraph)
|
||||
const inputs: SubgraphIO[] = []
|
||||
|
||||
for (const [, connections] of groupedByOutput) {
|
||||
const inputLinks: SerialisableLLink[] = []
|
||||
|
||||
// Create serialised links for all links (will be recreated in subgraph)
|
||||
for (const resolved of connections) {
|
||||
const { link, input } = resolved
|
||||
if (!input) continue
|
||||
|
||||
const linkData = link.asSerialisable()
|
||||
linkData.origin_id = SUBGRAPH_INPUT_ID
|
||||
linkData.origin_slot = inputs.length
|
||||
links.push(linkData)
|
||||
inputLinks.push(linkData)
|
||||
}
|
||||
|
||||
// Use first input link
|
||||
const { input } = connections[0]
|
||||
if (!input) continue
|
||||
|
||||
// Subgraph input slot
|
||||
const { color_off, color_on, dir, hasErrors, label, localized_name, name, shape, type } = input
|
||||
const uniqueName = nextUniqueName(name, inputs.map(input => input.name))
|
||||
const uniqueLocalizedName = localized_name ? nextUniqueName(localized_name, inputs.map(input => input.localized_name ?? "")) : undefined
|
||||
|
||||
const inputData: SubgraphIO = {
|
||||
id: createUuidv4(),
|
||||
type: String(type),
|
||||
linkIds: inputLinks.map(link => link.id),
|
||||
name: uniqueName,
|
||||
color_off,
|
||||
color_on,
|
||||
dir,
|
||||
label,
|
||||
localized_name: uniqueLocalizedName,
|
||||
hasErrors,
|
||||
shape,
|
||||
}
|
||||
|
||||
inputs.push(inputData)
|
||||
}
|
||||
|
||||
return inputs
|
||||
}
|
||||
|
||||
/**
|
||||
* Clones the output slots, and updates existing links, when converting items to a subgraph.
|
||||
* @param resolvedOutputLinks The resolved output links.
|
||||
* @param links The links to add to the subgraph.
|
||||
* @returns The subgraph output slots.
|
||||
*/
|
||||
export function mapSubgraphOutputsAndLinks(resolvedOutputLinks: ResolvedConnection[], links: SerialisableLLink[]): SubgraphIO[] {
|
||||
// Group matching links
|
||||
const groupedByOutput = groupResolvedByOutput(resolvedOutputLinks)
|
||||
|
||||
const outputs: SubgraphIO[] = []
|
||||
|
||||
for (const [, connections] of groupedByOutput) {
|
||||
const outputLinks: SerialisableLLink[] = []
|
||||
|
||||
// Create serialised links for all links (will be recreated in subgraph)
|
||||
for (const resolved of connections) {
|
||||
const { link, output } = resolved
|
||||
if (!output) continue
|
||||
|
||||
// Link
|
||||
const linkData = link.asSerialisable()
|
||||
linkData.target_id = SUBGRAPH_OUTPUT_ID
|
||||
linkData.target_slot = outputs.length
|
||||
links.push(linkData)
|
||||
outputLinks.push(linkData)
|
||||
}
|
||||
|
||||
// Use first output link
|
||||
const { output } = connections[0]
|
||||
if (!output) continue
|
||||
|
||||
// Subgraph output slot
|
||||
const { color_off, color_on, dir, hasErrors, label, localized_name, name, shape, type } = output
|
||||
const uniqueName = nextUniqueName(name, outputs.map(output => output.name))
|
||||
const uniqueLocalizedName = localized_name ? nextUniqueName(localized_name, outputs.map(output => output.localized_name ?? "")) : undefined
|
||||
|
||||
const outputData = {
|
||||
id: createUuidv4(),
|
||||
type: String(type),
|
||||
linkIds: outputLinks.map(link => link.id),
|
||||
name: uniqueName,
|
||||
color_off,
|
||||
color_on,
|
||||
dir,
|
||||
label,
|
||||
localized_name: uniqueLocalizedName,
|
||||
hasErrors,
|
||||
shape,
|
||||
} satisfies SubgraphIO
|
||||
|
||||
outputs.push(structuredClone(outputData))
|
||||
}
|
||||
return outputs
|
||||
}
|
||||
Reference in New Issue
Block a user