import { start } from '../utils' import lg from '../utils/litegraph' describe('extensions', () => { beforeEach(() => { lg.setup(global) }) afterEach(() => { lg.teardown(global) }) it('calls each extension hook', async () => { const mockExtension = { name: 'TestExtension', init: jest.fn(), setup: jest.fn(), addCustomNodeDefs: jest.fn(), getCustomWidgets: jest.fn(), beforeRegisterNodeDef: jest.fn(), registerCustomNodes: jest.fn(), loadedGraphNode: jest.fn(), nodeCreated: jest.fn(), beforeConfigureGraph: jest.fn(), afterConfigureGraph: jest.fn() } const { app, ez, graph } = await start({ async preSetup(app) { app.registerExtension(mockExtension) } }) // Basic initialisation hooks should be called once, with app expect(mockExtension.init).toHaveBeenCalledTimes(1) expect(mockExtension.init).toHaveBeenCalledWith(app) // Adding custom node defs should be passed the full list of nodes expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1) expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app) const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0] expect(defs).toHaveProperty('KSampler') expect(defs).toHaveProperty('LoadImage') // Get custom widgets is called once and should return new widget types expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1) expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app) // Before register node def will be called once per node type const nodeNames = Object.keys(defs) const nodeCount = nodeNames.length expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount) for (let i = 0; i < 10; i++) { // It should be send the JS class and the original JSON definition const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0] const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1] expect(nodeClass.name).toBe('ComfyNode') expect(nodeClass.comfyClass).toBe(nodeNames[i]) expect(nodeDef.name).toBe(nodeNames[i]) expect(nodeDef).toHaveProperty('input') expect(nodeDef).toHaveProperty('output') } // Register custom nodes is called once after registerNode defs to allow adding other frontend nodes expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1) // Before configure graph will be called here as the default graph is being loaded expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1) // it gets sent the graph data that is going to be loaded const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0] // A node created is fired for each node constructor that is called expect(mockExtension.nodeCreated).toHaveBeenCalledTimes( graphData.nodes.length ) for (let i = 0; i < graphData.nodes.length; i++) { expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe( graphData.nodes[i].type ) } // Each node then calls loadedGraphNode to allow them to be updated expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes( graphData.nodes.length ) for (let i = 0; i < graphData.nodes.length; i++) { expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe( graphData.nodes[i].type ) } // After configure is then called once all the setup is done expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1) expect(mockExtension.setup).toHaveBeenCalledTimes(1) expect(mockExtension.setup).toHaveBeenCalledWith(app) // Ensure hooks are called in the correct order const callOrder = [ 'init', 'addCustomNodeDefs', 'getCustomWidgets', 'beforeRegisterNodeDef', 'registerCustomNodes', 'beforeConfigureGraph', 'nodeCreated', 'loadedGraphNode', 'afterConfigureGraph', 'setup' ] for (let i = 1; i < callOrder.length; i++) { const fn1 = mockExtension[callOrder[i - 1]] const fn2 = mockExtension[callOrder[i]] expect(fn1.mock.invocationCallOrder[0]).toBeLessThan( fn2.mock.invocationCallOrder[0] ) } graph.clear() // Ensure adding a new node calls the correct callback ez.LoadImage() expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes( graphData.nodes.length ) expect(mockExtension.nodeCreated).toHaveBeenCalledTimes( graphData.nodes.length + 1 ) expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe('LoadImage') // Reload the graph to ensure correct hooks are fired await graph.reload() // These hooks should not be fired again expect(mockExtension.init).toHaveBeenCalledTimes(1) expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1) expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1) expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1) expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount) expect(mockExtension.setup).toHaveBeenCalledTimes(1) // These should be called again expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2) expect(mockExtension.nodeCreated).toHaveBeenCalledTimes( graphData.nodes.length + 2 ) expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes( graphData.nodes.length + 1 ) expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2) }, 15000) it('allows custom nodeDefs and widgets to be registered', async () => { const widgetMock = jest.fn((node, inputName, inputData, app) => { expect(node.constructor.comfyClass).toBe('TestNode') expect(inputName).toBe('test_input') expect(inputData[0]).toBe('CUSTOMWIDGET') expect(inputData[1]?.hello).toBe('world') expect(app).toStrictEqual(app) return { widget: node.addWidget('button', inputName, 'hello', () => {}) } }) // Register our extension that adds a custom node + widget type const mockExtension = { name: 'TestExtension', addCustomNodeDefs: (nodeDefs) => { nodeDefs['TestNode'] = { output: [], output_name: [], output_is_list: [], name: 'TestNode', display_name: 'TestNode', category: 'Test', input: { required: { test_input: ['CUSTOMWIDGET', { hello: 'world' }] } } } }, getCustomWidgets: jest.fn(() => { return { CUSTOMWIDGET: widgetMock } }) } const { graph, ez } = await start({ async preSetup(app) { app.registerExtension(mockExtension) } }) expect(mockExtension.getCustomWidgets).toBeCalledTimes(1) graph.clear() expect(widgetMock).toBeCalledTimes(0) const node = ez.TestNode() expect(widgetMock).toBeCalledTimes(1) // Ensure our custom widget is created expect(node.inputs.length).toBe(0) expect(node.widgets.length).toBe(1) const w = node.widgets[0].widget expect(w.name).toBe('test_input') expect(w.type).toBe('button') }) })