import { start, makeNodeDef, checkBeforeAndAfterReload, assertNotNullOrUndefined, createDefaultWorkflow } from '../utils' import lg from '../utils/litegraph' /** * @typedef { import("../utils/ezgraph") } Ez * @typedef { ReturnType["ez"] } EzNodeFactory */ /** * @param { EzNodeFactory } ez * @param { InstanceType } graph * @param { InstanceType } input * @param { string } widgetType * @param { number } controlWidgetCount * @returns */ async function connectPrimitiveAndReload( ez, graph, input, widgetType, controlWidgetCount = 0 ) { // Connect to primitive and ensure its still connected after let primitive = ez.PrimitiveNode() primitive.outputs[0].connectTo(input) await checkBeforeAndAfterReload(graph, async () => { primitive = graph.find(primitive) let { connections } = primitive.outputs[0] expect(connections).toHaveLength(1) expect(connections[0].targetNode.id).toBe(input.node.node.id) // Ensure widget is correct type const valueWidget = primitive.widgets.value expect(valueWidget.widget.type).toBe(widgetType) // Check if control_after_generate should be added if (controlWidgetCount) { const controlWidget = primitive.widgets.control_after_generate expect(controlWidget.widget.type).toBe('combo') if (widgetType === 'combo') { const filterWidget = primitive.widgets.control_filter_list expect(filterWidget.widget.type).toBe('string') } } // Ensure we dont have other widgets expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount) }) return primitive } describe('widget inputs', () => { beforeEach(() => { lg.setup(global) }) afterEach(() => { lg.teardown(global) }) ;[ { name: 'int', type: 'INT', widget: 'number', control: 1 }, { name: 'float', type: 'FLOAT', widget: 'number', control: 1 }, { name: 'text', type: 'STRING' }, { name: 'customtext', type: 'STRING', opt: { multiline: true } }, { name: 'toggle', type: 'BOOLEAN' }, { name: 'combo', type: ['a', 'b', 'c'], control: 2 } ].forEach((c) => { test(`widget conversion + primitive works on ${c.name}`, async () => { const { ez, graph } = await start({ mockNodeDefs: makeNodeDef('TestNode', { [c.name]: [c.type, c.opt ?? {}] }) }) // Create test node and convert to input const n = ez.TestNode() const w = n.widgets[c.name] w.convertToInput() expect(w.isConvertedToInput).toBeTruthy() const input = w.getConvertedInput() expect(input).toBeTruthy() await connectPrimitiveAndReload( ez, graph, input, c.widget ?? c.name, c.control ) }) }) test('converted widget works after reload', async () => { const { ez, graph } = await start() let n = ez.CheckpointLoaderSimple() const inputCount = n.inputs.length // Convert ckpt name to an input n.widgets.ckpt_name.convertToInput() expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy() expect(n.inputs.ckpt_name).toBeTruthy() expect(n.inputs.length).toEqual(inputCount + 1) // Convert back to widget and ensure input is removed n.widgets.ckpt_name.convertToWidget() expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy() expect(n.inputs.ckpt_name).toBeFalsy() expect(n.inputs.length).toEqual(inputCount) // Convert again and reload the graph to ensure it maintains state n.widgets.ckpt_name.convertToInput() expect(n.inputs.length).toEqual(inputCount + 1) const primitive = await connectPrimitiveAndReload( ez, graph, n.inputs.ckpt_name, 'combo', 2 ) // Disconnect & reconnect primitive.outputs[0].connections[0].disconnect() let { connections } = primitive.outputs[0] expect(connections).toHaveLength(0) primitive.outputs[0].connectTo(n.inputs.ckpt_name) ;({ connections } = primitive.outputs[0]) expect(connections).toHaveLength(1) expect(connections[0].targetNode.id).toBe(n.node.id) // Convert back to widget and ensure input is removed n.widgets.ckpt_name.convertToWidget() expect(n.widgets.ckpt_name.isConvertedToInput).toBeFalsy() expect(n.inputs.ckpt_name).toBeFalsy() expect(n.inputs.length).toEqual(inputCount) }) test('converted widget works on clone', async () => { const { graph, ez } = await start() let n = ez.CheckpointLoaderSimple() // Convert the widget to an input n.widgets.ckpt_name.convertToInput() expect(n.widgets.ckpt_name.isConvertedToInput).toBeTruthy() // Clone the node n.menu['Clone'].call() expect(graph.nodes).toHaveLength(2) const clone = graph.nodes[1] expect(clone.id).not.toEqual(n.id) // Ensure the clone has an input expect(clone.widgets.ckpt_name.isConvertedToInput).toBeTruthy() expect(clone.inputs.ckpt_name).toBeTruthy() // Ensure primitive connects to both nodes let primitive = ez.PrimitiveNode() primitive.outputs[0].connectTo(n.inputs.ckpt_name) primitive.outputs[0].connectTo(clone.inputs.ckpt_name) expect(primitive.outputs[0].connections).toHaveLength(2) // Convert back to widget and ensure input is removed clone.widgets.ckpt_name.convertToWidget() expect(clone.widgets.ckpt_name.isConvertedToInput).toBeFalsy() expect(clone.inputs.ckpt_name).toBeFalsy() }) // Invalid workflow against zod schema now. test.skip('shows missing node error on custom node with converted input', async () => { const { graph } = await start() const dialogShow = jest.spyOn(graph.app.ui.dialog, 'show') await graph.app.loadGraphData({ last_node_id: 3, last_link_id: 4, nodes: [ { id: 1, type: 'TestNode', pos: [41.87329101561909, 389.7381480823742], size: { 0: 220, 1: 374 }, flags: {}, order: 1, mode: 0, inputs: [ { name: 'test', type: 'FLOAT', link: 4, widget: { name: 'test' }, slot_index: 0 } ], outputs: [], properties: { 'Node name for S&R': 'TestNode' }, widgets_values: [1] }, { id: 3, type: 'PrimitiveNode', pos: [-312, 433], size: { 0: 210, 1: 82 }, flags: {}, order: 0, mode: 0, // Missing name and type outputs: [{ links: [4], widget: { name: 'test' } }], title: 'test', properties: {} } ], links: [[4, 3, 0, 1, 6, 'FLOAT']], groups: [], config: {}, extra: {}, version: 0.4 }) expect(dialogShow).toBeCalledTimes(1) // @ts-expect-error expect(dialogShow.mock.calls[0][0].innerHTML).toContain( 'the following node types were not found' ) // @ts-expect-error expect(dialogShow.mock.calls[0][0].innerHTML).toContain('TestNode') }) test('defaultInput widgets can be converted back to inputs', async () => { const { graph, ez } = await start({ mockNodeDefs: makeNodeDef('TestNode', { example: ['INT', { defaultInput: true }] }) }) // Create test node and ensure it starts as an input let n = ez.TestNode() let w = n.widgets.example expect(w.isConvertedToInput).toBeTruthy() let input = w.getConvertedInput() expect(input).toBeTruthy() // Ensure it can be converted to w.convertToWidget() expect(w.isConvertedToInput).toBeFalsy() expect(n.inputs.length).toEqual(0) // and from w.convertToInput() expect(w.isConvertedToInput).toBeTruthy() input = w.getConvertedInput() // Reload and ensure it still only has 1 converted widget if (!assertNotNullOrUndefined(input)) return await connectPrimitiveAndReload(ez, graph, input, 'number', 1) n = graph.find(n) expect(n.widgets).toHaveLength(1) w = n.widgets.example expect(w.isConvertedToInput).toBeTruthy() // Convert back to widget and ensure it is still a widget after reload w.convertToWidget() await graph.reload() n = graph.find(n) expect(n.widgets).toHaveLength(1) expect(n.widgets[0].isConvertedToInput).toBeFalsy() expect(n.inputs.length).toEqual(0) }) test('forceInput widgets can not be converted back to inputs', async () => { const { graph, ez } = await start({ mockNodeDefs: makeNodeDef('TestNode', { example: ['INT', { forceInput: true }] }) }) // Create test node and ensure it starts as an input let n = ez.TestNode() let w = n.widgets.example expect(w.isConvertedToInput).toBeTruthy() const input = w.getConvertedInput() expect(input).toBeTruthy() // Convert to widget should error expect(() => w.convertToWidget()).toThrow() // Reload and ensure it still only has 1 converted widget if (assertNotNullOrUndefined(input)) { await connectPrimitiveAndReload(ez, graph, input, 'number', 1) n = graph.find(n) expect(n.widgets).toHaveLength(1) expect(n.widgets.example.isConvertedToInput).toBeTruthy() } }) test('primitive can connect to matching combos on converted widgets', async () => { const { ez } = await start({ mockNodeDefs: { ...makeNodeDef('TestNode1', { example: [['A', 'B', 'C'], { forceInput: true }] }), ...makeNodeDef('TestNode2', { example: [['A', 'B', 'C'], { forceInput: true }] }) } }) const n1 = ez.TestNode1() const n2 = ez.TestNode2() const p = ez.PrimitiveNode() p.outputs[0].connectTo(n1.inputs[0]) p.outputs[0].connectTo(n2.inputs[0]) expect(p.outputs[0].connections).toHaveLength(2) const valueWidget = p.widgets.value expect(valueWidget.widget.type).toBe('combo') expect(valueWidget.widget.options.values).toEqual(['A', 'B', 'C']) }) test('primitive can not connect to non matching combos on converted widgets', async () => { const { ez } = await start({ mockNodeDefs: { ...makeNodeDef('TestNode1', { example: [['A', 'B', 'C'], { forceInput: true }] }), ...makeNodeDef('TestNode2', { example: [['A', 'B'], { forceInput: true }] }) } }) const n1 = ez.TestNode1() const n2 = ez.TestNode2() const p = ez.PrimitiveNode() p.outputs[0].connectTo(n1.inputs[0]) expect(() => p.outputs[0].connectTo(n2.inputs[0])).toThrow() expect(p.outputs[0].connections).toHaveLength(1) }) test('combo output can not connect to non matching combos list input', async () => { const { ez } = await start({ mockNodeDefs: { ...makeNodeDef('TestNode1', {}, [['A', 'B']]), ...makeNodeDef('TestNode2', { example: [['A', 'B'], { forceInput: true }] }), ...makeNodeDef('TestNode3', { example: [['A', 'B', 'C'], { forceInput: true }] }) } }) const n1 = ez.TestNode1() const n2 = ez.TestNode2() const n3 = ez.TestNode3() n1.outputs[0].connectTo(n2.inputs[0]) expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow() }) test('combo primitive can filter list when control_after_generate called', async () => { const { ez } = await start({ mockNodeDefs: { ...makeNodeDef('TestNode1', { example: [ ['A', 'B', 'C', 'D', 'AA', 'BB', 'CC', 'DD', 'AAA', 'BBB'], {} ] }) } }) const n1 = ez.TestNode1() n1.widgets.example.convertToInput() const p = ez.PrimitiveNode() p.outputs[0].connectTo(n1.inputs[0]) const value = p.widgets.value const control = p.widgets.control_after_generate.widget const filter = p.widgets.control_filter_list expect(p.widgets.length).toBe(3) control.value = 'increment' expect(value.value).toBe('A') // Manually trigger after queue when set to increment control['afterQueued']() expect(value.value).toBe('B') // Filter to items containing D filter.value = 'D' control['afterQueued']() expect(value.value).toBe('D') control['afterQueued']() expect(value.value).toBe('DD') // Check decrement value.value = 'BBB' control.value = 'decrement' filter.value = 'B' control['afterQueued']() expect(value.value).toBe('BB') control['afterQueued']() expect(value.value).toBe('B') // Check regex works value.value = 'BBB' filter.value = '/[AB]|^C$/' control['afterQueued']() expect(value.value).toBe('AAA') control['afterQueued']() expect(value.value).toBe('BB') control['afterQueued']() expect(value.value).toBe('AA') control['afterQueued']() expect(value.value).toBe('C') control['afterQueued']() expect(value.value).toBe('B') control['afterQueued']() expect(value.value).toBe('A') // Check random control.value = 'randomize' filter.value = '/D/' for (let i = 0; i < 100; i++) { control['afterQueued']() expect(value.value === 'D' || value.value === 'DD').toBeTruthy() } // Ensure it doesnt apply when fixed control.value = 'fixed' value.value = 'B' filter.value = 'C' control['afterQueued']() expect(value.value).toBe('B') }) describe('reroutes', () => { async function checkOutput(graph, values) { expect((await graph.toPrompt()).output).toStrictEqual({ 1: { inputs: { ckpt_name: 'model1.safetensors' }, class_type: 'CheckpointLoaderSimple' }, 2: { inputs: { text: 'positive', clip: ['1', 1] }, class_type: 'CLIPTextEncode' }, 3: { inputs: { text: 'negative', clip: ['1', 1] }, class_type: 'CLIPTextEncode' }, 4: { inputs: { width: values.width ?? 512, height: values.height ?? 512, batch_size: values?.batch_size ?? 1 }, class_type: 'EmptyLatentImage' }, 5: { inputs: { seed: 0, steps: 20, cfg: 8, sampler_name: 'euler', scheduler: values?.scheduler ?? 'normal', denoise: 1, model: ['1', 0], positive: ['2', 0], negative: ['3', 0], latent_image: ['4', 0] }, class_type: 'KSampler' }, 6: { inputs: { samples: ['5', 0], vae: ['1', 2] }, class_type: 'VAEDecode' }, 7: { inputs: { filename_prefix: values.filename_prefix ?? 'ComfyUI', images: ['6', 0] }, class_type: 'SaveImage' } }) } async function waitForWidget(node) { // widgets are created slightly after the graph is ready // hard to find an exact hook to get these so just wait for them to be ready for (let i = 0; i < 10; i++) { await new Promise((r) => setTimeout(r, 10)) if (node.widgets?.value) { return } } } it('can connect primitive via a reroute path to a widget input', async () => { const { ez, graph } = await start() const nodes = createDefaultWorkflow(ez, graph) nodes.empty.widgets.width.convertToInput() nodes.sampler.widgets.scheduler.convertToInput() nodes.save.widgets.filename_prefix.convertToInput() let widthReroute = ez.Reroute() let schedulerReroute = ez.Reroute() let fileReroute = ez.Reroute() let widthNext = widthReroute let schedulerNext = schedulerReroute let fileNext = fileReroute for (let i = 0; i < 5; i++) { let next = ez.Reroute() widthNext.outputs[0].connectTo(next.inputs[0]) widthNext = next next = ez.Reroute() schedulerNext.outputs[0].connectTo(next.inputs[0]) schedulerNext = next next = ez.Reroute() fileNext.outputs[0].connectTo(next.inputs[0]) fileNext = next } widthNext.outputs[0].connectTo(nodes.empty.inputs.width) schedulerNext.outputs[0].connectTo(nodes.sampler.inputs.scheduler) fileNext.outputs[0].connectTo(nodes.save.inputs.filename_prefix) let widthPrimitive = ez.PrimitiveNode() let schedulerPrimitive = ez.PrimitiveNode() let filePrimitive = ez.PrimitiveNode() widthPrimitive.outputs[0].connectTo(widthReroute.inputs[0]) schedulerPrimitive.outputs[0].connectTo(schedulerReroute.inputs[0]) filePrimitive.outputs[0].connectTo(fileReroute.inputs[0]) expect(widthPrimitive.widgets.value.value).toBe(512) widthPrimitive.widgets.value.value = 1024 expect(schedulerPrimitive.widgets.value.value).toBe('normal') schedulerPrimitive.widgets.value.value = 'simple' expect(filePrimitive.widgets.value.value).toBe('ComfyUI') filePrimitive.widgets.value.value = 'ComfyTest' await checkBeforeAndAfterReload(graph, async () => { widthPrimitive = graph.find(widthPrimitive) schedulerPrimitive = graph.find(schedulerPrimitive) filePrimitive = graph.find(filePrimitive) await waitForWidget(filePrimitive) expect(widthPrimitive.widgets.length).toBe(2) expect(schedulerPrimitive.widgets.length).toBe(3) expect(filePrimitive.widgets.length).toBe(1) await checkOutput(graph, { width: 1024, scheduler: 'simple', filename_prefix: 'ComfyTest' }) }) }) it('can connect primitive via a reroute path to multiple widget inputs', async () => { const { ez, graph } = await start() const nodes = createDefaultWorkflow(ez, graph) nodes.empty.widgets.width.convertToInput() nodes.empty.widgets.height.convertToInput() nodes.empty.widgets.batch_size.convertToInput() let reroute = ez.Reroute() let prevReroute = reroute for (let i = 0; i < 5; i++) { const next = ez.Reroute() prevReroute.outputs[0].connectTo(next.inputs[0]) prevReroute = next } const r1 = ez.Reroute(prevReroute.outputs[0]) const r2 = ez.Reroute(prevReroute.outputs[0]) const r3 = ez.Reroute(r2.outputs[0]) const r4 = ez.Reroute(r2.outputs[0]) r1.outputs[0].connectTo(nodes.empty.inputs.width) r3.outputs[0].connectTo(nodes.empty.inputs.height) r4.outputs[0].connectTo(nodes.empty.inputs.batch_size) let primitive = ez.PrimitiveNode() primitive.outputs[0].connectTo(reroute.inputs[0]) expect(primitive.widgets.value.value).toBe(1) primitive.widgets.value.value = 64 await checkBeforeAndAfterReload(graph, async (r) => { primitive = graph.find(primitive) await waitForWidget(primitive) // Ensure widget configs are merged expect(primitive.widgets.value.widget.options?.min).toBe(16) // width/height min expect(primitive.widgets.value.widget.options?.max).toBe(4096) // batch max expect(primitive.widgets.value.widget.options?.step).toBe(80) // width/height step * 10 await checkOutput(graph, { width: 64, height: 64, batch_size: 64 }) }) }) it('primitive maintains size on reload', async () => { const { ez, graph } = await start() const primitive = ez.PrimitiveNode() const image = ez.EmptyImage() image.widgets.width.convertToInput() primitive.outputs[0].connectTo(image.inputs.width) primitive.node.size = [999, 999] await checkBeforeAndAfterReload(graph, async (r) => { const { node } = graph.find(primitive) expect(node.size[0]).toBe(999) expect(node.size[1]).toBe(999) }) }) }) })