mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-23 00:04:06 +00:00
Fix widget snap to work with input sockets (#1017)
This commit is contained in:
@@ -2675,23 +2675,40 @@ export class LGraphCanvas {
|
||||
if (!firstLink || !linkConnector.isNodeValidDrop(node)) {
|
||||
// No link, or none of the dragged links may be dropped here
|
||||
} else if (linkConnector.state.connectingTo === "input") {
|
||||
if (inputId === -1 && outputId === -1) {
|
||||
if (overWidget) {
|
||||
// Check widgets first - inputId is only valid if over the input socket
|
||||
const slot = node.getSlotFromWidget(overWidget)
|
||||
|
||||
if (slot && linkConnector.isInputValidDrop(node, slot)) {
|
||||
highlightInput = slot
|
||||
highlightPos = node.getInputSlotPos(slot)
|
||||
linkConnector.overWidget = overWidget
|
||||
}
|
||||
}
|
||||
|
||||
// Not over a valid widget - treat drop on invalid widget same as node background
|
||||
if (!linkConnector.overWidget) {
|
||||
if (inputId === -1 && outputId === -1) {
|
||||
// Node background / title under the pointer
|
||||
if (!linkConnector.overWidget) {
|
||||
const result = node.findInputByType(firstLink.fromSlot.type)
|
||||
if (result) {
|
||||
highlightInput = result.slot
|
||||
highlightPos = node.getInputPos(result.index)
|
||||
highlightPos = node.getInputSlotPos(result.slot)
|
||||
}
|
||||
} else if (
|
||||
inputId != -1 &&
|
||||
node.inputs[inputId] &&
|
||||
LiteGraph.isValidConnection(firstLink.fromSlot.type, node.inputs[inputId].type)
|
||||
) {
|
||||
highlightPos = pos
|
||||
// XXX CHECK THIS
|
||||
highlightInput = node.inputs[inputId]
|
||||
}
|
||||
|
||||
if (highlightInput) {
|
||||
const widget = node.getWidgetFromSlot(highlightInput)
|
||||
if (widget) linkConnector.overWidget = widget
|
||||
}
|
||||
} else if (
|
||||
inputId != -1 &&
|
||||
node.inputs[inputId] &&
|
||||
LiteGraph.isValidConnection(firstLink.fromSlot.type, node.inputs[inputId].type)
|
||||
) {
|
||||
highlightPos = pos
|
||||
// XXX CHECK THIS
|
||||
highlightInput = node.inputs[inputId]
|
||||
}
|
||||
} else if (linkConnector.state.connectingTo === "output") {
|
||||
// Connecting from an input to an output
|
||||
|
||||
@@ -2900,20 +2900,29 @@ export class LGraphNode implements Positionable, IPinnable, IColorable {
|
||||
* @returns Position of the input slot
|
||||
*/
|
||||
getInputPos(slot: number): Point {
|
||||
const { pos: [nodeX, nodeY], inputs } = this
|
||||
return this.getInputSlotPos(this.inputs[slot])
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the position of an input slot, in graph co-ordinates.
|
||||
* @param input The actual node input object
|
||||
* @returns Position of the centre of the input slot in graph co-ordinates.
|
||||
*/
|
||||
getInputSlotPos(input: INodeInputSlot): Point {
|
||||
const { pos: [nodeX, nodeY] } = this
|
||||
|
||||
if (this.flags.collapsed) {
|
||||
const halfTitle = LiteGraph.NODE_TITLE_HEIGHT * 0.5
|
||||
return [nodeX, nodeY - halfTitle]
|
||||
}
|
||||
|
||||
const inputPos = inputs?.[slot]?.pos
|
||||
if (inputPos) return [nodeX + inputPos[0], nodeY + inputPos[1]]
|
||||
const { pos } = input
|
||||
if (pos) return [nodeX + pos[0], nodeY + pos[1]]
|
||||
|
||||
// default vertical slots
|
||||
const offsetX = LiteGraph.NODE_SLOT_HEIGHT * 0.5
|
||||
const nodeOffsetY = this.constructor.slot_start_y || 0
|
||||
const slotIndex = this.#defaultVerticalInputs.indexOf(this.inputs[slot])
|
||||
const slotIndex = this.#defaultVerticalInputs.indexOf(input)
|
||||
const slotY = (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT
|
||||
|
||||
return [nodeX + offsetX, nodeY + slotY + nodeOffsetY]
|
||||
@@ -3481,8 +3490,8 @@ export class LGraphNode implements Positionable, IPinnable, IColorable {
|
||||
/**
|
||||
* Returns the input slot that is associated with the given widget.
|
||||
*/
|
||||
getSlotFromWidget(widget: IWidget): INodeInputSlot | undefined {
|
||||
return this.inputs.find(slot => isWidgetInputSlot(slot) && slot.widget.name === widget.name)
|
||||
getSlotFromWidget(widget: IWidget | undefined): INodeInputSlot | undefined {
|
||||
if (widget) return this.inputs.find(slot => isWidgetInputSlot(slot) && slot.widget.name === widget.name)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -405,18 +405,11 @@ export class LinkConnector {
|
||||
// To input
|
||||
} else if (connectingTo === "input") {
|
||||
const input = node.getInputOnPos([canvasX, canvasY])
|
||||
const inputOrSocket = input ?? node.getSlotFromWidget(this.overWidget)
|
||||
|
||||
// Input slot
|
||||
if (input) {
|
||||
this.#dropOnInput(node, input)
|
||||
} else if (this.overWidget && renderLinks[0] instanceof ToInputRenderLink) {
|
||||
// Widget
|
||||
this.events.dispatch("dropped-on-widget", {
|
||||
link: renderLinks[0],
|
||||
node,
|
||||
widget: this.overWidget,
|
||||
})
|
||||
this.overWidget = undefined
|
||||
if (inputOrSocket) {
|
||||
this.#dropOnInput(node, inputOrSocket)
|
||||
} else {
|
||||
// Node background / title
|
||||
this.connectToNode(node, event)
|
||||
@@ -570,6 +563,10 @@ export class LinkConnector {
|
||||
}
|
||||
}
|
||||
|
||||
isInputValidDrop(node: LGraphNode, input: INodeInputSlot): boolean {
|
||||
return this.renderLinks.some(link => link.canConnectToInput(node, input))
|
||||
}
|
||||
|
||||
isNodeValidDrop(node: LGraphNode): boolean {
|
||||
if (this.state.connectingTo === "output") {
|
||||
return node.outputs.some(output => this.renderLinks.some(link => link.canConnectToOutput(node, output)))
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import { describe, expect } from "vitest"
|
||||
import type { INodeInputSlot, Point } from "@/interfaces"
|
||||
import type { ISerialisedNode } from "@/types/serialisation"
|
||||
|
||||
import { afterEach, beforeEach, describe, expect, vi } from "vitest"
|
||||
|
||||
import { LGraphNode, LiteGraph } from "@/litegraph"
|
||||
import { LGraph } from "@/litegraph"
|
||||
@@ -7,41 +10,92 @@ import { NodeOutputSlot } from "@/node/NodeOutputSlot"
|
||||
|
||||
import { test } from "./testExtensions"
|
||||
|
||||
function getMockISerialisedNode(data: Partial<ISerialisedNode>): ISerialisedNode {
|
||||
return Object.assign({
|
||||
id: 0,
|
||||
flags: {},
|
||||
type: "TestNode",
|
||||
pos: [100, 100],
|
||||
size: [100, 100],
|
||||
order: 0,
|
||||
mode: 0,
|
||||
}, data)
|
||||
}
|
||||
|
||||
describe("LGraphNode", () => {
|
||||
let node: LGraphNode
|
||||
let origLiteGraph: typeof LiteGraph
|
||||
|
||||
beforeEach(() => {
|
||||
origLiteGraph = Object.assign({}, LiteGraph)
|
||||
|
||||
Object.assign(LiteGraph, {
|
||||
NODE_TITLE_HEIGHT: 20,
|
||||
NODE_SLOT_HEIGHT: 15,
|
||||
NODE_TEXT_SIZE: 14,
|
||||
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
|
||||
DEFAULT_GROUP_FONT_SIZE: 24,
|
||||
isValidConnection: vi.fn().mockReturnValue(true),
|
||||
})
|
||||
node = new LGraphNode("Test Node")
|
||||
node.pos = [100, 200]
|
||||
node.size = [150, 100] // Example size
|
||||
|
||||
// Reset mocks if needed
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
Object.assign(LiteGraph, origLiteGraph)
|
||||
})
|
||||
|
||||
test("should serialize position/size correctly", () => {
|
||||
const node = new LGraphNode("TestNode")
|
||||
node.pos = [10, 10]
|
||||
expect(node.pos).toEqual(new Float32Array([10, 10]))
|
||||
expect(node.serialize().pos).toEqual([10, 10])
|
||||
node.pos = [10, 20]
|
||||
node.size = [30, 40]
|
||||
const json = node.serialize()
|
||||
expect(json.pos).toEqual([10, 20])
|
||||
expect(json.size).toEqual([30, 40])
|
||||
|
||||
node.size = [100, 100]
|
||||
expect(node.size).toEqual(new Float32Array([100, 100]))
|
||||
expect(node.serialize().size).toEqual([100, 100])
|
||||
const configureData: ISerialisedNode = {
|
||||
id: node.id,
|
||||
type: node.type,
|
||||
pos: [50, 60],
|
||||
size: [70, 80],
|
||||
flags: {},
|
||||
order: node.order,
|
||||
mode: node.mode,
|
||||
inputs: node.inputs?.map(i => ({ name: i.name, type: i.type, link: i.link })),
|
||||
outputs: node.outputs?.map(o => ({ name: o.name, type: o.type, links: o.links, slot_index: o.slot_index })),
|
||||
}
|
||||
node.configure(configureData)
|
||||
expect(node.pos).toEqual(new Float32Array([50, 60]))
|
||||
expect(node.size).toEqual(new Float32Array([70, 80]))
|
||||
})
|
||||
|
||||
test("should configure inputs correctly", () => {
|
||||
const node = new LGraphNode("TestNode")
|
||||
node.configure({
|
||||
node.configure(getMockISerialisedNode({
|
||||
id: 0,
|
||||
inputs: [{ name: "TestInput", type: "number", link: null }],
|
||||
})
|
||||
}))
|
||||
expect(node.inputs.length).toEqual(1)
|
||||
expect(node.inputs[0].name).toEqual("TestInput")
|
||||
expect(node.inputs[0].link).toEqual(null)
|
||||
expect(node.inputs[0]).instanceOf(NodeInputSlot)
|
||||
|
||||
// Should not override existing inputs
|
||||
node.configure({ id: 1 })
|
||||
node.configure(getMockISerialisedNode({ id: 1 }))
|
||||
expect(node.id).toEqual(1)
|
||||
expect(node.inputs.length).toEqual(1)
|
||||
})
|
||||
|
||||
test("should configure outputs correctly", () => {
|
||||
const node = new LGraphNode("TestNode")
|
||||
node.configure({
|
||||
node.configure(getMockISerialisedNode({
|
||||
id: 0,
|
||||
outputs: [{ name: "TestOutput", type: "number", links: [] }],
|
||||
})
|
||||
}))
|
||||
expect(node.outputs.length).toEqual(1)
|
||||
expect(node.outputs[0].name).toEqual("TestOutput")
|
||||
expect(node.outputs[0].type).toEqual("number")
|
||||
@@ -49,7 +103,7 @@ describe("LGraphNode", () => {
|
||||
expect(node.outputs[0]).instanceOf(NodeOutputSlot)
|
||||
|
||||
// Should not override existing outputs
|
||||
node.configure({ id: 1 })
|
||||
node.configure(getMockISerialisedNode({ id: 1 }))
|
||||
expect(node.id).toEqual(1)
|
||||
expect(node.outputs.length).toEqual(1)
|
||||
})
|
||||
@@ -60,14 +114,14 @@ describe("LGraphNode", () => {
|
||||
const node2 = new LGraphNode("TargetNode")
|
||||
|
||||
// Configure nodes with input/output slots
|
||||
node1.configure({
|
||||
node1.configure(getMockISerialisedNode({
|
||||
id: 1,
|
||||
outputs: [{ name: "Output1", type: "number", links: [] }],
|
||||
})
|
||||
node2.configure({
|
||||
}))
|
||||
node2.configure(getMockISerialisedNode({
|
||||
id: 2,
|
||||
inputs: [{ name: "Input1", type: "number", link: null }],
|
||||
})
|
||||
}))
|
||||
|
||||
// Create a graph and add nodes to it
|
||||
const graph = new LGraph()
|
||||
@@ -108,21 +162,21 @@ describe("LGraphNode", () => {
|
||||
const targetNode2 = new LGraphNode("TargetNode2")
|
||||
|
||||
// Configure nodes with input/output slots
|
||||
sourceNode.configure({
|
||||
sourceNode.configure(getMockISerialisedNode({
|
||||
id: 1,
|
||||
outputs: [
|
||||
{ name: "Output1", type: "number", links: [] },
|
||||
{ name: "Output2", type: "number", links: [] },
|
||||
],
|
||||
})
|
||||
targetNode1.configure({
|
||||
}))
|
||||
targetNode1.configure(getMockISerialisedNode({
|
||||
id: 2,
|
||||
inputs: [{ name: "Input1", type: "number", link: null }],
|
||||
})
|
||||
targetNode2.configure({
|
||||
}))
|
||||
targetNode2.configure(getMockISerialisedNode({
|
||||
id: 3,
|
||||
inputs: [{ name: "Input1", type: "number", link: null }],
|
||||
})
|
||||
}))
|
||||
|
||||
// Create a graph and add nodes to it
|
||||
const graph = new LGraph()
|
||||
@@ -184,11 +238,11 @@ describe("LGraphNode", () => {
|
||||
node.boundingRect[1] = 100
|
||||
node.boundingRect[2] = 100
|
||||
node.boundingRect[3] = 100
|
||||
node.configure({
|
||||
node.configure(getMockISerialisedNode({
|
||||
id: 1,
|
||||
inputs: [{ name: "Input1", type: "number", link: null }],
|
||||
outputs: [{ name: "Output1", type: "number", links: [] }],
|
||||
})
|
||||
}))
|
||||
|
||||
// Collapse the node
|
||||
node.flags.collapsed = true
|
||||
@@ -197,25 +251,25 @@ describe("LGraphNode", () => {
|
||||
const inputPos = node.getInputPos(0)
|
||||
const outputPos = node.getOutputPos(0)
|
||||
|
||||
expect(inputPos).toEqual([100, 85])
|
||||
expect(outputPos).toEqual([180, 85])
|
||||
expect(inputPos).toEqual([100, 90])
|
||||
expect(outputPos).toEqual([180, 90])
|
||||
})
|
||||
|
||||
test("should return correct positions for input and output slots", () => {
|
||||
const node = new LGraphNode("TestNode")
|
||||
node.pos = [100, 100]
|
||||
node.size = [100, 100]
|
||||
node.configure({
|
||||
node.configure(getMockISerialisedNode({
|
||||
id: 1,
|
||||
inputs: [{ name: "Input1", type: "number", link: null }],
|
||||
outputs: [{ name: "Output1", type: "number", links: [] }],
|
||||
})
|
||||
}))
|
||||
|
||||
const inputPos = node.getInputPos(0)
|
||||
const outputPos = node.getOutputPos(0)
|
||||
|
||||
expect(inputPos).toEqual([110, 114])
|
||||
expect(outputPos).toEqual([191, 114])
|
||||
expect(inputPos).toEqual([107.5, 110.5])
|
||||
expect(outputPos).toEqual([193.5, 110.5])
|
||||
})
|
||||
})
|
||||
|
||||
@@ -224,11 +278,11 @@ describe("LGraphNode", () => {
|
||||
const node = new LGraphNode("TestNode")
|
||||
node.pos = [100, 100]
|
||||
node.size = [100, 100]
|
||||
node.configure({
|
||||
node.configure(getMockISerialisedNode({
|
||||
id: 1,
|
||||
inputs: [{ name: "Input1", type: "number", link: null }],
|
||||
outputs: [{ name: "Output1", type: "number", links: [] }],
|
||||
})
|
||||
}))
|
||||
|
||||
// Test point far outside node bounds
|
||||
expect(node.getSlotOnPos([0, 0])).toBeUndefined()
|
||||
@@ -244,13 +298,13 @@ describe("LGraphNode", () => {
|
||||
node.boundingRect[1] = 100
|
||||
node.boundingRect[2] = 200
|
||||
node.boundingRect[3] = 200
|
||||
node.configure({
|
||||
node.configure(getMockISerialisedNode({
|
||||
id: 1,
|
||||
inputs: [
|
||||
{ name: "Input1", type: "number", link: null },
|
||||
{ name: "Input2", type: "string", link: null },
|
||||
],
|
||||
})
|
||||
}))
|
||||
|
||||
// Get position of first input slot
|
||||
const inputPos = node.getInputPos(0)
|
||||
@@ -271,13 +325,13 @@ describe("LGraphNode", () => {
|
||||
node.boundingRect[1] = 100
|
||||
node.boundingRect[2] = 200
|
||||
node.boundingRect[3] = 200
|
||||
node.configure({
|
||||
node.configure(getMockISerialisedNode({
|
||||
id: 1,
|
||||
outputs: [
|
||||
{ name: "Output1", type: "number", links: [] },
|
||||
{ name: "Output2", type: "string", links: [] },
|
||||
],
|
||||
})
|
||||
}))
|
||||
|
||||
// Get position of first output slot
|
||||
const outputPos = node.getOutputPos(0)
|
||||
@@ -299,11 +353,11 @@ describe("LGraphNode", () => {
|
||||
node.boundingRect[1] = 100
|
||||
node.boundingRect[2] = 200
|
||||
node.boundingRect[3] = 200
|
||||
node.configure({
|
||||
node.configure(getMockISerialisedNode({
|
||||
id: 1,
|
||||
inputs: [{ name: "Input1", type: "number", link: null }],
|
||||
outputs: [{ name: "Output1", type: "number", links: [] }],
|
||||
})
|
||||
}))
|
||||
|
||||
// Get positions of first input and output slots
|
||||
const inputPos = node.getInputPos(0)
|
||||
@@ -437,20 +491,86 @@ describe("LGraphNode", () => {
|
||||
expect(node.widgets?.length).toBe(2)
|
||||
|
||||
node.widgets![0].serialize = false
|
||||
node.configure({
|
||||
node.configure(getMockISerialisedNode({
|
||||
id: 1,
|
||||
type: "TestNode",
|
||||
pos: [100, 100],
|
||||
size: [100, 100],
|
||||
flags: {},
|
||||
properties: {},
|
||||
order: 0,
|
||||
mode: 0,
|
||||
widgets_values: [100],
|
||||
})
|
||||
}))
|
||||
|
||||
expect(node.widgets![0].value).toBe(1)
|
||||
expect(node.widgets![1].value).toBe(100)
|
||||
})
|
||||
})
|
||||
|
||||
describe("getInputSlotPos", () => {
|
||||
let inputSlot: INodeInputSlot
|
||||
|
||||
beforeEach(() => {
|
||||
inputSlot = { name: "test_in", type: "string", link: null, boundingRect: new Float32Array([0, 0, 0, 0]) }
|
||||
})
|
||||
test("should return position based on title height when collapsed", () => {
|
||||
node.flags.collapsed = true
|
||||
const expectedPos: Point = [100, 200 - LiteGraph.NODE_TITLE_HEIGHT * 0.5]
|
||||
expect(node.getInputSlotPos(inputSlot)).toEqual(expectedPos)
|
||||
})
|
||||
|
||||
test("should return position based on input.pos when defined and not collapsed", () => {
|
||||
node.flags.collapsed = false
|
||||
inputSlot.pos = [10, 50]
|
||||
node.inputs = [inputSlot]
|
||||
const expectedPos: Point = [100 + 10, 200 + 50]
|
||||
expect(node.getInputSlotPos(inputSlot)).toEqual(expectedPos)
|
||||
})
|
||||
|
||||
test("should return default vertical position when input.pos is undefined and not collapsed", () => {
|
||||
node.flags.collapsed = false
|
||||
const inputSlot2 = { name: "test_in_2", type: "number", link: null, boundingRect: new Float32Array([0, 0, 0, 0]) }
|
||||
node.inputs = [inputSlot, inputSlot2]
|
||||
const slotIndex = 0
|
||||
const nodeOffsetY = (node.constructor as any).slot_start_y || 0
|
||||
const expectedY = 200 + (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY
|
||||
const expectedX = 100 + LiteGraph.NODE_SLOT_HEIGHT * 0.5
|
||||
expect(node.getInputSlotPos(inputSlot)).toEqual([expectedX, expectedY])
|
||||
const slotIndex2 = 1
|
||||
const expectedY2 = 200 + (slotIndex2 + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY
|
||||
expect(node.getInputSlotPos(inputSlot2)).toEqual([expectedX, expectedY2])
|
||||
})
|
||||
|
||||
test("should return default vertical position including slot_start_y when defined", () => {
|
||||
(node.constructor as any).slot_start_y = 25
|
||||
node.flags.collapsed = false
|
||||
node.inputs = [inputSlot]
|
||||
const slotIndex = 0
|
||||
const nodeOffsetY = 25
|
||||
const expectedY = 200 + (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY
|
||||
const expectedX = 100 + LiteGraph.NODE_SLOT_HEIGHT * 0.5
|
||||
expect(node.getInputSlotPos(inputSlot)).toEqual([expectedX, expectedY])
|
||||
delete (node.constructor as any).slot_start_y
|
||||
})
|
||||
})
|
||||
|
||||
describe("getInputPos", () => {
|
||||
test("should call getInputSlotPos with the correct input slot from inputs array", () => {
|
||||
const input0: INodeInputSlot = { name: "in0", type: "string", link: null, boundingRect: new Float32Array([0, 0, 0, 0]) }
|
||||
const input1: INodeInputSlot = { name: "in1", type: "number", link: null, boundingRect: new Float32Array([0, 0, 0, 0]), pos: [5, 45] }
|
||||
node.inputs = [input0, input1]
|
||||
const spy = vi.spyOn(node, "getInputSlotPos")
|
||||
node.getInputPos(1)
|
||||
expect(spy).toHaveBeenCalledWith(input1)
|
||||
const expectedPos: Point = [100 + 5, 200 + 45]
|
||||
expect(node.getInputPos(1)).toEqual(expectedPos)
|
||||
spy.mockClear()
|
||||
node.getInputPos(0)
|
||||
expect(spy).toHaveBeenCalledWith(input0)
|
||||
const slotIndex = 0
|
||||
const nodeOffsetY = (node.constructor as any).slot_start_y || 0
|
||||
const expectedDefaultY = 200 + (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY
|
||||
const expectedDefaultX = 100 + LiteGraph.NODE_SLOT_HEIGHT * 0.5
|
||||
expect(node.getInputPos(0)).toEqual([expectedDefaultX, expectedDefaultY])
|
||||
spy.mockRestore()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
131
test/canvas/LinkConnector.test.ts
Normal file
131
test/canvas/LinkConnector.test.ts
Normal file
@@ -0,0 +1,131 @@
|
||||
import type { INodeInputSlot, LGraphNode } from "@/litegraph"
|
||||
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest"
|
||||
|
||||
// We don't strictly need RenderLink interface import for the mock
|
||||
import { LinkConnector } from "@/canvas/LinkConnector"
|
||||
|
||||
// Mocks
|
||||
const mockSetConnectingLinks = vi.fn()
|
||||
|
||||
// Mock a structure that has the needed method
|
||||
function mockRenderLinkImpl(canConnect: boolean) {
|
||||
return {
|
||||
canConnectToInput: vi.fn().mockReturnValue(canConnect),
|
||||
// Add other properties if they become necessary for tests
|
||||
}
|
||||
}
|
||||
|
||||
const mockNode = {} as LGraphNode
|
||||
const mockInput = {} as INodeInputSlot
|
||||
|
||||
describe("LinkConnector", () => {
|
||||
let connector: LinkConnector
|
||||
|
||||
beforeEach(() => {
|
||||
connector = new LinkConnector(mockSetConnectingLinks)
|
||||
// Clear the array directly before each test
|
||||
connector.renderLinks.length = 0
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe("isInputValidDrop", () => {
|
||||
test("should return false if there are no render links", () => {
|
||||
expect(connector.isInputValidDrop(mockNode, mockInput)).toBe(false)
|
||||
})
|
||||
|
||||
test("should return true if at least one render link can connect", () => {
|
||||
const link1 = mockRenderLinkImpl(false)
|
||||
const link2 = mockRenderLinkImpl(true)
|
||||
// Cast to any to satisfy the push requirement, as we only need the canConnectToInput method
|
||||
connector.renderLinks.push(link1 as any, link2 as any)
|
||||
expect(connector.isInputValidDrop(mockNode, mockInput)).toBe(true)
|
||||
expect(link1.canConnectToInput).toHaveBeenCalledWith(mockNode, mockInput)
|
||||
expect(link2.canConnectToInput).toHaveBeenCalledWith(mockNode, mockInput)
|
||||
})
|
||||
|
||||
test("should return false if no render links can connect", () => {
|
||||
const link1 = mockRenderLinkImpl(false)
|
||||
const link2 = mockRenderLinkImpl(false)
|
||||
connector.renderLinks.push(link1 as any, link2 as any)
|
||||
expect(connector.isInputValidDrop(mockNode, mockInput)).toBe(false)
|
||||
expect(link1.canConnectToInput).toHaveBeenCalledWith(mockNode, mockInput)
|
||||
expect(link2.canConnectToInput).toHaveBeenCalledWith(mockNode, mockInput)
|
||||
})
|
||||
|
||||
test("should call canConnectToInput on each render link until one returns true", () => {
|
||||
const link1 = mockRenderLinkImpl(false)
|
||||
const link2 = mockRenderLinkImpl(true) // This one can connect
|
||||
const link3 = mockRenderLinkImpl(false)
|
||||
connector.renderLinks.push(link1 as any, link2 as any, link3 as any)
|
||||
|
||||
expect(connector.isInputValidDrop(mockNode, mockInput)).toBe(true)
|
||||
|
||||
expect(link1.canConnectToInput).toHaveBeenCalledTimes(1)
|
||||
expect(link2.canConnectToInput).toHaveBeenCalledTimes(1) // Stops here
|
||||
expect(link3.canConnectToInput).not.toHaveBeenCalled() // Should not be called
|
||||
})
|
||||
})
|
||||
|
||||
describe("listenUntilReset", () => {
|
||||
test("should add listener for the specified event and for reset", () => {
|
||||
const listener = vi.fn()
|
||||
const addEventListenerSpy = vi.spyOn(connector.events, "addEventListener")
|
||||
|
||||
connector.listenUntilReset("before-drop-links", listener)
|
||||
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith("before-drop-links", listener, undefined)
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith("reset", expect.any(Function), { once: true })
|
||||
})
|
||||
|
||||
test("should call the listener when the event is dispatched before reset", () => {
|
||||
const listener = vi.fn()
|
||||
const eventData = { renderLinks: [], event: {} as any } // Mock event data
|
||||
connector.listenUntilReset("before-drop-links", listener)
|
||||
|
||||
connector.events.dispatch("before-drop-links", eventData)
|
||||
|
||||
expect(listener).toHaveBeenCalledTimes(1)
|
||||
expect(listener).toHaveBeenCalledWith(new CustomEvent("before-drop-links"))
|
||||
})
|
||||
|
||||
test("should remove the listener when reset is dispatched", () => {
|
||||
const listener = vi.fn()
|
||||
const removeEventListenerSpy = vi.spyOn(connector.events, "removeEventListener")
|
||||
|
||||
connector.listenUntilReset("before-drop-links", listener)
|
||||
|
||||
// Simulate the reset event being dispatched
|
||||
connector.events.dispatch("reset", false)
|
||||
|
||||
// Check if removeEventListener was called correctly for the original listener
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledWith("before-drop-links", listener)
|
||||
})
|
||||
|
||||
test("should not call the listener after reset is dispatched", () => {
|
||||
const listener = vi.fn()
|
||||
const eventData = { renderLinks: [], event: {} as any }
|
||||
connector.listenUntilReset("before-drop-links", listener)
|
||||
|
||||
// Dispatch reset first
|
||||
connector.events.dispatch("reset", false)
|
||||
|
||||
// Then dispatch the original event
|
||||
connector.events.dispatch("before-drop-links", eventData)
|
||||
|
||||
expect(listener).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
test("should pass options to addEventListener", () => {
|
||||
const listener = vi.fn()
|
||||
const options = { once: true }
|
||||
const addEventListenerSpy = vi.spyOn(connector.events, "addEventListener")
|
||||
|
||||
connector.listenUntilReset("after-drop-links", listener, options)
|
||||
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith("after-drop-links", listener, options)
|
||||
// Still adds the reset listener
|
||||
expect(addEventListenerSpy).toHaveBeenCalledWith("reset", expect.any(Function), { once: true })
|
||||
})
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user