From 6fa2e8e3ca0dbf865156e8804f707b31f55e8e46 Mon Sep 17 00:00:00 2001 From: Benjamin Lu Date: Fri, 1 Aug 2025 18:38:57 -0400 Subject: [PATCH] Add slot compatibility checking for subgraph slots (#1182) --- src/LGraphCanvas.ts | 2 +- src/node/NodeInputSlot.ts | 15 +- src/node/NodeOutputSlot.ts | 15 +- src/node/NodeSlot.ts | 4 +- src/subgraph/Subgraph.ts | 8 +- src/subgraph/SubgraphIONodeBase.ts | 13 +- src/subgraph/SubgraphInput.ts | 22 +- src/subgraph/SubgraphInputNode.ts | 7 +- src/subgraph/SubgraphOutput.ts | 25 +- src/subgraph/SubgraphOutputNode.ts | 8 +- src/subgraph/SubgraphSlotBase.ts | 49 ++- src/subgraph/subgraphUtils.ts | 34 ++- test/subgraph/SubgraphSlotConnections.test.ts | 280 ++++++++++++++++++ .../SubgraphSlotVisualFeedback.test.ts | 181 +++++++++++ 14 files changed, 624 insertions(+), 39 deletions(-) create mode 100644 test/subgraph/SubgraphSlotConnections.test.ts create mode 100644 test/subgraph/SubgraphSlotVisualFeedback.test.ts diff --git a/src/LGraphCanvas.ts b/src/LGraphCanvas.ts index a644c3273..6fdd8ca24 100644 --- a/src/LGraphCanvas.ts +++ b/src/LGraphCanvas.ts @@ -4212,7 +4212,7 @@ export class LGraphCanvas implements CustomEventDispatcher } // Draw subgraph IO nodes - this.subgraph?.draw(ctx, this.colourGetter) + this.subgraph?.draw(ctx, this.colourGetter, this.linkConnector.renderLinks[0]?.fromSlot, this.editor_alpha) // on top (debug) if (this.render_execution_order) { diff --git a/src/node/NodeInputSlot.ts b/src/node/NodeInputSlot.ts index bc83d6651..d741dd4cc 100644 --- a/src/node/NodeInputSlot.ts +++ b/src/node/NodeInputSlot.ts @@ -1,11 +1,14 @@ import type { INodeInputSlot, INodeOutputSlot, OptionalProps, ReadOnlyPoint } from "@/interfaces" import type { LGraphNode } from "@/LGraphNode" import type { LinkId } from "@/LLink" +import type { SubgraphInput } from "@/subgraph/SubgraphInput" +import type { SubgraphOutput } from "@/subgraph/SubgraphOutput" import type { IBaseWidget } from "@/types/widgets" import { LabelPosition } from "@/draw" import { LiteGraph } from "@/litegraph" import { type IDrawOptions, NodeSlot } from "@/node/NodeSlot" +import { isSubgraphInput } from "@/subgraph/subgraphUtils" export class NodeInputSlot extends NodeSlot implements INodeInputSlot { link: LinkId | null @@ -38,8 +41,16 @@ export class NodeInputSlot extends NodeSlot implements INodeInputSlot { return this.link != null } - override isValidTarget(fromSlot: INodeInputSlot | INodeOutputSlot): boolean { - return "links" in fromSlot && LiteGraph.isValidConnection(this.type, fromSlot.type) + override isValidTarget(fromSlot: INodeInputSlot | INodeOutputSlot | SubgraphInput | SubgraphOutput): boolean { + if ("links" in fromSlot) { + return LiteGraph.isValidConnection(fromSlot.type, this.type) + } + + if (isSubgraphInput(fromSlot)) { + return LiteGraph.isValidConnection(fromSlot.type, this.type) + } + + return false } override draw(ctx: CanvasRenderingContext2D, options: Omit) { diff --git a/src/node/NodeOutputSlot.ts b/src/node/NodeOutputSlot.ts index 1afd427f6..6cf20e573 100644 --- a/src/node/NodeOutputSlot.ts +++ b/src/node/NodeOutputSlot.ts @@ -1,10 +1,13 @@ import type { INodeInputSlot, INodeOutputSlot, OptionalProps, ReadOnlyPoint } from "@/interfaces" import type { LGraphNode } from "@/LGraphNode" import type { LinkId } from "@/LLink" +import type { SubgraphInput } from "@/subgraph/SubgraphInput" +import type { SubgraphOutput } from "@/subgraph/SubgraphOutput" import { LabelPosition } from "@/draw" import { LiteGraph } from "@/litegraph" import { type IDrawOptions, NodeSlot } from "@/node/NodeSlot" +import { isSubgraphOutput } from "@/subgraph/subgraphUtils" export class NodeOutputSlot extends NodeSlot implements INodeOutputSlot { #node: LGraphNode @@ -32,8 +35,16 @@ export class NodeOutputSlot extends NodeSlot implements INodeOutputSlot { this.#node = node } - override isValidTarget(fromSlot: INodeInputSlot | INodeOutputSlot): boolean { - return "link" in fromSlot && LiteGraph.isValidConnection(this.type, fromSlot.type) + override isValidTarget(fromSlot: INodeInputSlot | INodeOutputSlot | SubgraphInput | SubgraphOutput): boolean { + if ("link" in fromSlot) { + return LiteGraph.isValidConnection(this.type, fromSlot.type) + } + + if (isSubgraphOutput(fromSlot)) { + return LiteGraph.isValidConnection(this.type, fromSlot.type) + } + + return false } override get isConnected(): boolean { diff --git a/src/node/NodeSlot.ts b/src/node/NodeSlot.ts index 849bd5676..46d9d2f76 100644 --- a/src/node/NodeSlot.ts +++ b/src/node/NodeSlot.ts @@ -1,5 +1,7 @@ import type { CanvasColour, DefaultConnectionColors, INodeInputSlot, INodeOutputSlot, INodeSlot, ISubgraphInput, OptionalProps, Point, ReadOnlyPoint } from "@/interfaces" import type { LGraphNode } from "@/LGraphNode" +import type { SubgraphInput } from "@/subgraph/SubgraphInput" +import type { SubgraphOutput } from "@/subgraph/SubgraphOutput" import { LabelPosition, SlotShape, SlotType } from "@/draw" import { LiteGraph, Rectangle } from "@/litegraph" @@ -68,7 +70,7 @@ export abstract class NodeSlot extends SlotBase implements INodeSlot { * Whether this slot is a valid target for a dragging link. * @param fromSlot The slot that the link is being connected from. */ - abstract isValidTarget(fromSlot: INodeInputSlot | INodeOutputSlot): boolean + abstract isValidTarget(fromSlot: INodeInputSlot | INodeOutputSlot | SubgraphInput | SubgraphOutput): boolean /** * The label to display in the UI. diff --git a/src/subgraph/Subgraph.ts b/src/subgraph/Subgraph.ts index 3801bd31a..68facc3a9 100644 --- a/src/subgraph/Subgraph.ts +++ b/src/subgraph/Subgraph.ts @@ -1,5 +1,5 @@ import type { SubgraphEventMap } from "@/infrastructure/SubgraphEventMap" -import type { DefaultConnectionColors } from "@/interfaces" +import type { DefaultConnectionColors, INodeInputSlot, INodeOutputSlot } from "@/interfaces" import type { LGraphCanvas } from "@/LGraphCanvas" import type { ExportedSubgraph, ExposedWidget, ISerialisedGraph, Serialisable, SerialisableGraph } from "@/types/serialisation" @@ -206,9 +206,9 @@ export class Subgraph extends LGraph implements BaseLGraph, Serialisable impleme ) } - override drawProtected(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors): void { + override drawProtected(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors, fromSlot?: INodeInputSlot | INodeOutputSlot | SubgraphInput | SubgraphOutput, editorAlpha?: number): void { const { roundedRadius } = SubgraphIONodeBase const transform = ctx.getTransform() @@ -194,6 +195,6 @@ export class SubgraphInputNode extends SubgraphIONodeBase impleme // Restore context ctx.setTransform(transform) - this.drawSlots(ctx, colorContext) + this.drawSlots(ctx, colorContext, fromSlot, editorAlpha) } } diff --git a/src/subgraph/SubgraphOutput.ts b/src/subgraph/SubgraphOutput.ts index 1d4224512..1a67d75ab 100644 --- a/src/subgraph/SubgraphOutput.ts +++ b/src/subgraph/SubgraphOutput.ts @@ -1,13 +1,16 @@ +import type { SubgraphInput } from "./SubgraphInput" import type { SubgraphOutputNode } from "./SubgraphOutputNode" -import type { INodeOutputSlot, Point, ReadOnlyRect } from "@/interfaces" +import type { INodeInputSlot, INodeOutputSlot, Point, ReadOnlyRect } from "@/interfaces" import type { LGraphNode } from "@/LGraphNode" import type { RerouteId } from "@/Reroute" +import { LiteGraph } from "@/litegraph" import { LLink } from "@/LLink" import { NodeSlotType } from "@/types/globalEnums" import { removeFromArray } from "@/utils/collections" import { SubgraphSlot } from "./SubgraphSlotBase" +import { isNodeSlot, isSubgraphInput } from "./subgraphUtils" /** * An output "slot" from a subgraph to a parent graph. @@ -26,6 +29,9 @@ export class SubgraphOutput extends SubgraphSlot { override connect(slot: INodeOutputSlot, node: LGraphNode, afterRerouteId?: RerouteId): LLink | undefined { const { subgraph } = this.parent + // Validate type compatibility + if (!LiteGraph.isValidConnection(slot.type, this.type)) return + // Allow nodes to block connection const outputIndex = node.outputs.indexOf(slot) if (outputIndex === -1) throw new Error("Slot is not an output of the given node") @@ -111,4 +117,21 @@ export class SubgraphOutput extends SubgraphSlot { pos[0] = left + height * 0.5 pos[1] = top + height * 0.5 } + + /** + * Checks if this slot is a valid target for a connection from the given slot. + * For SubgraphOutput (which acts as an input inside the subgraph), + * the fromSlot should be an output slot. + */ + override isValidTarget(fromSlot: INodeInputSlot | INodeOutputSlot | SubgraphInput | SubgraphOutput): boolean { + if (isNodeSlot(fromSlot)) { + return "links" in fromSlot && LiteGraph.isValidConnection(fromSlot.type, this.type) + } + + if (isSubgraphInput(fromSlot)) { + return LiteGraph.isValidConnection(fromSlot.type, this.type) + } + + return false + } } diff --git a/src/subgraph/SubgraphOutputNode.ts b/src/subgraph/SubgraphOutputNode.ts index 1328e8e40..60786dbe2 100644 --- a/src/subgraph/SubgraphOutputNode.ts +++ b/src/subgraph/SubgraphOutputNode.ts @@ -1,8 +1,8 @@ +import type { SubgraphInput } from "./SubgraphInput" import type { SubgraphOutput } from "./SubgraphOutput" import type { LinkConnector } from "@/canvas/LinkConnector" import type { CanvasPointer } from "@/CanvasPointer" -import type { DefaultConnectionColors, ISlotType, Positionable } from "@/interfaces" -import type { INodeOutputSlot } from "@/interfaces" +import type { DefaultConnectionColors, INodeInputSlot, INodeOutputSlot, ISlotType, Positionable } from "@/interfaces" import type { LGraphNode, NodeId } from "@/LGraphNode" import type { LLink } from "@/LLink" import type { RerouteId } from "@/Reroute" @@ -90,7 +90,7 @@ export class SubgraphOutputNode extends SubgraphIONodeBase imple return findFreeSlotOfType(this.slots, type, slot => slot.linkIds.length > 0)?.slot } - override drawProtected(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors): void { + override drawProtected(ctx: CanvasRenderingContext2D, colorContext: DefaultConnectionColors, fromSlot?: INodeInputSlot | INodeOutputSlot | SubgraphInput | SubgraphOutput, editorAlpha?: number): void { const { roundedRadius } = SubgraphIONodeBase const transform = ctx.getTransform() @@ -114,6 +114,6 @@ export class SubgraphOutputNode extends SubgraphIONodeBase imple // Restore context ctx.setTransform(transform) - this.drawSlots(ctx, colorContext) + this.drawSlots(ctx, colorContext, fromSlot, editorAlpha) } } diff --git a/src/subgraph/SubgraphSlotBase.ts b/src/subgraph/SubgraphSlotBase.ts index 2b1b3a130..ed49e611b 100644 --- a/src/subgraph/SubgraphSlotBase.ts +++ b/src/subgraph/SubgraphSlotBase.ts @@ -1,4 +1,6 @@ +import type { SubgraphInput } from "./SubgraphInput" import type { SubgraphInputNode } from "./SubgraphInputNode" +import type { SubgraphOutput } from "./SubgraphOutput" import type { SubgraphOutputNode } from "./SubgraphOutputNode" import type { DefaultConnectionColors, Hoverable, INodeInputSlot, INodeOutputSlot, Point, ReadOnlyRect, ReadOnlySize } from "@/interfaces" import type { LGraphNode } from "@/LGraphNode" @@ -19,6 +21,8 @@ export interface SubgraphSlotDrawOptions { ctx: CanvasRenderingContext2D colorContext: DefaultConnectionColors lowQuality?: boolean + fromSlot?: INodeInputSlot | INodeOutputSlot | SubgraphInput | SubgraphOutput + editorAlpha?: number } /** Shared base class for the slots used on Subgraph . */ @@ -132,22 +136,32 @@ export abstract class SubgraphSlot extends SlotBase implements SubgraphIO, Hover this.linkIds.length = 0 } - /** @remarks Leaves the context dirty. */ - drawLabel(ctx: CanvasRenderingContext2D): void { - if (!this.displayName) return - - const [x, y] = this.labelPos - ctx.fillStyle = this.isPointerOver ? "white" : (LiteGraph.NODE_TEXT_COLOR || "#AAA") - - ctx.fillText(this.displayName, x, y) - } + /** + * Checks if this slot is a valid target for a connection from the given slot. + * @param fromSlot The slot that is being dragged to connect to this slot. + * @returns true if the connection is valid, false otherwise. + */ + abstract isValidTarget(fromSlot: INodeInputSlot | INodeOutputSlot | SubgraphInput | SubgraphOutput): boolean /** @remarks Leaves the context dirty. */ - draw({ ctx, colorContext, lowQuality }: SubgraphSlotDrawOptions): void { + draw({ ctx, colorContext, lowQuality, fromSlot, editorAlpha = 1 }: SubgraphSlotDrawOptions): void { // Assertion: SlotShape is a subset of RenderShape const shape = this.shape as unknown as SlotShape const { isPointerOver, pos: [x, y] } = this + // Check if this slot is a valid target for the current dragging connection + const isValidTarget = fromSlot ? this.isValidTarget(fromSlot) : true + const isValid = !fromSlot || isValidTarget + + // Only highlight if the slot is valid AND mouse is over it + const highlight = isValid && isPointerOver + + // Save current alpha + const previousAlpha = ctx.globalAlpha + + // Set opacity based on validity when dragging a connection + ctx.globalAlpha = isValid ? editorAlpha : 0.4 * editorAlpha + ctx.beginPath() // Default rendering for circle, hollow circle. @@ -161,17 +175,28 @@ export abstract class SubgraphSlot extends SlotBase implements SubgraphIO, Hover ctx.lineWidth = 3 ctx.strokeStyle = color - const radius = isPointerOver ? 4 : 3 + const radius = highlight ? 4 : 3 ctx.arc(x, y, radius, 0, Math.PI * 2) ctx.stroke() } else { // Normal circle ctx.fillStyle = color - const radius = isPointerOver ? 5 : 4 + const radius = highlight ? 5 : 4 ctx.arc(x, y, radius, 0, Math.PI * 2) ctx.fill() } + + // Draw label with current opacity + if (this.displayName) { + const [labelX, labelY] = this.labelPos + // Also apply highlight logic to text color + ctx.fillStyle = highlight ? "white" : (LiteGraph.NODE_TEXT_COLOR || "#AAA") + ctx.fillText(this.displayName, labelX, labelY) + } + + // Restore alpha + ctx.globalAlpha = previousAlpha } asSerialisable(): SubgraphIO { diff --git a/src/subgraph/subgraphUtils.ts b/src/subgraph/subgraphUtils.ts index eaa6bb7e8..078c9167e 100644 --- a/src/subgraph/subgraphUtils.ts +++ b/src/subgraph/subgraphUtils.ts @@ -1,4 +1,6 @@ -import type { INodeOutputSlot, Positionable } from "@/interfaces" +import type { SubgraphInput } from "./SubgraphInput" +import type { SubgraphOutput } from "./SubgraphOutput" +import type { INodeInputSlot, INodeOutputSlot, Positionable } from "@/interfaces" import type { LGraph } from "@/LGraph" import type { ISerialisedNode, SerialisableLLink, SubgraphIO } from "@/types/serialisation" @@ -336,3 +338,33 @@ export function mapSubgraphOutputsAndLinks(resolvedOutputLinks: ResolvedConnecti } return outputs } + +/** + * Type guard to check if a slot is a SubgraphInput. + * @param slot The slot to check + * @returns true if the slot is a SubgraphInput + */ +export function isSubgraphInput(slot: unknown): slot is SubgraphInput { + return slot != null && typeof slot === "object" && "parent" in slot && + slot.parent instanceof SubgraphInputNode +} + +/** + * Type guard to check if a slot is a SubgraphOutput. + * @param slot The slot to check + * @returns true if the slot is a SubgraphOutput + */ +export function isSubgraphOutput(slot: unknown): slot is SubgraphOutput { + return slot != null && typeof slot === "object" && "parent" in slot && + slot.parent instanceof SubgraphOutputNode +} + +/** + * Type guard to check if a slot is a regular node slot (INodeInputSlot or INodeOutputSlot). + * @param slot The slot to check + * @returns true if the slot is a regular node slot + */ +export function isNodeSlot(slot: unknown): slot is INodeInputSlot | INodeOutputSlot { + return slot != null && typeof slot === "object" && + ("link" in slot || "links" in slot) +} diff --git a/test/subgraph/SubgraphSlotConnections.test.ts b/test/subgraph/SubgraphSlotConnections.test.ts new file mode 100644 index 000000000..1caa4bb60 --- /dev/null +++ b/test/subgraph/SubgraphSlotConnections.test.ts @@ -0,0 +1,280 @@ +import { describe, expect, it } from "vitest" + +import { LGraphNode } from "@/litegraph" +import { NodeInputSlot } from "@/node/NodeInputSlot" +import { NodeOutputSlot } from "@/node/NodeOutputSlot" +import { isSubgraphInput, isSubgraphOutput } from "@/subgraph/subgraphUtils" + +import { createTestSubgraph, createTestSubgraphNode } from "./fixtures/subgraphHelpers" + +describe("Subgraph slot connections", () => { + describe("SubgraphInput connections", () => { + it("should connect to compatible regular input slots", () => { + const subgraph = createTestSubgraph({ + inputs: [{ name: "test_input", type: "number" }], + }) + + const subgraphInput = subgraph.inputs[0] + + const node = new LGraphNode("TestNode") + node.addInput("compatible_input", "number") + node.addInput("incompatible_input", "string") + subgraph.add(node) + + const compatibleSlot = node.inputs[0] as NodeInputSlot + const incompatibleSlot = node.inputs[1] as NodeInputSlot + + expect(compatibleSlot.isValidTarget(subgraphInput)).toBe(true) + expect(incompatibleSlot.isValidTarget(subgraphInput)).toBe(false) + }) + + // "not implemented" yet, but the test passes in terms of type checking + // it("should connect to compatible SubgraphOutput", () => { + // const subgraph = createTestSubgraph({ + // inputs: [{ name: "test_input", type: "number" }], + // outputs: [{ name: "test_output", type: "number" }], + // }) + + // const subgraphInput = subgraph.inputs[0] + // const subgraphOutput = subgraph.outputs[0] + + // expect(subgraphOutput.isValidTarget(subgraphInput)).toBe(true) + // }) + + it("should not connect to another SubgraphInput", () => { + const subgraph = createTestSubgraph({ + inputs: [ + { name: "input1", type: "number" }, + { name: "input2", type: "number" }, + ], + }) + + const subgraphInput1 = subgraph.inputs[0] + const subgraphInput2 = subgraph.inputs[1] + + expect(subgraphInput2.isValidTarget(subgraphInput1)).toBe(false) + }) + + it("should not connect to output slots", () => { + const subgraph = createTestSubgraph({ + inputs: [{ name: "test_input", type: "number" }], + }) + + const subgraphInput = subgraph.inputs[0] + + const node = new LGraphNode("TestNode") + node.addOutput("test_output", "number") + subgraph.add(node) + const outputSlot = node.outputs[0] as NodeOutputSlot + + expect(outputSlot.isValidTarget(subgraphInput)).toBe(false) + }) + }) + + describe("SubgraphOutput connections", () => { + it("should connect from compatible regular output slots", () => { + const subgraph = createTestSubgraph() + const node = new LGraphNode("TestNode") + node.addOutput("out", "number") + subgraph.add(node) + + const subgraphOutput = subgraph.addOutput("result", "number") + const nodeOutput = node.outputs[0] + + expect(subgraphOutput.isValidTarget(nodeOutput)).toBe(true) + }) + + it("should connect from SubgraphInput", () => { + const subgraph = createTestSubgraph() + + const subgraphInput = subgraph.addInput("value", "number") + const subgraphOutput = subgraph.addOutput("result", "number") + + expect(subgraphOutput.isValidTarget(subgraphInput)).toBe(true) + }) + + it("should not connect to another SubgraphOutput", () => { + const subgraph = createTestSubgraph() + + const subgraphOutput1 = subgraph.addOutput("result1", "number") + const subgraphOutput2 = subgraph.addOutput("result2", "number") + + expect(subgraphOutput1.isValidTarget(subgraphOutput2)).toBe(false) + }) + }) + + describe("Type compatibility", () => { + it("should respect type compatibility for SubgraphInput connections", () => { + const subgraph = createTestSubgraph({ + inputs: [{ name: "number_input", type: "number" }], + }) + + const subgraphInput = subgraph.inputs[0] + + const node = new LGraphNode("TestNode") + node.addInput("number_slot", "number") + node.addInput("string_slot", "string") + node.addInput("any_slot", "*") + node.addInput("boolean_slot", "boolean") + subgraph.add(node) + + const numberSlot = node.inputs[0] as NodeInputSlot + const stringSlot = node.inputs[1] as NodeInputSlot + const anySlot = node.inputs[2] as NodeInputSlot + const booleanSlot = node.inputs[3] as NodeInputSlot + + expect(numberSlot.isValidTarget(subgraphInput)).toBe(true) + expect(stringSlot.isValidTarget(subgraphInput)).toBe(false) + expect(anySlot.isValidTarget(subgraphInput)).toBe(true) + expect(booleanSlot.isValidTarget(subgraphInput)).toBe(false) + }) + + it("should respect type compatibility for SubgraphOutput connections", () => { + const subgraph = createTestSubgraph() + const node = new LGraphNode("TestNode") + node.addOutput("out", "string") + subgraph.add(node) + + const subgraphOutput = subgraph.addOutput("result", "number") + const nodeOutput = node.outputs[0] + + expect(subgraphOutput.isValidTarget(nodeOutput)).toBe(false) + }) + + it("should handle wildcard SubgraphInput", () => { + const subgraph = createTestSubgraph({ + inputs: [{ name: "any_input", type: "*" }], + }) + + const subgraphInput = subgraph.inputs[0] + + const node = new LGraphNode("TestNode") + node.addInput("number_slot", "number") + subgraph.add(node) + + const numberSlot = node.inputs[0] as NodeInputSlot + + expect(numberSlot.isValidTarget(subgraphInput)).toBe(true) + }) + }) + + describe("Type guards", () => { + it("should correctly identify SubgraphInput", () => { + const subgraph = createTestSubgraph() + const subgraphInput = subgraph.addInput("value", "number") + const node = new LGraphNode("TestNode") + node.addInput("in", "number") + + expect(isSubgraphInput(subgraphInput)).toBe(true) + expect(isSubgraphInput(node.inputs[0])).toBe(false) + expect(isSubgraphInput(null)).toBe(false) + // eslint-disable-next-line unicorn/no-useless-undefined + expect(isSubgraphInput(undefined)).toBe(false) + expect(isSubgraphInput({})).toBe(false) + }) + + it("should correctly identify SubgraphOutput", () => { + const subgraph = createTestSubgraph() + const subgraphOutput = subgraph.addOutput("result", "number") + const node = new LGraphNode("TestNode") + node.addOutput("out", "number") + + expect(isSubgraphOutput(subgraphOutput)).toBe(true) + expect(isSubgraphOutput(node.outputs[0])).toBe(false) + expect(isSubgraphOutput(null)).toBe(false) + // eslint-disable-next-line unicorn/no-useless-undefined + expect(isSubgraphOutput(undefined)).toBe(false) + expect(isSubgraphOutput({})).toBe(false) + }) + }) + + describe("Nested subgraphs", () => { + it("should handle dragging from SubgraphInput in nested subgraphs", () => { + const parentSubgraph = createTestSubgraph({ + inputs: [{ name: "parent_input", type: "number" }], + outputs: [{ name: "parent_output", type: "number" }], + }) + + const nestedSubgraph = createTestSubgraph({ + inputs: [{ name: "nested_input", type: "number" }], + outputs: [{ name: "nested_output", type: "number" }], + }) + + const nestedSubgraphNode = createTestSubgraphNode(nestedSubgraph) + parentSubgraph.add(nestedSubgraphNode) + + const regularNode = new LGraphNode("TestNode") + regularNode.addInput("test_input", "number") + nestedSubgraph.add(regularNode) + + const nestedSubgraphInput = nestedSubgraph.inputs[0] + const regularNodeSlot = regularNode.inputs[0] as NodeInputSlot + + expect(regularNodeSlot.isValidTarget(nestedSubgraphInput)).toBe(true) + }) + + it("should handle multiple levels of nesting", () => { + const level1 = createTestSubgraph({ + inputs: [{ name: "level1_input", type: "string" }], + }) + + const level2 = createTestSubgraph({ + inputs: [{ name: "level2_input", type: "string" }], + }) + + const level3 = createTestSubgraph({ + inputs: [{ name: "level3_input", type: "string" }], + outputs: [{ name: "level3_output", type: "string" }], + }) + + const level2Node = createTestSubgraphNode(level2) + level1.add(level2Node) + + const level3Node = createTestSubgraphNode(level3) + level2.add(level3Node) + + const deepNode = new LGraphNode("DeepNode") + deepNode.addInput("deep_input", "string") + level3.add(deepNode) + + const level3Input = level3.inputs[0] + const deepNodeSlot = deepNode.inputs[0] as NodeInputSlot + + expect(deepNodeSlot.isValidTarget(level3Input)).toBe(true) + + const level3Output = level3.outputs[0] + expect(level3Output.isValidTarget(level3Input)).toBe(true) + }) + + it("should maintain type checking across nesting levels", () => { + const outer = createTestSubgraph({ + inputs: [{ name: "outer_number", type: "number" }], + }) + + const inner = createTestSubgraph({ + inputs: [ + { name: "inner_number", type: "number" }, + { name: "inner_string", type: "string" }, + ], + }) + + const innerNode = createTestSubgraphNode(inner) + outer.add(innerNode) + + const node = new LGraphNode("TestNode") + node.addInput("number_slot", "number") + node.addInput("string_slot", "string") + inner.add(node) + + const innerNumberInput = inner.inputs[0] + const innerStringInput = inner.inputs[1] + const numberSlot = node.inputs[0] as NodeInputSlot + const stringSlot = node.inputs[1] as NodeInputSlot + + expect(numberSlot.isValidTarget(innerNumberInput)).toBe(true) + expect(numberSlot.isValidTarget(innerStringInput)).toBe(false) + expect(stringSlot.isValidTarget(innerNumberInput)).toBe(false) + expect(stringSlot.isValidTarget(innerStringInput)).toBe(true) + }) + }) +}) diff --git a/test/subgraph/SubgraphSlotVisualFeedback.test.ts b/test/subgraph/SubgraphSlotVisualFeedback.test.ts new file mode 100644 index 000000000..2871edacd --- /dev/null +++ b/test/subgraph/SubgraphSlotVisualFeedback.test.ts @@ -0,0 +1,181 @@ +import { beforeEach, describe, expect, it, vi } from "vitest" + +import { LGraphNode } from "@/litegraph" + +import { createTestSubgraph } from "./fixtures/subgraphHelpers" + +describe("SubgraphSlot visual feedback", () => { + let mockCtx: CanvasRenderingContext2D + let mockColorContext: any + let globalAlphaValues: number[] + + beforeEach(() => { + // Clear the array before each test + globalAlphaValues = [] + + // Create a mock canvas context that tracks all globalAlpha values + const mockContext = { + _globalAlpha: 1, + get globalAlpha() { + return this._globalAlpha + }, + set globalAlpha(value: number) { + this._globalAlpha = value + globalAlphaValues.push(value) + }, + fillStyle: "", + strokeStyle: "", + lineWidth: 1, + beginPath: vi.fn(), + arc: vi.fn(), + fill: vi.fn(), + stroke: vi.fn(), + rect: vi.fn(), + fillText: vi.fn(), + } + mockCtx = mockContext as unknown as CanvasRenderingContext2D + + // Create a mock color context + mockColorContext = { + defaultInputColor: "#FF0000", + defaultOutputColor: "#00FF00", + getConnectedColor: vi.fn().mockReturnValue("#0000FF"), + getDisconnectedColor: vi.fn().mockReturnValue("#AAAAAA"), + } + }) + + it("should render SubgraphInput slots with full opacity when dragging from compatible slot", () => { + const subgraph = createTestSubgraph() + const node = new LGraphNode("TestNode") + node.addInput("in", "number") + subgraph.add(node) + + // Add a subgraph input + const subgraphInput = subgraph.addInput("value", "number") + + // Simulate dragging from the subgraph input (which acts as output inside subgraph) + const nodeInput = node.inputs[0] + + // Draw the slot with a compatible fromSlot + subgraphInput.draw({ + ctx: mockCtx, + colorContext: mockColorContext, + fromSlot: nodeInput, + editorAlpha: 1, + }) + + // Should render with full opacity (not 0.4) + // Check that 0.4 was NOT set during drawing + expect(globalAlphaValues).not.toContain(0.4) + }) + + it("should render SubgraphInput slots with 40% opacity when dragging from another SubgraphInput", () => { + const subgraph = createTestSubgraph() + + // Add two subgraph inputs + const subgraphInput1 = subgraph.addInput("value1", "number") + const subgraphInput2 = subgraph.addInput("value2", "number") + + // Draw subgraphInput2 while dragging from subgraphInput1 (incompatible - both are outputs inside subgraph) + subgraphInput2.draw({ + ctx: mockCtx, + colorContext: mockColorContext, + fromSlot: subgraphInput1, + editorAlpha: 1, + }) + + // Should render with 40% opacity + // Check that 0.4 was set during drawing + expect(globalAlphaValues).toContain(0.4) + }) + + it("should render SubgraphOutput slots with full opacity when dragging from compatible slot", () => { + const subgraph = createTestSubgraph() + const node = new LGraphNode("TestNode") + node.addOutput("out", "number") + subgraph.add(node) + + // Add a subgraph output + const subgraphOutput = subgraph.addOutput("result", "number") + + // Simulate dragging from a node output + const nodeOutput = node.outputs[0] + + // Draw the slot with a compatible fromSlot + subgraphOutput.draw({ + ctx: mockCtx, + colorContext: mockColorContext, + fromSlot: nodeOutput, + editorAlpha: 1, + }) + + // Should render with full opacity (not 0.4) + // Check that 0.4 was NOT set during drawing + expect(globalAlphaValues).not.toContain(0.4) + }) + + it("should render SubgraphOutput slots with 40% opacity when dragging from another SubgraphOutput", () => { + const subgraph = createTestSubgraph() + + // Add two subgraph outputs + const subgraphOutput1 = subgraph.addOutput("result1", "number") + const subgraphOutput2 = subgraph.addOutput("result2", "number") + + // Draw subgraphOutput2 while dragging from subgraphOutput1 (incompatible - both are inputs inside subgraph) + subgraphOutput2.draw({ + ctx: mockCtx, + colorContext: mockColorContext, + fromSlot: subgraphOutput1, + editorAlpha: 1, + }) + + // Should render with 40% opacity + // Check that 0.4 was set during drawing + expect(globalAlphaValues).toContain(0.4) + }) + + // "not implmeneted yet" + // it("should render slots with full opacity when dragging between compatible SubgraphInput and SubgraphOutput", () => { + // const subgraph = createTestSubgraph() + + // // Add subgraph input and output with matching types + // const subgraphInput = subgraph.addInput("value", "number") + // const subgraphOutput = subgraph.addOutput("result", "number") + + // // Draw SubgraphOutput slot while dragging from SubgraphInput + // subgraphOutput.draw({ + // ctx: mockCtx, + // colorContext: mockColorContext, + // fromSlot: subgraphInput, + // editorAlpha: 1, + // }) + + // // Should render with full opacity + // expect(mockCtx.globalAlpha).toBe(1) + // }) + + it("should render slots with 40% opacity when dragging between incompatible types", () => { + const subgraph = createTestSubgraph() + const node = new LGraphNode("TestNode") + node.addOutput("string_output", "string") + subgraph.add(node) + + // Add subgraph output with incompatible type + const subgraphOutput = subgraph.addOutput("result", "number") + + // Get the string output slot from the node + const nodeStringOutput = node.outputs[0] + + // Draw the SubgraphOutput slot while dragging from a node output with incompatible type + subgraphOutput.draw({ + ctx: mockCtx, + colorContext: mockColorContext, + fromSlot: nodeStringOutput, + editorAlpha: 1, + }) + + // Should render with 40% opacity due to type mismatch + // Check that 0.4 was set during drawing + expect(globalAlphaValues).toContain(0.4) + }) +})