Fix widget snap to work with input sockets (#1017)

This commit is contained in:
filtered
2025-05-07 02:02:49 +10:00
committed by GitHub
parent c7c7547454
commit df6e4debb5
5 changed files with 345 additions and 71 deletions

View File

@@ -1,4 +1,7 @@
import { describe, expect } from "vitest"
import type { INodeInputSlot, Point } from "@/interfaces"
import type { ISerialisedNode } from "@/types/serialisation"
import { afterEach, beforeEach, describe, expect, vi } from "vitest"
import { LGraphNode, LiteGraph } from "@/litegraph"
import { LGraph } from "@/litegraph"
@@ -7,41 +10,92 @@ import { NodeOutputSlot } from "@/node/NodeOutputSlot"
import { test } from "./testExtensions"
function getMockISerialisedNode(data: Partial<ISerialisedNode>): ISerialisedNode {
return Object.assign({
id: 0,
flags: {},
type: "TestNode",
pos: [100, 100],
size: [100, 100],
order: 0,
mode: 0,
}, data)
}
describe("LGraphNode", () => {
let node: LGraphNode
let origLiteGraph: typeof LiteGraph
beforeEach(() => {
origLiteGraph = Object.assign({}, LiteGraph)
Object.assign(LiteGraph, {
NODE_TITLE_HEIGHT: 20,
NODE_SLOT_HEIGHT: 15,
NODE_TEXT_SIZE: 14,
DEFAULT_SHADOW_COLOR: "rgba(0,0,0,0.5)",
DEFAULT_GROUP_FONT_SIZE: 24,
isValidConnection: vi.fn().mockReturnValue(true),
})
node = new LGraphNode("Test Node")
node.pos = [100, 200]
node.size = [150, 100] // Example size
// Reset mocks if needed
vi.clearAllMocks()
})
afterEach(() => {
Object.assign(LiteGraph, origLiteGraph)
})
test("should serialize position/size correctly", () => {
const node = new LGraphNode("TestNode")
node.pos = [10, 10]
expect(node.pos).toEqual(new Float32Array([10, 10]))
expect(node.serialize().pos).toEqual([10, 10])
node.pos = [10, 20]
node.size = [30, 40]
const json = node.serialize()
expect(json.pos).toEqual([10, 20])
expect(json.size).toEqual([30, 40])
node.size = [100, 100]
expect(node.size).toEqual(new Float32Array([100, 100]))
expect(node.serialize().size).toEqual([100, 100])
const configureData: ISerialisedNode = {
id: node.id,
type: node.type,
pos: [50, 60],
size: [70, 80],
flags: {},
order: node.order,
mode: node.mode,
inputs: node.inputs?.map(i => ({ name: i.name, type: i.type, link: i.link })),
outputs: node.outputs?.map(o => ({ name: o.name, type: o.type, links: o.links, slot_index: o.slot_index })),
}
node.configure(configureData)
expect(node.pos).toEqual(new Float32Array([50, 60]))
expect(node.size).toEqual(new Float32Array([70, 80]))
})
test("should configure inputs correctly", () => {
const node = new LGraphNode("TestNode")
node.configure({
node.configure(getMockISerialisedNode({
id: 0,
inputs: [{ name: "TestInput", type: "number", link: null }],
})
}))
expect(node.inputs.length).toEqual(1)
expect(node.inputs[0].name).toEqual("TestInput")
expect(node.inputs[0].link).toEqual(null)
expect(node.inputs[0]).instanceOf(NodeInputSlot)
// Should not override existing inputs
node.configure({ id: 1 })
node.configure(getMockISerialisedNode({ id: 1 }))
expect(node.id).toEqual(1)
expect(node.inputs.length).toEqual(1)
})
test("should configure outputs correctly", () => {
const node = new LGraphNode("TestNode")
node.configure({
node.configure(getMockISerialisedNode({
id: 0,
outputs: [{ name: "TestOutput", type: "number", links: [] }],
})
}))
expect(node.outputs.length).toEqual(1)
expect(node.outputs[0].name).toEqual("TestOutput")
expect(node.outputs[0].type).toEqual("number")
@@ -49,7 +103,7 @@ describe("LGraphNode", () => {
expect(node.outputs[0]).instanceOf(NodeOutputSlot)
// Should not override existing outputs
node.configure({ id: 1 })
node.configure(getMockISerialisedNode({ id: 1 }))
expect(node.id).toEqual(1)
expect(node.outputs.length).toEqual(1)
})
@@ -60,14 +114,14 @@ describe("LGraphNode", () => {
const node2 = new LGraphNode("TargetNode")
// Configure nodes with input/output slots
node1.configure({
node1.configure(getMockISerialisedNode({
id: 1,
outputs: [{ name: "Output1", type: "number", links: [] }],
})
node2.configure({
}))
node2.configure(getMockISerialisedNode({
id: 2,
inputs: [{ name: "Input1", type: "number", link: null }],
})
}))
// Create a graph and add nodes to it
const graph = new LGraph()
@@ -108,21 +162,21 @@ describe("LGraphNode", () => {
const targetNode2 = new LGraphNode("TargetNode2")
// Configure nodes with input/output slots
sourceNode.configure({
sourceNode.configure(getMockISerialisedNode({
id: 1,
outputs: [
{ name: "Output1", type: "number", links: [] },
{ name: "Output2", type: "number", links: [] },
],
})
targetNode1.configure({
}))
targetNode1.configure(getMockISerialisedNode({
id: 2,
inputs: [{ name: "Input1", type: "number", link: null }],
})
targetNode2.configure({
}))
targetNode2.configure(getMockISerialisedNode({
id: 3,
inputs: [{ name: "Input1", type: "number", link: null }],
})
}))
// Create a graph and add nodes to it
const graph = new LGraph()
@@ -184,11 +238,11 @@ describe("LGraphNode", () => {
node.boundingRect[1] = 100
node.boundingRect[2] = 100
node.boundingRect[3] = 100
node.configure({
node.configure(getMockISerialisedNode({
id: 1,
inputs: [{ name: "Input1", type: "number", link: null }],
outputs: [{ name: "Output1", type: "number", links: [] }],
})
}))
// Collapse the node
node.flags.collapsed = true
@@ -197,25 +251,25 @@ describe("LGraphNode", () => {
const inputPos = node.getInputPos(0)
const outputPos = node.getOutputPos(0)
expect(inputPos).toEqual([100, 85])
expect(outputPos).toEqual([180, 85])
expect(inputPos).toEqual([100, 90])
expect(outputPos).toEqual([180, 90])
})
test("should return correct positions for input and output slots", () => {
const node = new LGraphNode("TestNode")
node.pos = [100, 100]
node.size = [100, 100]
node.configure({
node.configure(getMockISerialisedNode({
id: 1,
inputs: [{ name: "Input1", type: "number", link: null }],
outputs: [{ name: "Output1", type: "number", links: [] }],
})
}))
const inputPos = node.getInputPos(0)
const outputPos = node.getOutputPos(0)
expect(inputPos).toEqual([110, 114])
expect(outputPos).toEqual([191, 114])
expect(inputPos).toEqual([107.5, 110.5])
expect(outputPos).toEqual([193.5, 110.5])
})
})
@@ -224,11 +278,11 @@ describe("LGraphNode", () => {
const node = new LGraphNode("TestNode")
node.pos = [100, 100]
node.size = [100, 100]
node.configure({
node.configure(getMockISerialisedNode({
id: 1,
inputs: [{ name: "Input1", type: "number", link: null }],
outputs: [{ name: "Output1", type: "number", links: [] }],
})
}))
// Test point far outside node bounds
expect(node.getSlotOnPos([0, 0])).toBeUndefined()
@@ -244,13 +298,13 @@ describe("LGraphNode", () => {
node.boundingRect[1] = 100
node.boundingRect[2] = 200
node.boundingRect[3] = 200
node.configure({
node.configure(getMockISerialisedNode({
id: 1,
inputs: [
{ name: "Input1", type: "number", link: null },
{ name: "Input2", type: "string", link: null },
],
})
}))
// Get position of first input slot
const inputPos = node.getInputPos(0)
@@ -271,13 +325,13 @@ describe("LGraphNode", () => {
node.boundingRect[1] = 100
node.boundingRect[2] = 200
node.boundingRect[3] = 200
node.configure({
node.configure(getMockISerialisedNode({
id: 1,
outputs: [
{ name: "Output1", type: "number", links: [] },
{ name: "Output2", type: "string", links: [] },
],
})
}))
// Get position of first output slot
const outputPos = node.getOutputPos(0)
@@ -299,11 +353,11 @@ describe("LGraphNode", () => {
node.boundingRect[1] = 100
node.boundingRect[2] = 200
node.boundingRect[3] = 200
node.configure({
node.configure(getMockISerialisedNode({
id: 1,
inputs: [{ name: "Input1", type: "number", link: null }],
outputs: [{ name: "Output1", type: "number", links: [] }],
})
}))
// Get positions of first input and output slots
const inputPos = node.getInputPos(0)
@@ -437,20 +491,86 @@ describe("LGraphNode", () => {
expect(node.widgets?.length).toBe(2)
node.widgets![0].serialize = false
node.configure({
node.configure(getMockISerialisedNode({
id: 1,
type: "TestNode",
pos: [100, 100],
size: [100, 100],
flags: {},
properties: {},
order: 0,
mode: 0,
widgets_values: [100],
})
}))
expect(node.widgets![0].value).toBe(1)
expect(node.widgets![1].value).toBe(100)
})
})
describe("getInputSlotPos", () => {
let inputSlot: INodeInputSlot
beforeEach(() => {
inputSlot = { name: "test_in", type: "string", link: null, boundingRect: new Float32Array([0, 0, 0, 0]) }
})
test("should return position based on title height when collapsed", () => {
node.flags.collapsed = true
const expectedPos: Point = [100, 200 - LiteGraph.NODE_TITLE_HEIGHT * 0.5]
expect(node.getInputSlotPos(inputSlot)).toEqual(expectedPos)
})
test("should return position based on input.pos when defined and not collapsed", () => {
node.flags.collapsed = false
inputSlot.pos = [10, 50]
node.inputs = [inputSlot]
const expectedPos: Point = [100 + 10, 200 + 50]
expect(node.getInputSlotPos(inputSlot)).toEqual(expectedPos)
})
test("should return default vertical position when input.pos is undefined and not collapsed", () => {
node.flags.collapsed = false
const inputSlot2 = { name: "test_in_2", type: "number", link: null, boundingRect: new Float32Array([0, 0, 0, 0]) }
node.inputs = [inputSlot, inputSlot2]
const slotIndex = 0
const nodeOffsetY = (node.constructor as any).slot_start_y || 0
const expectedY = 200 + (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY
const expectedX = 100 + LiteGraph.NODE_SLOT_HEIGHT * 0.5
expect(node.getInputSlotPos(inputSlot)).toEqual([expectedX, expectedY])
const slotIndex2 = 1
const expectedY2 = 200 + (slotIndex2 + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY
expect(node.getInputSlotPos(inputSlot2)).toEqual([expectedX, expectedY2])
})
test("should return default vertical position including slot_start_y when defined", () => {
(node.constructor as any).slot_start_y = 25
node.flags.collapsed = false
node.inputs = [inputSlot]
const slotIndex = 0
const nodeOffsetY = 25
const expectedY = 200 + (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY
const expectedX = 100 + LiteGraph.NODE_SLOT_HEIGHT * 0.5
expect(node.getInputSlotPos(inputSlot)).toEqual([expectedX, expectedY])
delete (node.constructor as any).slot_start_y
})
})
describe("getInputPos", () => {
test("should call getInputSlotPos with the correct input slot from inputs array", () => {
const input0: INodeInputSlot = { name: "in0", type: "string", link: null, boundingRect: new Float32Array([0, 0, 0, 0]) }
const input1: INodeInputSlot = { name: "in1", type: "number", link: null, boundingRect: new Float32Array([0, 0, 0, 0]), pos: [5, 45] }
node.inputs = [input0, input1]
const spy = vi.spyOn(node, "getInputSlotPos")
node.getInputPos(1)
expect(spy).toHaveBeenCalledWith(input1)
const expectedPos: Point = [100 + 5, 200 + 45]
expect(node.getInputPos(1)).toEqual(expectedPos)
spy.mockClear()
node.getInputPos(0)
expect(spy).toHaveBeenCalledWith(input0)
const slotIndex = 0
const nodeOffsetY = (node.constructor as any).slot_start_y || 0
const expectedDefaultY = 200 + (slotIndex + 0.7) * LiteGraph.NODE_SLOT_HEIGHT + nodeOffsetY
const expectedDefaultX = 100 + LiteGraph.NODE_SLOT_HEIGHT * 0.5
expect(node.getInputPos(0)).toEqual([expectedDefaultX, expectedDefaultY])
spy.mockRestore()
})
})
})