diff --git a/src/LGraphCanvas.ts b/src/LGraphCanvas.ts index d8e5ac807..c1290cf54 100644 --- a/src/LGraphCanvas.ts +++ b/src/LGraphCanvas.ts @@ -4672,9 +4672,10 @@ export class LGraphCanvas implements ConnectionColorContext { // render inputs and outputs if (!node.collapsed) { - const max_y = node.drawSlots(ctx, { - colorContext: this, + node.layoutSlots() + node.drawSlots(ctx, { connectingLink: this.connecting_links?.[0], + colorContext: this, editorAlpha: this.editor_alpha, lowQuality: this.low_quality, }) @@ -4682,6 +4683,11 @@ export class LGraphCanvas implements ConnectionColorContext { ctx.textAlign = "left" ctx.globalAlpha = 1 + const slotsBounds = createBounds( + node.slots.map(slot => slot._layoutElement), + /** padding= */ 0, + ) + const max_y = slotsBounds ? slotsBounds[1] + slotsBounds[3] : 0 this.drawNodeWidgets(node, max_y, ctx) } else if (this.render_collapsed_slots) { node.drawCollapsedSlots(ctx) diff --git a/src/LGraphNode.ts b/src/LGraphNode.ts index 2b6523ba8..0f4a7268f 100644 --- a/src/LGraphNode.ts +++ b/src/LGraphNode.ts @@ -8,6 +8,7 @@ import type { INodeFlags, INodeInputSlot, INodeOutputSlot, + INodeSlot, IOptionalSlotData, IPinnable, ISlotType, @@ -34,9 +35,10 @@ import { BadgePosition, LGraphBadge } from "./LGraphBadge" import { type LGraphNodeConstructor, LiteGraph } from "./litegraph" import { isInRectangle, isInRect, snapPoint } from "./measure" import { LLink } from "./LLink" -import { ConnectionColorContext, NodeInputSlot, NodeOutputSlot } from "./NodeSlot" +import { ConnectionColorContext, isINodeInputSlot, NodeInputSlot, NodeOutputSlot, serializeSlot, toNodeSlotClass } from "./NodeSlot" import { WIDGET_TYPE_MAP } from "./widgets/widgetMap" import { toClass } from "./utils/type" +import { LayoutElement } from "./utils/layout" export type NodeId = number | string export interface INodePropertyInfo { @@ -666,15 +668,8 @@ export class LGraphNode implements Positionable, IPinnable { if (this.constructor === LGraphNode && this.last_serialization) return this.last_serialization - if (this.inputs) o.inputs = this.inputs - - if (this.outputs) { - // clear outputs last data (because data in connections is never serialized but stored inside the outputs info) - for (let i = 0; i < this.outputs.length; i++) { - delete this.outputs[i]._data - } - o.outputs = this.outputs - } + if (this.inputs) o.inputs = this.inputs.map(serializeSlot) + if (this.outputs) o.outputs = this.outputs.map(serializeSlot) if (this.title && this.title != this.constructor.title) o.title = this.title @@ -3180,78 +3175,84 @@ export class LGraphNode implements Positionable, IPinnable { return LiteGraph.NODE_TEXT_HIGHLIGHT_COLOR ?? LiteGraph.NODE_SELECTED_TITLE_COLOR ?? LiteGraph.NODE_TEXT_COLOR } + get slots(): INodeSlot[] { + return [...this.inputs, ...this.outputs] + } + + layoutSlot(slot: INodeSlot, options: { + slotIndex: number + }): void { + const { slotIndex } = options + const isInput = isINodeInputSlot(slot) + const pos = this.getConnectionPos(isInput, slotIndex) + + slot._layoutElement = new LayoutElement({ + value: slot, + boundingRect: [ + pos[0] - this.pos[0] - LiteGraph.NODE_SLOT_HEIGHT * 0.5, + pos[1] - this.pos[1] - LiteGraph.NODE_SLOT_HEIGHT * 0.5, + LiteGraph.NODE_SLOT_HEIGHT, + LiteGraph.NODE_SLOT_HEIGHT, + ], + }) + } + + layoutSlots(): void { + for (const [i, slot] of this.inputs.entries()) { + this.layoutSlot(slot, { + slotIndex: i, + }) + } + for (const [i, slot] of this.outputs.entries()) { + this.layoutSlot(slot, { + slotIndex: i, + }) + } + } + + #getMouseOverSlot(slot: INodeSlot): INodeSlot | null { + const isInput = isINodeInputSlot(slot) + const mouseOverId = this.mouseOver?.[isInput ? "inputId" : "outputId"] ?? -1 + if (mouseOverId === -1) { + return null + } + return isInput ? this.inputs[mouseOverId] : this.outputs[mouseOverId] + } + + #isMouseOverSlot(slot: INodeSlot): boolean { + return this.#getMouseOverSlot(slot) === slot + } + /** * Draws the node's input and output slots. - * @returns The maximum y-coordinate of the slots. - * TODO: Calculate the bounding box of the slots and return it instead of the maximum y-coordinate. */ drawSlots(ctx: CanvasRenderingContext2D, options: { - colorContext: ConnectionColorContext connectingLink: ConnectingLink | null + colorContext: ConnectionColorContext editorAlpha: number lowQuality: boolean - }): number { - const { colorContext, connectingLink, editorAlpha, lowQuality } = options - let max_y = 0 - - // input connection slots - // Reuse slot_pos to avoid creating a new Float32Array on each iteration - const slot_pos = new Float32Array(2) - for (const [i, input] of (this.inputs ?? []).entries()) { - const slot = toClass(NodeInputSlot, input) + }) { + const { connectingLink, colorContext, editorAlpha, lowQuality } = options + for (const slot of this.slots) { // change opacity of incompatible slots when dragging a connection - const isValid = slot.isValidTarget(connectingLink) - const highlight = isValid && this.mouseOver?.inputId === i - const label_color = highlight + const layoutElement = slot._layoutElement + const slotInstance = toNodeSlotClass(slot) + const isValid = slotInstance.isValidTarget(connectingLink) + const highlight = isValid && this.#isMouseOverSlot(slot) + const labelColor = highlight ? this.highlightColor : LiteGraph.NODE_TEXT_COLOR ctx.globalAlpha = isValid ? editorAlpha : 0.4 * editorAlpha - const pos = this.getConnectionPos(true, i, /* out= */slot_pos) - pos[0] -= this.pos[0] - pos[1] -= this.pos[1] - - max_y = Math.max(max_y, pos[1] + LiteGraph.NODE_SLOT_HEIGHT * 0.5) - - slot.draw(ctx, { - pos, + slotInstance.draw(ctx, { + pos: layoutElement.center, colorContext, - labelColor: label_color, + labelColor, lowQuality, renderText: !lowQuality, highlight, }) } - - // output connection slots - for (const [i, output] of (this.outputs ?? []).entries()) { - const slot = toClass(NodeOutputSlot, output) - - // change opacity of incompatible slots when dragging a connection - const isValid = slot.isValidTarget(connectingLink) - const highlight = isValid && this.mouseOver?.outputId === i - const label_color = highlight - ? this.highlightColor - : LiteGraph.NODE_TEXT_COLOR - ctx.globalAlpha = isValid ? editorAlpha : 0.4 * editorAlpha - - const pos = this.getConnectionPos(false, i, /* out= */slot_pos) - pos[0] -= this.pos[0] - pos[1] -= this.pos[1] - - max_y = Math.max(max_y, pos[1] + LiteGraph.NODE_SLOT_HEIGHT * 0.5) - - slot.draw(ctx, { - pos, - colorContext, - labelColor: label_color, - lowQuality, - renderText: !lowQuality, - highlight, - }) - } - - return max_y } } diff --git a/src/NodeSlot.ts b/src/NodeSlot.ts index 941c56d7a..241c02ae3 100644 --- a/src/NodeSlot.ts +++ b/src/NodeSlot.ts @@ -27,6 +27,24 @@ interface IDrawOptions { highlight?: boolean } +export function serializeSlot(slot: T): T { + const serialized = { ...slot } + delete serialized._layoutElement + if ("_data" in serialized) { + delete serialized._data + } + return serialized +} + +export function toNodeSlotClass(slot: INodeSlot): NodeSlot { + if (isINodeInputSlot(slot)) { + return new NodeInputSlot(slot) + } else if (isINodeOutputSlot(slot)) { + return new NodeOutputSlot(slot) + } + throw new Error("Invalid slot type") +} + export abstract class NodeSlot implements INodeSlot { name: string localized_name?: string @@ -222,6 +240,10 @@ export abstract class NodeSlot implements INodeSlot { } } +export function isINodeInputSlot(slot: INodeSlot): slot is INodeInputSlot { + return "link" in slot +} + export class NodeInputSlot extends NodeSlot implements INodeInputSlot { link: LinkId | null @@ -254,6 +276,10 @@ export class NodeInputSlot extends NodeSlot implements INodeInputSlot { } } +export function isINodeOutputSlot(slot: INodeSlot): slot is INodeOutputSlot { + return "links" in slot +} + export class NodeOutputSlot extends NodeSlot implements INodeOutputSlot { links: LinkId[] | null _data?: unknown diff --git a/src/interfaces.ts b/src/interfaces.ts index 4adfdbeec..35cbf064c 100644 --- a/src/interfaces.ts +++ b/src/interfaces.ts @@ -4,6 +4,7 @@ import type { LinkDirection, RenderShape } from "./types/globalEnums" import type { LinkId, LLink } from "./LLink" import type { Reroute, RerouteId } from "./Reroute" import type { IWidget } from "./types/widgets" +import type { LayoutElement } from "./utils/layout" export type Dictionary = { [key: string]: T } @@ -225,6 +226,12 @@ export interface INodeSlot { * for more information. */ widget?: IWidget + + /** + * A layout element that is used internally to position the slot. + * Set by {@link LGraphNode.#layoutSlots}. + */ + _layoutElement?: LayoutElement } export interface INodeFlags { @@ -238,12 +245,14 @@ export interface INodeFlags { export interface INodeInputSlot extends INodeSlot { link: LinkId | null + _layoutElement?: LayoutElement } export interface INodeOutputSlot extends INodeSlot { links: LinkId[] | null _data?: unknown slot_index?: number + _layoutElement?: LayoutElement } /** Links */ diff --git a/src/measure.ts b/src/measure.ts index 6117e7952..e2016536a 100644 --- a/src/measure.ts +++ b/src/measure.ts @@ -1,7 +1,6 @@ // @ts-strict-ignore import type { Point, - Positionable, ReadOnlyPoint, ReadOnlyRect, Rect, @@ -331,7 +330,7 @@ export function findPointOnCurve( } export function createBounds( - objects: Iterable, + objects: Iterable<{ boundingRect: ReadOnlyRect }>, padding: number = 10, ): ReadOnlyRect | null { const bounds = new Float32Array([Infinity, Infinity, -Infinity, -Infinity]) diff --git a/src/utils/layout.ts b/src/utils/layout.ts new file mode 100644 index 000000000..8655d6bd6 --- /dev/null +++ b/src/utils/layout.ts @@ -0,0 +1,21 @@ +import { Point, ReadOnlyRect } from "@/interfaces" + +export class LayoutElement { + public readonly value: T + public readonly boundingRect: ReadOnlyRect + + constructor(o: { + value: T + boundingRect: ReadOnlyRect + }) { + this.value = o.value + this.boundingRect = o.boundingRect + } + + get center(): Point { + return [ + this.boundingRect[0] + this.boundingRect[2] / 2, + this.boundingRect[1] + this.boundingRect[3] / 2, + ] + } +}