diff --git a/test/LGraph.test.ts b/test/LGraph.test.ts index 25d935f94..39b1cee70 100644 --- a/test/LGraph.test.ts +++ b/test/LGraph.test.ts @@ -76,6 +76,40 @@ describe("LGraph", () => { expect(graph.reroutes.size).toBe(1) expect(graph.reroutes.values().next().value!.floating).not.toBeUndefined() }) + + test("Reroutes and branches should be retained when the input node is removed", ({ expect, floatingBranchGraph: graph }) => { + expect(graph.nodes.length).toBe(3) + graph.remove(graph.nodes[2]) + expect(graph.nodes.length).toBe(2) + expect(graph.links.size).toBe(1) + expect(graph.floatingLinks.size).toBe(1) + expect(graph.reroutes.size).toBe(4) + graph.remove(graph.nodes[1]) + expect(graph.nodes.length).toBe(1) + expect(graph.links.size).toBe(0) + expect(graph.floatingLinks.size).toBe(2) + expect(graph.reroutes.size).toBe(4) + }) + + test("Floating reroutes should be removed when neither input nor output is connected", ({ expect, floatingBranchGraph: graph }) => { + // Remove output node + graph.remove(graph.nodes[0]) + expect(graph.nodes.length).toBe(2) + expect(graph.links.size).toBe(0) + expect(graph.floatingLinks.size).toBe(2) + // The original floating reroute should be removed + expect(graph.reroutes.size).toBe(3) + graph.remove(graph.nodes[0]) + expect(graph.nodes.length).toBe(1) + expect(graph.links.size).toBe(0) + expect(graph.floatingLinks.size).toBe(1) + expect(graph.reroutes.size).toBe(3) + graph.remove(graph.nodes[0]) + expect(graph.nodes.length).toBe(0) + expect(graph.links.size).toBe(0) + expect(graph.floatingLinks.size).toBe(0) + expect(graph.reroutes.size).toBe(0) + }) }) }) diff --git a/test/LGraphNode.test.ts b/test/LGraphNode.test.ts index 4c4859372..1fabdeaf0 100644 --- a/test/LGraphNode.test.ts +++ b/test/LGraphNode.test.ts @@ -1,6 +1,7 @@ import { describe, expect } from "vitest" import { LGraphNode } from "@/litegraph" +import { LGraph } from "@/litegraph" import { NodeInputSlot, NodeOutputSlot } from "@/NodeSlot" import { test } from "./testExtensions" @@ -51,4 +52,266 @@ describe("LGraphNode", () => { expect(node.id).toEqual(1) expect(node.outputs.length).toEqual(1) }) + + describe("Disconnect I/O Slots", () => { + test("should disconnect input correctly", () => { + const node1 = new LGraphNode("SourceNode") + const node2 = new LGraphNode("TargetNode") + + // Configure nodes with input/output slots + node1.configure({ + id: 1, + outputs: [{ name: "Output1", type: "number", links: [] }], + }) + node2.configure({ + id: 2, + inputs: [{ name: "Input1", type: "number", link: null }], + }) + + // Create a graph and add nodes to it + const graph = new LGraph() + graph.add(node1) + graph.add(node2) + + // Connect the nodes + const link = node1.connect(0, node2, 0) + expect(link).not.toBeNull() + expect(node2.inputs[0].link).toBe(link?.id) + expect(node1.outputs[0].links).toContain(link?.id) + + // Test disconnecting by slot number + const disconnected = node2.disconnectInput(0) + expect(disconnected).toBe(true) + expect(node2.inputs[0].link).toBeNull() + expect(node1.outputs[0].links?.length).toBe(0) + expect(graph._links.has(link?.id ?? -1)).toBe(false) + + // Test disconnecting by slot name + node1.connect(0, node2, 0) + const disconnectedByName = node2.disconnectInput("Input1") + expect(disconnectedByName).toBe(true) + expect(node2.inputs[0].link).toBeNull() + + // Test disconnecting non-existent slot + const invalidDisconnect = node2.disconnectInput(999) + expect(invalidDisconnect).toBe(false) + + // Test disconnecting already disconnected input + const alreadyDisconnected = node2.disconnectInput(0) + expect(alreadyDisconnected).toBe(true) + }) + + test("should disconnect output correctly", () => { + const sourceNode = new LGraphNode("SourceNode") + const targetNode1 = new LGraphNode("TargetNode1") + const targetNode2 = new LGraphNode("TargetNode2") + + // Configure nodes with input/output slots + sourceNode.configure({ + id: 1, + outputs: [ + { name: "Output1", type: "number", links: [] }, + { name: "Output2", type: "number", links: [] }, + ], + }) + targetNode1.configure({ + id: 2, + inputs: [{ name: "Input1", type: "number", link: null }], + }) + targetNode2.configure({ + id: 3, + inputs: [{ name: "Input1", type: "number", link: null }], + }) + + // Create a graph and add nodes to it + const graph = new LGraph() + graph.add(sourceNode) + graph.add(targetNode1) + graph.add(targetNode2) + + // Connect multiple nodes to the same output + const link1 = sourceNode.connect(0, targetNode1, 0) + const link2 = sourceNode.connect(0, targetNode2, 0) + expect(link1).not.toBeNull() + expect(link2).not.toBeNull() + expect(sourceNode.outputs[0].links?.length).toBe(2) + + // Test disconnecting specific target node + const disconnectedSpecific = sourceNode.disconnectOutput(0, targetNode1) + expect(disconnectedSpecific).toBe(true) + expect(targetNode1.inputs[0].link).toBeNull() + expect(sourceNode.outputs[0].links?.length).toBe(1) + expect(graph._links.has(link1?.id ?? -1)).toBe(false) + expect(graph._links.has(link2?.id ?? -1)).toBe(true) + + // Test disconnecting by slot name + const link3 = sourceNode.connect(1, targetNode1, 0) + expect(link3).not.toBeNull() + const disconnectedByName = sourceNode.disconnectOutput("Output2", targetNode1) + expect(disconnectedByName).toBe(true) + expect(targetNode1.inputs[0].link).toBeNull() + expect(sourceNode.outputs[1].links?.length).toBe(0) + + // Test disconnecting all connections from an output + const link4 = sourceNode.connect(0, targetNode1, 0) + expect(link4).not.toBeNull() + expect(sourceNode.outputs[0].links?.length).toBe(2) + const disconnectedAll = sourceNode.disconnectOutput(0) + expect(disconnectedAll).toBe(true) + expect(sourceNode.outputs[0].links).toBeNull() + expect(targetNode1.inputs[0].link).toBeNull() + expect(targetNode2.inputs[0].link).toBeNull() + expect(graph._links.has(link2?.id ?? -1)).toBe(false) + expect(graph._links.has(link4?.id ?? -1)).toBe(false) + + // Test disconnecting non-existent slot + const invalidDisconnect = sourceNode.disconnectOutput(999) + expect(invalidDisconnect).toBe(false) + + // Test disconnecting already disconnected output + const alreadyDisconnected = sourceNode.disconnectOutput(0) + expect(alreadyDisconnected).toBe(false) + }) + }) + + describe("getInputPos and getOutputPos", () => { + test("should handle collapsed nodes correctly", () => { + const node = new LGraphNode("TestNode") as unknown as Omit & { boundingRect: Float32Array } + node.pos = [100, 100] + node.size = [100, 100] + node.boundingRect[0] = 100 + node.boundingRect[1] = 100 + node.boundingRect[2] = 100 + node.boundingRect[3] = 100 + node.configure({ + id: 1, + inputs: [{ name: "Input1", type: "number", link: null }], + outputs: [{ name: "Output1", type: "number", links: [] }], + }) + + // Collapse the node + node.flags.collapsed = true + + // Get positions in collapsed state + const inputPos = node.getInputPos(0) + const outputPos = node.getOutputPos(0) + + expect(inputPos).toEqual([100, 85]) + expect(outputPos).toEqual([180, 85]) + }) + + test("should return correct positions for input and output slots", () => { + const node = new LGraphNode("TestNode") + node.pos = [100, 100] + node.size = [100, 100] + node.configure({ + id: 1, + inputs: [{ name: "Input1", type: "number", link: null }], + outputs: [{ name: "Output1", type: "number", links: [] }], + }) + + const inputPos = node.getInputPos(0) + const outputPos = node.getOutputPos(0) + + expect(inputPos).toEqual([110, 114]) + expect(outputPos).toEqual([191, 114]) + }) + }) + + describe("getSlotOnPos", () => { + test("should return undefined when point is outside node bounds", () => { + const node = new LGraphNode("TestNode") + node.pos = [100, 100] + node.size = [100, 100] + node.configure({ + id: 1, + inputs: [{ name: "Input1", type: "number", link: null }], + outputs: [{ name: "Output1", type: "number", links: [] }], + }) + + // Test point far outside node bounds + expect(node.getSlotOnPos([0, 0])).toBeUndefined() + // Test point just outside node bounds + expect(node.getSlotOnPos([99, 99])).toBeUndefined() + }) + + test("should detect input slots correctly", () => { + const node = new LGraphNode("TestNode") as unknown as Omit & { boundingRect: Float32Array } + node.pos = [100, 100] + node.size = [100, 100] + node.boundingRect[0] = 100 + node.boundingRect[1] = 100 + node.boundingRect[2] = 200 + node.boundingRect[3] = 200 + node.configure({ + id: 1, + inputs: [ + { name: "Input1", type: "number", link: null }, + { name: "Input2", type: "string", link: null }, + ], + }) + + // Get position of first input slot + const inputPos = node.getInputPos(0) + // Test point directly on input slot + const slot = node.getSlotOnPos(inputPos) + expect(slot).toBeDefined() + expect(slot?.name).toBe("Input1") + + // Test point near but not on input slot + expect(node.getSlotOnPos([inputPos[0] - 15, inputPos[1]])).toBeUndefined() + }) + + test("should detect output slots correctly", () => { + const node = new LGraphNode("TestNode") as unknown as Omit & { boundingRect: Float32Array } + node.pos = [100, 100] + node.size = [100, 100] + node.boundingRect[0] = 100 + node.boundingRect[1] = 100 + node.boundingRect[2] = 200 + node.boundingRect[3] = 200 + node.configure({ + id: 1, + outputs: [ + { name: "Output1", type: "number", links: [] }, + { name: "Output2", type: "string", links: [] }, + ], + }) + + // Get position of first output slot + const outputPos = node.getOutputPos(0) + // Test point directly on output slot + const slot = node.getSlotOnPos(outputPos) + expect(slot).toBeDefined() + expect(slot?.name).toBe("Output1") + + // Test point near but not on output slot + const gotslot = node.getSlotOnPos([outputPos[0] + 30, outputPos[1]]) + expect(gotslot).toBeUndefined() + }) + + test("should prioritize input slots over output slots", () => { + const node = new LGraphNode("TestNode") as unknown as Omit & { boundingRect: Float32Array } + node.pos = [100, 100] + node.size = [100, 100] + node.boundingRect[0] = 100 + node.boundingRect[1] = 100 + node.boundingRect[2] = 200 + node.boundingRect[3] = 200 + node.configure({ + id: 1, + inputs: [{ name: "Input1", type: "number", link: null }], + outputs: [{ name: "Output1", type: "number", links: [] }], + }) + + // Get positions of first input and output slots + const inputPos = node.getInputPos(0) + + // Test point that could theoretically hit both slots + // Should return the input slot due to priority + const slot = node.getSlotOnPos(inputPos) + expect(slot).toBeDefined() + expect(slot?.name).toBe("Input1") + }) + }) }) diff --git a/test/assets/floatingBranch.json b/test/assets/floatingBranch.json new file mode 100644 index 000000000..0764d73bf --- /dev/null +++ b/test/assets/floatingBranch.json @@ -0,0 +1,123 @@ +{ + "id": "e5ffd5e1-1c01-45ac-90dd-b7d83a206b0f", + "revision": 0, + "last_node_id": 3, + "last_link_id": 3, + "nodes": [ + { + "id": 1, + "type": "InvertMask", + "pos": [100, 130], + "size": [140, 26], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [ + { + "localized_name": "mask", + "name": "mask", + "type": "MASK", + "link": null + } + ], + "outputs": [ + { + "localized_name": "MASK", + "name": "MASK", + "type": "MASK", + "links": [2, 3] + } + ], + "properties": { "Node name for S&R": "InvertMask" }, + "widgets_values": [] + }, + { + "id": 3, + "type": "InvertMask", + "pos": [400, 220], + "size": [140, 26], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [ + { "localized_name": "mask", "name": "mask", "type": "MASK", "link": 3 } + ], + "outputs": [ + { + "localized_name": "MASK", + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { "Node name for S&R": "InvertMask" }, + "widgets_values": [] + }, + { + "id": 2, + "type": "InvertMask", + "pos": [400, 130], + "size": [140, 26], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [ + { "localized_name": "mask", "name": "mask", "type": "MASK", "link": 2 } + ], + "outputs": [ + { + "localized_name": "MASK", + "name": "MASK", + "type": "MASK", + "links": null + } + ], + "properties": { "Node name for S&R": "InvertMask" }, + "widgets_values": [] + } + ], + "links": [ + [2, 1, 0, 2, 0, "MASK"], + [3, 1, 0, 3, 0, "MASK"] + ], + "floatingLinks": [ + { + "id": 6, + "origin_id": 1, + "origin_slot": 0, + "target_id": -1, + "target_slot": -1, + "type": "MASK", + "parentId": 1 + } + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1.2100000000000002, + "offset": [319.8264462809916, 109.2148760330578] + }, + "linkExtensions": [ + { "id": 2, "parentId": 3 }, + { "id": 3, "parentId": 3 } + ], + "reroutes": [ + { + "id": 1, + "parentId": 2, + "pos": [350, 110], + "linkIds": [], + "floating": { "slotType": "output" } + }, + { "id": 2, "parentId": 4, "pos": [310, 150], "linkIds": [2, 3] }, + { "id": 3, "parentId": 2, "pos": [360, 170], "linkIds": [2, 3] }, + { + "id": 4, + "pos": [271.9090881347656, 146.9834747314453], + "linkIds": [2, 3] + } + ] + }, + "version": 0.4 +} diff --git a/test/testExtensions.ts b/test/testExtensions.ts index 3953b08b6..8f627b3a4 100644 --- a/test/testExtensions.ts +++ b/test/testExtensions.ts @@ -5,6 +5,7 @@ import { test as baseTest } from "vitest" import { LGraph } from "@/LGraph" import { LiteGraph } from "@/litegraph" +import floatingBranch from "./assets/floatingBranch.json" import floatingLink from "./assets/floatingLink.json" import linkedNodes from "./assets/linkedNodes.json" import { basicSerialisableGraph, minimalSerialisableGraph, oldSchemaGraph } from "./assets/testGraphs" @@ -15,6 +16,8 @@ interface LitegraphFixtures { oldSchemaGraph: ISerialisedGraph floatingLinkGraph: ISerialisedGraph linkedNodesGraph: ISerialisedGraph + floatingBranchSerialisedGraph: ISerialisedGraph + floatingBranchGraph: LGraph } /** These fixtures alter global state, and are difficult to reset. Relies on a single test per-file to reset state. */ @@ -35,6 +38,12 @@ export const test = baseTest.extend({ oldSchemaGraph: structuredClone(oldSchemaGraph), floatingLinkGraph: structuredClone(floatingLink as unknown as ISerialisedGraph), linkedNodesGraph: structuredClone(linkedNodes as unknown as ISerialisedGraph), + floatingBranchSerialisedGraph: structuredClone(floatingBranch as unknown as ISerialisedGraph), + floatingBranchGraph: async ({ floatingBranchSerialisedGraph }, use) => { + const cloned = structuredClone(floatingBranchSerialisedGraph) + const graph = new LGraph(cloned) + await use(graph) + }, }) /** Test that use {@link DirtyFixtures}. One test per file. */