Manage widget definitions with Pinia store (#1510)

* Fix compile

* nit

* Remove extensions.test

* nit
This commit is contained in:
Chenlei Hu
2024-11-11 17:23:52 -05:00
committed by GitHub
parent 64ef0f18b1
commit 1ff6e27d9c
12 changed files with 111 additions and 269 deletions

View File

@@ -80,12 +80,13 @@ https://github.com/Nuked88/ComfyUI-N-Sidebar/blob/7ae7da4a9761009fb6629bc04c6830
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore' import { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
import { import {
getColorPalette, getColorPalette,
defaultColorPalette defaultColorPalette
} from '@/extensions/core/colorPalette' } from '@/extensions/core/colorPalette'
import _ from 'lodash' import _ from 'lodash'
import { useWidgetStore } from '@/stores/widgetStore'
const props = defineProps({ const props = defineProps({
nodeDef: { nodeDef: {
@@ -101,16 +102,16 @@ const props = defineProps({
const colors = getColorPalette()?.colors?.litegraph_base const colors = getColorPalette()?.colors?.litegraph_base
const litegraphColors = colors ?? defaultColorPalette.colors.litegraph_base const litegraphColors = colors ?? defaultColorPalette.colors.litegraph_base
const nodeDefStore = useNodeDefStore() const widgetStore = useWidgetStore()
const nodeDef = props.nodeDef const nodeDef = props.nodeDef
const allInputDefs = nodeDef.input.all const allInputDefs = nodeDef.input.all
const allOutputDefs = nodeDef.output.all const allOutputDefs = nodeDef.output.all
const slotInputDefs = allInputDefs.filter( const slotInputDefs = allInputDefs.filter(
(input) => !nodeDefStore.inputIsWidget(input) (input) => !widgetStore.inputIsWidget(input)
) )
const widgetInputDefs = allInputDefs.filter((input) => const widgetInputDefs = allInputDefs.filter((input) =>
nodeDefStore.inputIsWidget(input) widgetStore.inputIsWidget(input)
) )
const truncateDefaultValue = (value: any, charLimit: number = 32): string => { const truncateDefaultValue = (value: any, charLimit: number = 32): string => {
let stringValue: string let stringValue: string

View File

@@ -5,6 +5,7 @@ import type { IWidget } from '@comfyorg/litegraph'
import type { DOMWidget } from '@/scripts/domWidget' import type { DOMWidget } from '@/scripts/domWidget'
import { ComfyNodeDef } from '@/types/apiTypes' import { ComfyNodeDef } from '@/types/apiTypes'
import { useToastStore } from '@/stores/toastStore' import { useToastStore } from '@/stores/toastStore'
import { Widgets } from '@/types/comfy'
type FolderType = 'input' | 'output' | 'temp' type FolderType = 'input' | 'output' | 'temp'
@@ -107,7 +108,7 @@ app.registerExtension({
audioUIWidget.element.classList.add('empty-audio-widget') audioUIWidget.element.classList.add('empty-audio-widget')
// Populate the audio widget UI on node execution. // Populate the audio widget UI on node execution.
const onExecuted = node.onExecuted const onExecuted = node.onExecuted
node.onExecuted = function (message) { node.onExecuted = function (message: any) {
onExecuted?.apply(this, arguments) onExecuted?.apply(this, arguments)
const audios = message.audio const audios = message.audio
if (!audios) return if (!audios) return
@@ -120,7 +121,7 @@ app.registerExtension({
} }
return { widget: audioUIWidget } return { widget: audioUIWidget }
} }
} } as Widgets
}, },
onNodeOutputsUpdated(nodeOutputs: Record<number, any>) { onNodeOutputsUpdated(nodeOutputs: Record<number, any>) {
for (const [nodeId, output] of Object.entries(nodeOutputs)) { for (const [nodeId, output] of Object.entries(nodeOutputs)) {
@@ -153,9 +154,9 @@ app.registerExtension({
const audioWidget: IWidget = node.widgets.find( const audioWidget: IWidget = node.widgets.find(
(w: IWidget) => w.name === 'audio' (w: IWidget) => w.name === 'audio'
) )
const audioUIWidget: DOMWidget<HTMLAudioElement> = node.widgets.find( const audioUIWidget = node.widgets.find(
(w: IWidget) => w.name === 'audioUI' (w: IWidget) => w.name === 'audioUI'
) ) as DOMWidget<HTMLAudioElement>
const onAudioWidgetUpdate = () => { const onAudioWidgetUpdate = () => {
audioUIWidget.element.src = api.apiURL( audioUIWidget.element.src = api.apiURL(

View File

@@ -1,6 +1,10 @@
// @ts-strict-ignore // @ts-strict-ignore
import { ComfyLogging } from './logging' import { ComfyLogging } from './logging'
import { ComfyWidgetConstructor, ComfyWidgets, initWidgets } from './widgets' import {
type ComfyWidgetConstructor,
ComfyWidgets,
initWidgets
} from './widgets'
import { ComfyUI, $el } from './ui' import { ComfyUI, $el } from './ui'
import { api } from './api' import { api } from './api'
import { defaultGraph } from './defaultGraph' import { defaultGraph } from './defaultGraph'
@@ -60,6 +64,7 @@ import { useCommandStore } from '@/stores/commandStore'
import { shallowReactive } from 'vue' import { shallowReactive } from 'vue'
import { type IBaseWidget } from '@comfyorg/litegraph/dist/types/widgets' import { type IBaseWidget } from '@comfyorg/litegraph/dist/types/widgets'
import { workflowService } from '@/services/workflowService' import { workflowService } from '@/services/workflowService'
import { useWidgetStore } from '@/stores/widgetStore'
export const ANIM_PREVIEW_WIDGET = '$$comfy_animation_preview' export const ANIM_PREVIEW_WIDGET = '$$comfy_animation_preview'
@@ -142,7 +147,6 @@ export class ComfyApp {
storageLocation: StorageLocation storageLocation: StorageLocation
multiUserServer: boolean multiUserServer: boolean
ctx: CanvasRenderingContext2D ctx: CanvasRenderingContext2D
widgets: Record<string, ComfyWidgetConstructor>
bodyTop: HTMLElement bodyTop: HTMLElement
bodyLeft: HTMLElement bodyLeft: HTMLElement
bodyRight: HTMLElement bodyRight: HTMLElement
@@ -167,6 +171,16 @@ export class ComfyApp {
return useWorkspaceStore().shiftDown return useWorkspaceStore().shiftDown
} }
/**
* @deprecated Use useWidgetStore().widgets instead
*/
get widgets(): Record<string, ComfyWidgetConstructor> {
if (this.vueAppReady) {
return useWidgetStore().widgets
}
return ComfyWidgets
}
constructor() { constructor() {
this.vueAppReady = false this.vueAppReady = false
this.ui = new ComfyUI(this) this.ui = new ComfyUI(this)
@@ -1551,10 +1565,7 @@ export class ComfyApp {
} }
const node = this.graph.getNodeById(detail.display_node || detail.node) const node = this.graph.getNodeById(detail.display_node || detail.node)
if (node) { if (node) {
// @ts-expect-error if (node.onExecuted) node.onExecuted(detail.output)
if (node.onExecuted)
// @ts-expect-error
node.onExecuted(detail.output)
} }
}) })
@@ -1653,7 +1664,6 @@ export class ComfyApp {
app.graph.onConfigure = function () { app.graph.onConfigure = function () {
// Fire callbacks before the onConfigure, this is used by widget inputs to setup the config // Fire callbacks before the onConfigure, this is used by widget inputs to setup the config
for (const node of app.graph.nodes) { for (const node of app.graph.nodes) {
// @ts-expect-error
node.onGraphConfigured?.() node.onGraphConfigured?.()
} }
@@ -1920,7 +1930,6 @@ export class ComfyApp {
const nodeDefArray: ComfyNodeDef[] = Object.values(allNodeDefs) const nodeDefArray: ComfyNodeDef[] = Object.values(allNodeDefs)
this.#invokeExtensions('beforeRegisterVueAppNodeDefs', nodeDefArray, this) this.#invokeExtensions('beforeRegisterVueAppNodeDefs', nodeDefArray, this)
nodeDefStore.updateNodeDefs(nodeDefArray) nodeDefStore.updateNodeDefs(nodeDefArray)
nodeDefStore.widgets = this.widgets
} }
/** /**
@@ -1938,7 +1947,11 @@ export class ComfyApp {
} }
} }
getWidgetType(inputData, inputName) { /**
* Remove the impl after groupNode jest tests are removed.
* @deprecated Use useWidgetStore().getWidgetType instead
*/
getWidgetType(inputData, inputName: string) {
const type = inputData[0] const type = inputData[0]
if (Array.isArray(type)) { if (Array.isArray(type)) {
@@ -2092,13 +2105,6 @@ export class ComfyApp {
async registerNodesFromDefs(defs: Record<string, ComfyNodeDef>) { async registerNodesFromDefs(defs: Record<string, ComfyNodeDef>) {
await this.#invokeExtensionsAsync('addCustomNodeDefs', defs) await this.#invokeExtensionsAsync('addCustomNodeDefs', defs)
// Generate list of known widgets
this.widgets = Object.assign(
{},
ComfyWidgets,
...(await this.#invokeExtensionsAsync('getCustomWidgets')).filter(Boolean)
)
// Register a node for each definition // Register a node for each definition
for (const nodeId in defs) { for (const nodeId in defs) {
this.registerNodeDef(nodeId, defs[nodeId]) this.registerNodeDef(nodeId, defs[nodeId])

View File

@@ -268,7 +268,7 @@ LGraphNode.prototype.addDOMWidget = function (
name: string, name: string,
type: string, type: string,
element: HTMLElement, element: HTMLElement,
options: Record<string, any> options: Record<string, any> = {}
): DOMWidget { ): DOMWidget {
options = { hideOnZoom: true, selectOn: ['focus', 'click'], ...options } options = { hideOnZoom: true, selectOn: ['focus', 'click'], ...options }

View File

@@ -2,10 +2,11 @@
import { api } from './api' import { api } from './api'
import './domWidget' import './domWidget'
import type { ComfyApp } from './app' import type { ComfyApp } from './app'
import type { IWidget, LGraphNode } from '@comfyorg/litegraph' import type { LGraphNode } from '@comfyorg/litegraph'
import { InputSpec } from '@/types/apiTypes' import { InputSpec } from '@/types/apiTypes'
import { useSettingStore } from '@/stores/settingStore' import { useSettingStore } from '@/stores/settingStore'
import { useToastStore } from '@/stores/toastStore' import { useToastStore } from '@/stores/toastStore'
import type { IWidget } from '@comfyorg/litegraph'
export type ComfyWidgetConstructor = ( export type ComfyWidgetConstructor = (
node: LGraphNode, node: LGraphNode,

View File

@@ -7,6 +7,7 @@ import { useSettingStore } from './settingStore'
import { app } from '@/scripts/app' import { app } from '@/scripts/app'
import { useMenuItemStore } from './menuItemStore' import { useMenuItemStore } from './menuItemStore'
import { useBottomPanelStore } from './workspace/bottomPanelStore' import { useBottomPanelStore } from './workspace/bottomPanelStore'
import { useWidgetStore } from './widgetStore'
export const useExtensionStore = defineStore('extension', () => { export const useExtensionStore = defineStore('extension', () => {
// For legacy reasons, the name uniquely identifies an extension // For legacy reasons, the name uniquely identifies an extension
@@ -50,6 +51,16 @@ export const useExtensionStore = defineStore('extension', () => {
useMenuItemStore().loadExtensionMenuCommands(extension) useMenuItemStore().loadExtensionMenuCommands(extension)
useSettingStore().loadExtensionSettings(extension) useSettingStore().loadExtensionSettings(extension)
useBottomPanelStore().registerExtensionBottomPanelTabs(extension) useBottomPanelStore().registerExtensionBottomPanelTabs(extension)
if (extension.getCustomWidgets) {
// TODO(huchenlei): We should deprecate the async return value of
// getCustomWidgets.
;(async () => {
if (extension.getCustomWidgets) {
const widgets = await extension.getCustomWidgets(app)
useWidgetStore().registerCustomWidgets(widgets)
}
})()
}
/* /*
* Extensions are currently stored in both extensionStore and app.extensions. * Extensions are currently stored in both extensionStore and app.extensions.
* Legacy jest tests still depend on app.extensions being populated. * Legacy jest tests still depend on app.extensions being populated.

View File

@@ -323,20 +323,6 @@ export const useNodeDefStore = defineStore('nodeDef', () => {
nodeDefsByName.value[nodeDef.name] = nodeDefImpl nodeDefsByName.value[nodeDef.name] = nodeDefImpl
nodeDefsByDisplayName.value[nodeDef.display_name] = nodeDefImpl nodeDefsByDisplayName.value[nodeDef.display_name] = nodeDefImpl
} }
function getWidgetType(type: string, inputName: string) {
if (type === 'COMBO') {
return 'COMBO'
} else if (`${type}:${inputName}` in widgets.value) {
return `${type}:${inputName}`
} else if (type in widgets.value) {
return type
} else {
return null
}
}
function inputIsWidget(spec: BaseInputSpec) {
return getWidgetType(spec.type, spec.name) !== null
}
function fromLGraphNode(node: LGraphNode): ComfyNodeDefImpl | null { function fromLGraphNode(node: LGraphNode): ComfyNodeDefImpl | null {
// Frontend-only nodes don't have nodeDef // Frontend-only nodes don't have nodeDef
return nodeDefsByName.value[node.constructor?.nodeData?.name] ?? null return nodeDefsByName.value[node.constructor?.nodeData?.name] ?? null
@@ -345,7 +331,6 @@ export const useNodeDefStore = defineStore('nodeDef', () => {
return { return {
nodeDefsByName, nodeDefsByName,
nodeDefsByDisplayName, nodeDefsByDisplayName,
widgets,
showDeprecated, showDeprecated,
showExperimental, showExperimental,
@@ -356,8 +341,6 @@ export const useNodeDefStore = defineStore('nodeDef', () => {
updateNodeDefs, updateNodeDefs,
addNodeDef, addNodeDef,
getWidgetType,
inputIsWidget,
fromLGraphNode fromLGraphNode
} }
}) })

45
src/stores/widgetStore.ts Normal file
View File

@@ -0,0 +1,45 @@
import { ComfyWidgets, ComfyWidgetConstructor } from '@/scripts/widgets'
import { defineStore } from 'pinia'
import { ref, computed } from 'vue'
import type { BaseInputSpec } from './nodeDefStore'
export const useWidgetStore = defineStore('widget', () => {
const coreWidgets = ComfyWidgets
const customWidgets = ref<Record<string, ComfyWidgetConstructor>>({})
const widgets = computed(() => ({
...customWidgets.value,
...coreWidgets
}))
function getWidgetType(type: string, inputName: string) {
if (type === 'COMBO') {
return 'COMBO'
} else if (`${type}:${inputName}` in widgets.value) {
return `${type}:${inputName}`
} else if (type in widgets.value) {
return type
} else {
return null
}
}
function inputIsWidget(spec: BaseInputSpec) {
return getWidgetType(spec.type, spec.name) !== null
}
function registerCustomWidgets(
newWidgets: Record<string, ComfyWidgetConstructor>
) {
customWidgets.value = {
...customWidgets.value,
...newWidgets
}
}
return {
widgets,
getWidgetType,
inputIsWidget,
registerCustomWidgets
}
})

19
src/types/comfy.d.ts vendored
View File

@@ -1,20 +1,13 @@
import { LGraphNode, IWidget } from './litegraph' import type { LGraphNode } from './litegraph'
import { ComfyApp } from '../scripts/app' import type { ComfyApp } from '../scripts/app'
import type { ComfyNodeDef } from '@/types/apiTypes' import type { ComfyNodeDef } from '@/types/apiTypes'
import type { Keybinding } from '@/types/keyBindingTypes' import type { Keybinding } from '@/types/keyBindingTypes'
import type { ComfyCommand } from '@/stores/commandStore' import type { ComfyCommand } from '@/stores/commandStore'
import { SettingParams } from './settingTypes' import type { SettingParams } from './settingTypes'
import type { BottomPanelExtension } from './extensionTypes' import type { BottomPanelExtension } from './extensionTypes'
import type { ComfyWidgetConstructor } from '@/scripts/widgets'
export type Widgets = Record< export type Widgets = Record<string, ComfyWidgetConstructor>
string,
(
node,
inputName,
inputData,
app?: ComfyApp
) => { widget?: IWidget; minWidth?: number; minHeight?: number }
>
export interface AboutPageBadge { export interface AboutPageBadge {
label: string label: string
@@ -94,6 +87,8 @@ export interface ComfyExtension {
defs: Record<string, ComfyNodeDef>, defs: Record<string, ComfyNodeDef>,
app: ComfyApp app: ComfyApp
): Promise<void> | void ): Promise<void> | void
// TODO(huchenlei): We should deprecate the async return value of
// getCustomWidgets.
/** /**
* Allows the extension to add custom widgets * Allows the extension to add custom widgets
* @param app The ComfyUI app instance * @param app The ComfyUI app instance

View File

@@ -22,6 +22,8 @@ declare module '@comfyorg/litegraph' {
* Callback fired on each node after the graph is configured * Callback fired on each node after the graph is configured
*/ */
onAfterGraphConfigured?(): void onAfterGraphConfigured?(): void
onGraphConfigured?(): void
onExecuted?(output: any): void
onNodeCreated?(this: LGraphNode): void onNodeCreated?(this: LGraphNode): void
setInnerNodes?(nodes: LGraphNode[]): void setInnerNodes?(nodes: LGraphNode[]): void
applyToGraph?(extraLinks?: LLink[]): void applyToGraph?(extraLinks?: LLink[]): void
@@ -38,7 +40,7 @@ declare module '@comfyorg/litegraph' {
name: string, name: string,
type: string, type: string,
element: HTMLElement, element: HTMLElement,
options: Record<string, any> options?: Record<string, any>
): DOMWidget ): DOMWidget
} }

View File

@@ -69,6 +69,16 @@ module.exports = async function () {
} }
}) })
jest.mock('@/stores/widgetStore', () => {
const widgets = {}
return {
useWidgetStore: () => ({
widgets,
registerCustomWidgets: jest.fn()
})
}
})
jest.mock('vue-i18n', () => { jest.mock('vue-i18n', () => {
return { return {
useI18n: jest.fn() useI18n: jest.fn()

View File

@@ -1,213 +0,0 @@
// @ts-strict-ignore
import { start } from '../../utils'
import lg from '../../utils/litegraph'
describe('extensions', () => {
beforeEach(() => {
lg.setup(global)
})
afterEach(() => {
lg.teardown(global)
})
it('calls each extension hook', async () => {
const mockExtension = {
name: 'TestExtension',
init: jest.fn(),
setup: jest.fn(),
addCustomNodeDefs: jest.fn(),
getCustomWidgets: jest.fn(),
beforeRegisterNodeDef: jest.fn(),
registerCustomNodes: jest.fn(),
loadedGraphNode: jest.fn(),
nodeCreated: jest.fn(),
beforeConfigureGraph: jest.fn(),
afterConfigureGraph: jest.fn()
}
const { app, ez, graph } = await start({
async preSetup(app) {
app.registerExtension(mockExtension)
}
})
// Basic initialisation hooks should be called once, with app
expect(mockExtension.init).toHaveBeenCalledTimes(1)
expect(mockExtension.init).toHaveBeenCalledWith(app)
// Adding custom node defs should be passed the full list of nodes
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1)
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app)
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0]
expect(defs).toHaveProperty('KSampler')
expect(defs).toHaveProperty('LoadImage')
// Get custom widgets is called once and should return new widget types
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1)
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app)
// Before register node def will be called once per node type
const nodeNames = Object.keys(defs)
const nodeCount = nodeNames.length
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount)
for (let i = 0; i < 10; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]
expect(nodeClass.name).toBe('ComfyNode')
expect(nodeClass.comfyClass).toBe(nodeNames[i])
expect(nodeDef.name).toBe(nodeNames[i])
expect(nodeDef).toHaveProperty('input')
expect(nodeDef).toHaveProperty('output')
}
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1)
// Before configure graph will be called here as the default graph is being loaded
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1)
// it gets sent the graph data that is going to be loaded
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0]
// A node created is fired for each node constructor that is called
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(
graphData.nodes.length
)
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(
graphData.nodes[i].type
)
}
// Each node then calls loadedGraphNode to allow them to be updated
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(
graphData.nodes.length
)
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(
graphData.nodes[i].type
)
}
// After configure is then called once all the setup is done
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1)
expect(mockExtension.setup).toHaveBeenCalledTimes(1)
expect(mockExtension.setup).toHaveBeenCalledWith(app)
// Ensure hooks are called in the correct order
const callOrder = [
'init',
'addCustomNodeDefs',
'getCustomWidgets',
'beforeRegisterNodeDef',
'registerCustomNodes',
'beforeConfigureGraph',
'nodeCreated',
'loadedGraphNode',
'afterConfigureGraph',
'setup'
]
for (let i = 1; i < callOrder.length; i++) {
const fn1 = mockExtension[callOrder[i - 1]]
const fn2 = mockExtension[callOrder[i]]
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(
fn2.mock.invocationCallOrder[0]
)
}
graph.clear()
// Ensure adding a new node calls the correct callback
ez.LoadImage()
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(
graphData.nodes.length
)
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(
graphData.nodes.length + 1
)
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe('LoadImage')
// Reload the graph to ensure correct hooks are fired
await graph.reload()
// These hooks should not be fired again
expect(mockExtension.init).toHaveBeenCalledTimes(1)
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1)
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1)
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1)
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount)
expect(mockExtension.setup).toHaveBeenCalledTimes(1)
// These should be called again
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2)
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(
graphData.nodes.length + 2
)
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(
graphData.nodes.length + 1
)
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2)
}, 15000)
it('allows custom nodeDefs and widgets to be registered', async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {
expect(node.constructor.comfyClass).toBe('TestNode')
expect(inputName).toBe('test_input')
expect(inputData[0]).toBe('CUSTOMWIDGET')
expect(inputData[1]?.hello).toBe('world')
expect(app).toStrictEqual(app)
return {
widget: node.addWidget('button', inputName, 'hello', () => {})
}
})
// Register our extension that adds a custom node + widget type
const mockExtension = {
name: 'TestExtension',
addCustomNodeDefs: (nodeDefs) => {
nodeDefs['TestNode'] = {
output: [],
output_name: [],
output_is_list: [],
name: 'TestNode',
display_name: 'TestNode',
category: 'Test',
input: {
required: {
test_input: ['CUSTOMWIDGET', { hello: 'world' }]
}
}
}
},
getCustomWidgets: jest.fn(() => {
return {
CUSTOMWIDGET: widgetMock
}
})
}
const { graph, ez } = await start({
async preSetup(app) {
app.registerExtension(mockExtension)
}
})
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1)
graph.clear()
expect(widgetMock).toBeCalledTimes(0)
const node = ez.TestNode()
expect(widgetMock).toBeCalledTimes(1)
// Ensure our custom widget is created
expect(node.inputs.length).toBe(0)
expect(node.widgets.length).toBe(1)
const w = node.widgets[0].widget
expect(w.name).toBe('test_input')
expect(w.type).toBe('button')
})
})