diff --git a/src/LGraphCanvas.ts b/src/LGraphCanvas.ts index f998815a6..228a80996 100644 --- a/src/LGraphCanvas.ts +++ b/src/LGraphCanvas.ts @@ -2675,23 +2675,40 @@ export class LGraphCanvas { if (!firstLink || !linkConnector.isNodeValidDrop(node)) { // No link, or none of the dragged links may be dropped here } else if (linkConnector.state.connectingTo === "input") { - if (inputId === -1 && outputId === -1) { + if (overWidget) { + // Check widgets first - inputId is only valid if over the input socket + const slot = node.getSlotFromWidget(overWidget) + + if (slot && linkConnector.isInputValidDrop(node, slot)) { + highlightInput = slot + highlightPos = node.getInputSlotPos(slot) + linkConnector.overWidget = overWidget + } + } + + // Not over a valid widget - treat drop on invalid widget same as node background + if (!linkConnector.overWidget) { + if (inputId === -1 && outputId === -1) { // Node background / title under the pointer - if (!linkConnector.overWidget) { const result = node.findInputByType(firstLink.fromSlot.type) if (result) { highlightInput = result.slot - highlightPos = node.getInputPos(result.index) + highlightPos = node.getInputSlotPos(result.slot) } + } else if ( + inputId != -1 && + node.inputs[inputId] && + LiteGraph.isValidConnection(firstLink.fromSlot.type, node.inputs[inputId].type) + ) { + highlightPos = pos + // XXX CHECK THIS + highlightInput = node.inputs[inputId] + } + + if (highlightInput) { + const widget = node.getWidgetFromSlot(highlightInput) + if (widget) linkConnector.overWidget = widget } - } else if ( - inputId != -1 && - node.inputs[inputId] && - LiteGraph.isValidConnection(firstLink.fromSlot.type, node.inputs[inputId].type) - ) { - highlightPos = pos - // XXX CHECK THIS - highlightInput = node.inputs[inputId] } } else if (linkConnector.state.connectingTo === "output") { // Connecting from an input to an output diff --git a/src/LGraphNode.ts b/src/LGraphNode.ts index 1371c036e..517242e67 100644 --- a/src/LGraphNode.ts +++ b/src/LGraphNode.ts @@ -2900,20 +2900,29 @@ export class LGraphNode implements Positionable, IPinnable, IColorable { * @returns Position of the input slot */ getInputPos(slot: number): Point { - const { pos: [nodeX, nodeY], inputs } = this + return this.getInputSlotPos(this.inputs[slot]) + } + + /** + * Gets the position of an input slot, in graph co-ordinates. + * @param input The actual node input object + * @returns Position of the centre of the input slot in graph co-ordinates. + */ + getInputSlotPos(input: INodeInputSlot): Point { + const { pos: [nodeX, nodeY] } = this if (this.flags.collapsed) { const halfTitle = LiteGraph.NODE_TITLE_HEIGHT * 0.5 return [nodeX, nodeY - halfTitle] } - const inputPos = inputs?.[slot]?.pos - if (inputPos) return [nodeX + inputPos[0], nodeY + inputPos[1]] + const { pos } = input + if (pos) return [nodeX + pos[0], nodeY + pos[1]] // default vertical slots const offsetX = LiteGraph.NODE_SLOT_HEIGHT * 0.5 const nodeOffsetY = this.constructor.slot_start_y || 0 - const slotIndex = this.#defaultVerticalInputs.indexOf(this.inputs[slot]) + const slotIndex = this.#defaultVerticalInputs.indexOf(input) const slotY = (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT return [nodeX + offsetX, nodeY + slotY + nodeOffsetY] @@ -3481,8 +3490,8 @@ export class LGraphNode implements Positionable, IPinnable, IColorable { /** * Returns the input slot that is associated with the given widget. */ - getSlotFromWidget(widget: IWidget): INodeInputSlot | undefined { - return this.inputs.find(slot => isWidgetInputSlot(slot) && slot.widget.name === widget.name) + getSlotFromWidget(widget: IWidget | undefined): INodeInputSlot | undefined { + if (widget) return this.inputs.find(slot => isWidgetInputSlot(slot) && slot.widget.name === widget.name) } /** diff --git a/src/canvas/LinkConnector.ts b/src/canvas/LinkConnector.ts index 94c42a78e..12f39e74c 100644 --- a/src/canvas/LinkConnector.ts +++ b/src/canvas/LinkConnector.ts @@ -405,18 +405,11 @@ export class LinkConnector { // To input } else if (connectingTo === "input") { const input = node.getInputOnPos([canvasX, canvasY]) + const inputOrSocket = input ?? node.getSlotFromWidget(this.overWidget) // Input slot - if (input) { - this.#dropOnInput(node, input) - } else if (this.overWidget && renderLinks[0] instanceof ToInputRenderLink) { - // Widget - this.events.dispatch("dropped-on-widget", { - link: renderLinks[0], - node, - widget: this.overWidget, - }) - this.overWidget = undefined + if (inputOrSocket) { + this.#dropOnInput(node, inputOrSocket) } else { // Node background / title this.connectToNode(node, event) @@ -570,6 +563,10 @@ export class LinkConnector { } } + isInputValidDrop(node: LGraphNode, input: INodeInputSlot): boolean { + return this.renderLinks.some(link => link.canConnectToInput(node, input)) + } + isNodeValidDrop(node: LGraphNode): boolean { if (this.state.connectingTo === "output") { return node.outputs.some(output => this.renderLinks.some(link => link.canConnectToOutput(node, output))) diff --git a/test/LGraphNode.test.ts b/test/LGraphNode.test.ts index b75701dbf..9d8265aa0 100644 --- a/test/LGraphNode.test.ts +++ b/test/LGraphNode.test.ts @@ -1,4 +1,7 @@ -import { describe, expect } from "vitest" +import type { INodeInputSlot, Point } from "@/interfaces" +import type { ISerialisedNode } from "@/types/serialisation" + +import { afterEach, beforeEach, describe, expect, vi } from "vitest" import { LGraphNode, LiteGraph } from "@/litegraph" import { LGraph } from "@/litegraph" @@ -7,41 +10,92 @@ import { NodeOutputSlot } from "@/node/NodeOutputSlot" import { test } from "./testExtensions" +function getMockISerialisedNode(data: Partial): ISerialisedNode { + return Object.assign({ + id: 0, + flags: {}, + type: "TestNode", + pos: [100, 100], + size: [100, 100], + order: 0, + mode: 0, + }, data) +} + describe("LGraphNode", () => { + let node: LGraphNode + let origLiteGraph: typeof LiteGraph + + beforeEach(() => { + origLiteGraph = Object.assign({}, LiteGraph) + + Object.assign(LiteGraph, { + NODE_TITLE_HEIGHT: 20, + NODE_SLOT_HEIGHT: 15, + NODE_TEXT_SIZE: 14, + DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)", + DEFAULT_GROUP_FONT_SIZE: 24, + isValidConnection: vi.fn().mockReturnValue(true), + }) + node = new LGraphNode("Test Node") + node.pos = [100, 200] + node.size = [150, 100] // Example size + + // Reset mocks if needed + vi.clearAllMocks() + }) + + afterEach(() => { + Object.assign(LiteGraph, origLiteGraph) + }) + test("should serialize position/size correctly", () => { const node = new LGraphNode("TestNode") - node.pos = [10, 10] - expect(node.pos).toEqual(new Float32Array([10, 10])) - expect(node.serialize().pos).toEqual([10, 10]) + node.pos = [10, 20] + node.size = [30, 40] + const json = node.serialize() + expect(json.pos).toEqual([10, 20]) + expect(json.size).toEqual([30, 40]) - node.size = [100, 100] - expect(node.size).toEqual(new Float32Array([100, 100])) - expect(node.serialize().size).toEqual([100, 100]) + const configureData: ISerialisedNode = { + id: node.id, + type: node.type, + pos: [50, 60], + size: [70, 80], + flags: {}, + order: node.order, + mode: node.mode, + inputs: node.inputs?.map(i => ({ name: i.name, type: i.type, link: i.link })), + outputs: node.outputs?.map(o => ({ name: o.name, type: o.type, links: o.links, slot_index: o.slot_index })), + } + node.configure(configureData) + expect(node.pos).toEqual(new Float32Array([50, 60])) + expect(node.size).toEqual(new Float32Array([70, 80])) }) test("should configure inputs correctly", () => { const node = new LGraphNode("TestNode") - node.configure({ + node.configure(getMockISerialisedNode({ id: 0, inputs: [{ name: "TestInput", type: "number", link: null }], - }) + })) expect(node.inputs.length).toEqual(1) expect(node.inputs[0].name).toEqual("TestInput") expect(node.inputs[0].link).toEqual(null) expect(node.inputs[0]).instanceOf(NodeInputSlot) // Should not override existing inputs - node.configure({ id: 1 }) + node.configure(getMockISerialisedNode({ id: 1 })) expect(node.id).toEqual(1) expect(node.inputs.length).toEqual(1) }) test("should configure outputs correctly", () => { const node = new LGraphNode("TestNode") - node.configure({ + node.configure(getMockISerialisedNode({ id: 0, outputs: [{ name: "TestOutput", type: "number", links: [] }], - }) + })) expect(node.outputs.length).toEqual(1) expect(node.outputs[0].name).toEqual("TestOutput") expect(node.outputs[0].type).toEqual("number") @@ -49,7 +103,7 @@ describe("LGraphNode", () => { expect(node.outputs[0]).instanceOf(NodeOutputSlot) // Should not override existing outputs - node.configure({ id: 1 }) + node.configure(getMockISerialisedNode({ id: 1 })) expect(node.id).toEqual(1) expect(node.outputs.length).toEqual(1) }) @@ -60,14 +114,14 @@ describe("LGraphNode", () => { const node2 = new LGraphNode("TargetNode") // Configure nodes with input/output slots - node1.configure({ + node1.configure(getMockISerialisedNode({ id: 1, outputs: [{ name: "Output1", type: "number", links: [] }], - }) - node2.configure({ + })) + node2.configure(getMockISerialisedNode({ id: 2, inputs: [{ name: "Input1", type: "number", link: null }], - }) + })) // Create a graph and add nodes to it const graph = new LGraph() @@ -108,21 +162,21 @@ describe("LGraphNode", () => { const targetNode2 = new LGraphNode("TargetNode2") // Configure nodes with input/output slots - sourceNode.configure({ + sourceNode.configure(getMockISerialisedNode({ id: 1, outputs: [ { name: "Output1", type: "number", links: [] }, { name: "Output2", type: "number", links: [] }, ], - }) - targetNode1.configure({ + })) + targetNode1.configure(getMockISerialisedNode({ id: 2, inputs: [{ name: "Input1", type: "number", link: null }], - }) - targetNode2.configure({ + })) + targetNode2.configure(getMockISerialisedNode({ id: 3, inputs: [{ name: "Input1", type: "number", link: null }], - }) + })) // Create a graph and add nodes to it const graph = new LGraph() @@ -184,11 +238,11 @@ describe("LGraphNode", () => { node.boundingRect[1] = 100 node.boundingRect[2] = 100 node.boundingRect[3] = 100 - node.configure({ + node.configure(getMockISerialisedNode({ id: 1, inputs: [{ name: "Input1", type: "number", link: null }], outputs: [{ name: "Output1", type: "number", links: [] }], - }) + })) // Collapse the node node.flags.collapsed = true @@ -197,25 +251,25 @@ describe("LGraphNode", () => { const inputPos = node.getInputPos(0) const outputPos = node.getOutputPos(0) - expect(inputPos).toEqual([100, 85]) - expect(outputPos).toEqual([180, 85]) + expect(inputPos).toEqual([100, 90]) + expect(outputPos).toEqual([180, 90]) }) test("should return correct positions for input and output slots", () => { const node = new LGraphNode("TestNode") node.pos = [100, 100] node.size = [100, 100] - node.configure({ + node.configure(getMockISerialisedNode({ id: 1, inputs: [{ name: "Input1", type: "number", link: null }], outputs: [{ name: "Output1", type: "number", links: [] }], - }) + })) const inputPos = node.getInputPos(0) const outputPos = node.getOutputPos(0) - expect(inputPos).toEqual([110, 114]) - expect(outputPos).toEqual([191, 114]) + expect(inputPos).toEqual([107.5, 110.5]) + expect(outputPos).toEqual([193.5, 110.5]) }) }) @@ -224,11 +278,11 @@ describe("LGraphNode", () => { const node = new LGraphNode("TestNode") node.pos = [100, 100] node.size = [100, 100] - node.configure({ + node.configure(getMockISerialisedNode({ id: 1, inputs: [{ name: "Input1", type: "number", link: null }], outputs: [{ name: "Output1", type: "number", links: [] }], - }) + })) // Test point far outside node bounds expect(node.getSlotOnPos([0, 0])).toBeUndefined() @@ -244,13 +298,13 @@ describe("LGraphNode", () => { node.boundingRect[1] = 100 node.boundingRect[2] = 200 node.boundingRect[3] = 200 - node.configure({ + node.configure(getMockISerialisedNode({ id: 1, inputs: [ { name: "Input1", type: "number", link: null }, { name: "Input2", type: "string", link: null }, ], - }) + })) // Get position of first input slot const inputPos = node.getInputPos(0) @@ -271,13 +325,13 @@ describe("LGraphNode", () => { node.boundingRect[1] = 100 node.boundingRect[2] = 200 node.boundingRect[3] = 200 - node.configure({ + node.configure(getMockISerialisedNode({ id: 1, outputs: [ { name: "Output1", type: "number", links: [] }, { name: "Output2", type: "string", links: [] }, ], - }) + })) // Get position of first output slot const outputPos = node.getOutputPos(0) @@ -299,11 +353,11 @@ describe("LGraphNode", () => { node.boundingRect[1] = 100 node.boundingRect[2] = 200 node.boundingRect[3] = 200 - node.configure({ + node.configure(getMockISerialisedNode({ id: 1, inputs: [{ name: "Input1", type: "number", link: null }], outputs: [{ name: "Output1", type: "number", links: [] }], - }) + })) // Get positions of first input and output slots const inputPos = node.getInputPos(0) @@ -437,20 +491,86 @@ describe("LGraphNode", () => { expect(node.widgets?.length).toBe(2) node.widgets![0].serialize = false - node.configure({ + node.configure(getMockISerialisedNode({ id: 1, type: "TestNode", pos: [100, 100], size: [100, 100], - flags: {}, properties: {}, - order: 0, - mode: 0, widgets_values: [100], - }) + })) expect(node.widgets![0].value).toBe(1) expect(node.widgets![1].value).toBe(100) }) }) + + describe("getInputSlotPos", () => { + let inputSlot: INodeInputSlot + + beforeEach(() => { + inputSlot = { name: "test_in", type: "string", link: null, boundingRect: new Float32Array([0, 0, 0, 0]) } + }) + test("should return position based on title height when collapsed", () => { + node.flags.collapsed = true + const expectedPos: Point = [100, 200 - LiteGraph.NODE_TITLE_HEIGHT * 0.5] + expect(node.getInputSlotPos(inputSlot)).toEqual(expectedPos) + }) + + test("should return position based on input.pos when defined and not collapsed", () => { + node.flags.collapsed = false + inputSlot.pos = [10, 50] + node.inputs = [inputSlot] + const expectedPos: Point = [100 + 10, 200 + 50] + expect(node.getInputSlotPos(inputSlot)).toEqual(expectedPos) + }) + + test("should return default vertical position when input.pos is undefined and not collapsed", () => { + node.flags.collapsed = false + const inputSlot2 = { name: "test_in_2", type: "number", link: null, boundingRect: new Float32Array([0, 0, 0, 0]) } + node.inputs = [inputSlot, inputSlot2] + const slotIndex = 0 + const nodeOffsetY = (node.constructor as any).slot_start_y || 0 + const expectedY = 200 + (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY + const expectedX = 100 + LiteGraph.NODE_SLOT_HEIGHT * 0.5 + expect(node.getInputSlotPos(inputSlot)).toEqual([expectedX, expectedY]) + const slotIndex2 = 1 + const expectedY2 = 200 + (slotIndex2 + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY + expect(node.getInputSlotPos(inputSlot2)).toEqual([expectedX, expectedY2]) + }) + + test("should return default vertical position including slot_start_y when defined", () => { + (node.constructor as any).slot_start_y = 25 + node.flags.collapsed = false + node.inputs = [inputSlot] + const slotIndex = 0 + const nodeOffsetY = 25 + const expectedY = 200 + (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY + const expectedX = 100 + LiteGraph.NODE_SLOT_HEIGHT * 0.5 + expect(node.getInputSlotPos(inputSlot)).toEqual([expectedX, expectedY]) + delete (node.constructor as any).slot_start_y + }) + }) + + describe("getInputPos", () => { + test("should call getInputSlotPos with the correct input slot from inputs array", () => { + const input0: INodeInputSlot = { name: "in0", type: "string", link: null, boundingRect: new Float32Array([0, 0, 0, 0]) } + const input1: INodeInputSlot = { name: "in1", type: "number", link: null, boundingRect: new Float32Array([0, 0, 0, 0]), pos: [5, 45] } + node.inputs = [input0, input1] + const spy = vi.spyOn(node, "getInputSlotPos") + node.getInputPos(1) + expect(spy).toHaveBeenCalledWith(input1) + const expectedPos: Point = [100 + 5, 200 + 45] + expect(node.getInputPos(1)).toEqual(expectedPos) + spy.mockClear() + node.getInputPos(0) + expect(spy).toHaveBeenCalledWith(input0) + const slotIndex = 0 + const nodeOffsetY = (node.constructor as any).slot_start_y || 0 + const expectedDefaultY = 200 + (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY + const expectedDefaultX = 100 + LiteGraph.NODE_SLOT_HEIGHT * 0.5 + expect(node.getInputPos(0)).toEqual([expectedDefaultX, expectedDefaultY]) + spy.mockRestore() + }) + }) }) diff --git a/test/canvas/LinkConnector.test.ts b/test/canvas/LinkConnector.test.ts new file mode 100644 index 000000000..48cfe88c0 --- /dev/null +++ b/test/canvas/LinkConnector.test.ts @@ -0,0 +1,131 @@ +import type { INodeInputSlot, LGraphNode } from "@/litegraph" + +import { beforeEach, describe, expect, test, vi } from "vitest" + +// We don't strictly need RenderLink interface import for the mock +import { LinkConnector } from "@/canvas/LinkConnector" + +// Mocks +const mockSetConnectingLinks = vi.fn() + +// Mock a structure that has the needed method +function mockRenderLinkImpl(canConnect: boolean) { + return { + canConnectToInput: vi.fn().mockReturnValue(canConnect), + // Add other properties if they become necessary for tests + } +} + +const mockNode = {} as LGraphNode +const mockInput = {} as INodeInputSlot + +describe("LinkConnector", () => { + let connector: LinkConnector + + beforeEach(() => { + connector = new LinkConnector(mockSetConnectingLinks) + // Clear the array directly before each test + connector.renderLinks.length = 0 + vi.clearAllMocks() + }) + + describe("isInputValidDrop", () => { + test("should return false if there are no render links", () => { + expect(connector.isInputValidDrop(mockNode, mockInput)).toBe(false) + }) + + test("should return true if at least one render link can connect", () => { + const link1 = mockRenderLinkImpl(false) + const link2 = mockRenderLinkImpl(true) + // Cast to any to satisfy the push requirement, as we only need the canConnectToInput method + connector.renderLinks.push(link1 as any, link2 as any) + expect(connector.isInputValidDrop(mockNode, mockInput)).toBe(true) + expect(link1.canConnectToInput).toHaveBeenCalledWith(mockNode, mockInput) + expect(link2.canConnectToInput).toHaveBeenCalledWith(mockNode, mockInput) + }) + + test("should return false if no render links can connect", () => { + const link1 = mockRenderLinkImpl(false) + const link2 = mockRenderLinkImpl(false) + connector.renderLinks.push(link1 as any, link2 as any) + expect(connector.isInputValidDrop(mockNode, mockInput)).toBe(false) + expect(link1.canConnectToInput).toHaveBeenCalledWith(mockNode, mockInput) + expect(link2.canConnectToInput).toHaveBeenCalledWith(mockNode, mockInput) + }) + + test("should call canConnectToInput on each render link until one returns true", () => { + const link1 = mockRenderLinkImpl(false) + const link2 = mockRenderLinkImpl(true) // This one can connect + const link3 = mockRenderLinkImpl(false) + connector.renderLinks.push(link1 as any, link2 as any, link3 as any) + + expect(connector.isInputValidDrop(mockNode, mockInput)).toBe(true) + + expect(link1.canConnectToInput).toHaveBeenCalledTimes(1) + expect(link2.canConnectToInput).toHaveBeenCalledTimes(1) // Stops here + expect(link3.canConnectToInput).not.toHaveBeenCalled() // Should not be called + }) + }) + + describe("listenUntilReset", () => { + test("should add listener for the specified event and for reset", () => { + const listener = vi.fn() + const addEventListenerSpy = vi.spyOn(connector.events, "addEventListener") + + connector.listenUntilReset("before-drop-links", listener) + + expect(addEventListenerSpy).toHaveBeenCalledWith("before-drop-links", listener, undefined) + expect(addEventListenerSpy).toHaveBeenCalledWith("reset", expect.any(Function), { once: true }) + }) + + test("should call the listener when the event is dispatched before reset", () => { + const listener = vi.fn() + const eventData = { renderLinks: [], event: {} as any } // Mock event data + connector.listenUntilReset("before-drop-links", listener) + + connector.events.dispatch("before-drop-links", eventData) + + expect(listener).toHaveBeenCalledTimes(1) + expect(listener).toHaveBeenCalledWith(new CustomEvent("before-drop-links")) + }) + + test("should remove the listener when reset is dispatched", () => { + const listener = vi.fn() + const removeEventListenerSpy = vi.spyOn(connector.events, "removeEventListener") + + connector.listenUntilReset("before-drop-links", listener) + + // Simulate the reset event being dispatched + connector.events.dispatch("reset", false) + + // Check if removeEventListener was called correctly for the original listener + expect(removeEventListenerSpy).toHaveBeenCalledWith("before-drop-links", listener) + }) + + test("should not call the listener after reset is dispatched", () => { + const listener = vi.fn() + const eventData = { renderLinks: [], event: {} as any } + connector.listenUntilReset("before-drop-links", listener) + + // Dispatch reset first + connector.events.dispatch("reset", false) + + // Then dispatch the original event + connector.events.dispatch("before-drop-links", eventData) + + expect(listener).not.toHaveBeenCalled() + }) + + test("should pass options to addEventListener", () => { + const listener = vi.fn() + const options = { once: true } + const addEventListenerSpy = vi.spyOn(connector.events, "addEventListener") + + connector.listenUntilReset("after-drop-links", listener, options) + + expect(addEventListenerSpy).toHaveBeenCalledWith("after-drop-links", listener, options) + // Still adds the reset listener + expect(addEventListenerSpy).toHaveBeenCalledWith("reset", expect.any(Function), { once: true }) + }) + }) +})