mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-07 08:30:06 +00:00
Manage widget definitions with Pinia store (#1510)
* Fix compile * nit * Remove extensions.test * nit
This commit is contained in:
@@ -80,12 +80,13 @@ https://github.com/Nuked88/ComfyUI-N-Sidebar/blob/7ae7da4a9761009fb6629bc04c6830
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore'
|
||||
import { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
|
||||
import {
|
||||
getColorPalette,
|
||||
defaultColorPalette
|
||||
} from '@/extensions/core/colorPalette'
|
||||
import _ from 'lodash'
|
||||
import { useWidgetStore } from '@/stores/widgetStore'
|
||||
|
||||
const props = defineProps({
|
||||
nodeDef: {
|
||||
@@ -101,16 +102,16 @@ const props = defineProps({
|
||||
const colors = getColorPalette()?.colors?.litegraph_base
|
||||
const litegraphColors = colors ?? defaultColorPalette.colors.litegraph_base
|
||||
|
||||
const nodeDefStore = useNodeDefStore()
|
||||
const widgetStore = useWidgetStore()
|
||||
|
||||
const nodeDef = props.nodeDef
|
||||
const allInputDefs = nodeDef.input.all
|
||||
const allOutputDefs = nodeDef.output.all
|
||||
const slotInputDefs = allInputDefs.filter(
|
||||
(input) => !nodeDefStore.inputIsWidget(input)
|
||||
(input) => !widgetStore.inputIsWidget(input)
|
||||
)
|
||||
const widgetInputDefs = allInputDefs.filter((input) =>
|
||||
nodeDefStore.inputIsWidget(input)
|
||||
widgetStore.inputIsWidget(input)
|
||||
)
|
||||
const truncateDefaultValue = (value: any, charLimit: number = 32): string => {
|
||||
let stringValue: string
|
||||
|
||||
@@ -5,6 +5,7 @@ import type { IWidget } from '@comfyorg/litegraph'
|
||||
import type { DOMWidget } from '@/scripts/domWidget'
|
||||
import { ComfyNodeDef } from '@/types/apiTypes'
|
||||
import { useToastStore } from '@/stores/toastStore'
|
||||
import { Widgets } from '@/types/comfy'
|
||||
|
||||
type FolderType = 'input' | 'output' | 'temp'
|
||||
|
||||
@@ -107,7 +108,7 @@ app.registerExtension({
|
||||
audioUIWidget.element.classList.add('empty-audio-widget')
|
||||
// Populate the audio widget UI on node execution.
|
||||
const onExecuted = node.onExecuted
|
||||
node.onExecuted = function (message) {
|
||||
node.onExecuted = function (message: any) {
|
||||
onExecuted?.apply(this, arguments)
|
||||
const audios = message.audio
|
||||
if (!audios) return
|
||||
@@ -120,7 +121,7 @@ app.registerExtension({
|
||||
}
|
||||
return { widget: audioUIWidget }
|
||||
}
|
||||
}
|
||||
} as Widgets
|
||||
},
|
||||
onNodeOutputsUpdated(nodeOutputs: Record<number, any>) {
|
||||
for (const [nodeId, output] of Object.entries(nodeOutputs)) {
|
||||
@@ -153,9 +154,9 @@ app.registerExtension({
|
||||
const audioWidget: IWidget = node.widgets.find(
|
||||
(w: IWidget) => w.name === 'audio'
|
||||
)
|
||||
const audioUIWidget: DOMWidget<HTMLAudioElement> = node.widgets.find(
|
||||
const audioUIWidget = node.widgets.find(
|
||||
(w: IWidget) => w.name === 'audioUI'
|
||||
)
|
||||
) as DOMWidget<HTMLAudioElement>
|
||||
|
||||
const onAudioWidgetUpdate = () => {
|
||||
audioUIWidget.element.src = api.apiURL(
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
// @ts-strict-ignore
|
||||
import { ComfyLogging } from './logging'
|
||||
import { ComfyWidgetConstructor, ComfyWidgets, initWidgets } from './widgets'
|
||||
import {
|
||||
type ComfyWidgetConstructor,
|
||||
ComfyWidgets,
|
||||
initWidgets
|
||||
} from './widgets'
|
||||
import { ComfyUI, $el } from './ui'
|
||||
import { api } from './api'
|
||||
import { defaultGraph } from './defaultGraph'
|
||||
@@ -60,6 +64,7 @@ import { useCommandStore } from '@/stores/commandStore'
|
||||
import { shallowReactive } from 'vue'
|
||||
import { type IBaseWidget } from '@comfyorg/litegraph/dist/types/widgets'
|
||||
import { workflowService } from '@/services/workflowService'
|
||||
import { useWidgetStore } from '@/stores/widgetStore'
|
||||
|
||||
export const ANIM_PREVIEW_WIDGET = '$$comfy_animation_preview'
|
||||
|
||||
@@ -142,7 +147,6 @@ export class ComfyApp {
|
||||
storageLocation: StorageLocation
|
||||
multiUserServer: boolean
|
||||
ctx: CanvasRenderingContext2D
|
||||
widgets: Record<string, ComfyWidgetConstructor>
|
||||
bodyTop: HTMLElement
|
||||
bodyLeft: HTMLElement
|
||||
bodyRight: HTMLElement
|
||||
@@ -167,6 +171,16 @@ export class ComfyApp {
|
||||
return useWorkspaceStore().shiftDown
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use useWidgetStore().widgets instead
|
||||
*/
|
||||
get widgets(): Record<string, ComfyWidgetConstructor> {
|
||||
if (this.vueAppReady) {
|
||||
return useWidgetStore().widgets
|
||||
}
|
||||
return ComfyWidgets
|
||||
}
|
||||
|
||||
constructor() {
|
||||
this.vueAppReady = false
|
||||
this.ui = new ComfyUI(this)
|
||||
@@ -1551,10 +1565,7 @@ export class ComfyApp {
|
||||
}
|
||||
const node = this.graph.getNodeById(detail.display_node || detail.node)
|
||||
if (node) {
|
||||
// @ts-expect-error
|
||||
if (node.onExecuted)
|
||||
// @ts-expect-error
|
||||
node.onExecuted(detail.output)
|
||||
if (node.onExecuted) node.onExecuted(detail.output)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1653,7 +1664,6 @@ export class ComfyApp {
|
||||
app.graph.onConfigure = function () {
|
||||
// Fire callbacks before the onConfigure, this is used by widget inputs to setup the config
|
||||
for (const node of app.graph.nodes) {
|
||||
// @ts-expect-error
|
||||
node.onGraphConfigured?.()
|
||||
}
|
||||
|
||||
@@ -1920,7 +1930,6 @@ export class ComfyApp {
|
||||
const nodeDefArray: ComfyNodeDef[] = Object.values(allNodeDefs)
|
||||
this.#invokeExtensions('beforeRegisterVueAppNodeDefs', nodeDefArray, this)
|
||||
nodeDefStore.updateNodeDefs(nodeDefArray)
|
||||
nodeDefStore.widgets = this.widgets
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1938,7 +1947,11 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
|
||||
getWidgetType(inputData, inputName) {
|
||||
/**
|
||||
* Remove the impl after groupNode jest tests are removed.
|
||||
* @deprecated Use useWidgetStore().getWidgetType instead
|
||||
*/
|
||||
getWidgetType(inputData, inputName: string) {
|
||||
const type = inputData[0]
|
||||
|
||||
if (Array.isArray(type)) {
|
||||
@@ -2092,13 +2105,6 @@ export class ComfyApp {
|
||||
async registerNodesFromDefs(defs: Record<string, ComfyNodeDef>) {
|
||||
await this.#invokeExtensionsAsync('addCustomNodeDefs', defs)
|
||||
|
||||
// Generate list of known widgets
|
||||
this.widgets = Object.assign(
|
||||
{},
|
||||
ComfyWidgets,
|
||||
...(await this.#invokeExtensionsAsync('getCustomWidgets')).filter(Boolean)
|
||||
)
|
||||
|
||||
// Register a node for each definition
|
||||
for (const nodeId in defs) {
|
||||
this.registerNodeDef(nodeId, defs[nodeId])
|
||||
|
||||
@@ -268,7 +268,7 @@ LGraphNode.prototype.addDOMWidget = function (
|
||||
name: string,
|
||||
type: string,
|
||||
element: HTMLElement,
|
||||
options: Record<string, any>
|
||||
options: Record<string, any> = {}
|
||||
): DOMWidget {
|
||||
options = { hideOnZoom: true, selectOn: ['focus', 'click'], ...options }
|
||||
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
import { api } from './api'
|
||||
import './domWidget'
|
||||
import type { ComfyApp } from './app'
|
||||
import type { IWidget, LGraphNode } from '@comfyorg/litegraph'
|
||||
import type { LGraphNode } from '@comfyorg/litegraph'
|
||||
import { InputSpec } from '@/types/apiTypes'
|
||||
import { useSettingStore } from '@/stores/settingStore'
|
||||
import { useToastStore } from '@/stores/toastStore'
|
||||
import type { IWidget } from '@comfyorg/litegraph'
|
||||
|
||||
export type ComfyWidgetConstructor = (
|
||||
node: LGraphNode,
|
||||
|
||||
@@ -7,6 +7,7 @@ import { useSettingStore } from './settingStore'
|
||||
import { app } from '@/scripts/app'
|
||||
import { useMenuItemStore } from './menuItemStore'
|
||||
import { useBottomPanelStore } from './workspace/bottomPanelStore'
|
||||
import { useWidgetStore } from './widgetStore'
|
||||
|
||||
export const useExtensionStore = defineStore('extension', () => {
|
||||
// For legacy reasons, the name uniquely identifies an extension
|
||||
@@ -50,6 +51,16 @@ export const useExtensionStore = defineStore('extension', () => {
|
||||
useMenuItemStore().loadExtensionMenuCommands(extension)
|
||||
useSettingStore().loadExtensionSettings(extension)
|
||||
useBottomPanelStore().registerExtensionBottomPanelTabs(extension)
|
||||
if (extension.getCustomWidgets) {
|
||||
// TODO(huchenlei): We should deprecate the async return value of
|
||||
// getCustomWidgets.
|
||||
;(async () => {
|
||||
if (extension.getCustomWidgets) {
|
||||
const widgets = await extension.getCustomWidgets(app)
|
||||
useWidgetStore().registerCustomWidgets(widgets)
|
||||
}
|
||||
})()
|
||||
}
|
||||
/*
|
||||
* Extensions are currently stored in both extensionStore and app.extensions.
|
||||
* Legacy jest tests still depend on app.extensions being populated.
|
||||
|
||||
@@ -323,20 +323,6 @@ export const useNodeDefStore = defineStore('nodeDef', () => {
|
||||
nodeDefsByName.value[nodeDef.name] = nodeDefImpl
|
||||
nodeDefsByDisplayName.value[nodeDef.display_name] = nodeDefImpl
|
||||
}
|
||||
function getWidgetType(type: string, inputName: string) {
|
||||
if (type === 'COMBO') {
|
||||
return 'COMBO'
|
||||
} else if (`${type}:${inputName}` in widgets.value) {
|
||||
return `${type}:${inputName}`
|
||||
} else if (type in widgets.value) {
|
||||
return type
|
||||
} else {
|
||||
return null
|
||||
}
|
||||
}
|
||||
function inputIsWidget(spec: BaseInputSpec) {
|
||||
return getWidgetType(spec.type, spec.name) !== null
|
||||
}
|
||||
function fromLGraphNode(node: LGraphNode): ComfyNodeDefImpl | null {
|
||||
// Frontend-only nodes don't have nodeDef
|
||||
return nodeDefsByName.value[node.constructor?.nodeData?.name] ?? null
|
||||
@@ -345,7 +331,6 @@ export const useNodeDefStore = defineStore('nodeDef', () => {
|
||||
return {
|
||||
nodeDefsByName,
|
||||
nodeDefsByDisplayName,
|
||||
widgets,
|
||||
showDeprecated,
|
||||
showExperimental,
|
||||
|
||||
@@ -356,8 +341,6 @@ export const useNodeDefStore = defineStore('nodeDef', () => {
|
||||
|
||||
updateNodeDefs,
|
||||
addNodeDef,
|
||||
getWidgetType,
|
||||
inputIsWidget,
|
||||
fromLGraphNode
|
||||
}
|
||||
})
|
||||
|
||||
45
src/stores/widgetStore.ts
Normal file
45
src/stores/widgetStore.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { ComfyWidgets, ComfyWidgetConstructor } from '@/scripts/widgets'
|
||||
import { defineStore } from 'pinia'
|
||||
import { ref, computed } from 'vue'
|
||||
import type { BaseInputSpec } from './nodeDefStore'
|
||||
|
||||
export const useWidgetStore = defineStore('widget', () => {
|
||||
const coreWidgets = ComfyWidgets
|
||||
const customWidgets = ref<Record<string, ComfyWidgetConstructor>>({})
|
||||
const widgets = computed(() => ({
|
||||
...customWidgets.value,
|
||||
...coreWidgets
|
||||
}))
|
||||
|
||||
function getWidgetType(type: string, inputName: string) {
|
||||
if (type === 'COMBO') {
|
||||
return 'COMBO'
|
||||
} else if (`${type}:${inputName}` in widgets.value) {
|
||||
return `${type}:${inputName}`
|
||||
} else if (type in widgets.value) {
|
||||
return type
|
||||
} else {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
function inputIsWidget(spec: BaseInputSpec) {
|
||||
return getWidgetType(spec.type, spec.name) !== null
|
||||
}
|
||||
|
||||
function registerCustomWidgets(
|
||||
newWidgets: Record<string, ComfyWidgetConstructor>
|
||||
) {
|
||||
customWidgets.value = {
|
||||
...customWidgets.value,
|
||||
...newWidgets
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
widgets,
|
||||
getWidgetType,
|
||||
inputIsWidget,
|
||||
registerCustomWidgets
|
||||
}
|
||||
})
|
||||
19
src/types/comfy.d.ts
vendored
19
src/types/comfy.d.ts
vendored
@@ -1,20 +1,13 @@
|
||||
import { LGraphNode, IWidget } from './litegraph'
|
||||
import { ComfyApp } from '../scripts/app'
|
||||
import type { LGraphNode } from './litegraph'
|
||||
import type { ComfyApp } from '../scripts/app'
|
||||
import type { ComfyNodeDef } from '@/types/apiTypes'
|
||||
import type { Keybinding } from '@/types/keyBindingTypes'
|
||||
import type { ComfyCommand } from '@/stores/commandStore'
|
||||
import { SettingParams } from './settingTypes'
|
||||
import type { SettingParams } from './settingTypes'
|
||||
import type { BottomPanelExtension } from './extensionTypes'
|
||||
import type { ComfyWidgetConstructor } from '@/scripts/widgets'
|
||||
|
||||
export type Widgets = Record<
|
||||
string,
|
||||
(
|
||||
node,
|
||||
inputName,
|
||||
inputData,
|
||||
app?: ComfyApp
|
||||
) => { widget?: IWidget; minWidth?: number; minHeight?: number }
|
||||
>
|
||||
export type Widgets = Record<string, ComfyWidgetConstructor>
|
||||
|
||||
export interface AboutPageBadge {
|
||||
label: string
|
||||
@@ -94,6 +87,8 @@ export interface ComfyExtension {
|
||||
defs: Record<string, ComfyNodeDef>,
|
||||
app: ComfyApp
|
||||
): Promise<void> | void
|
||||
// TODO(huchenlei): We should deprecate the async return value of
|
||||
// getCustomWidgets.
|
||||
/**
|
||||
* Allows the extension to add custom widgets
|
||||
* @param app The ComfyUI app instance
|
||||
|
||||
4
src/types/litegraph-augmentation.d.ts
vendored
4
src/types/litegraph-augmentation.d.ts
vendored
@@ -22,6 +22,8 @@ declare module '@comfyorg/litegraph' {
|
||||
* Callback fired on each node after the graph is configured
|
||||
*/
|
||||
onAfterGraphConfigured?(): void
|
||||
onGraphConfigured?(): void
|
||||
onExecuted?(output: any): void
|
||||
onNodeCreated?(this: LGraphNode): void
|
||||
setInnerNodes?(nodes: LGraphNode[]): void
|
||||
applyToGraph?(extraLinks?: LLink[]): void
|
||||
@@ -38,7 +40,7 @@ declare module '@comfyorg/litegraph' {
|
||||
name: string,
|
||||
type: string,
|
||||
element: HTMLElement,
|
||||
options: Record<string, any>
|
||||
options?: Record<string, any>
|
||||
): DOMWidget
|
||||
}
|
||||
|
||||
|
||||
@@ -69,6 +69,16 @@ module.exports = async function () {
|
||||
}
|
||||
})
|
||||
|
||||
jest.mock('@/stores/widgetStore', () => {
|
||||
const widgets = {}
|
||||
return {
|
||||
useWidgetStore: () => ({
|
||||
widgets,
|
||||
registerCustomWidgets: jest.fn()
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
jest.mock('vue-i18n', () => {
|
||||
return {
|
||||
useI18n: jest.fn()
|
||||
|
||||
@@ -1,213 +0,0 @@
|
||||
// @ts-strict-ignore
|
||||
import { start } from '../../utils'
|
||||
import lg from '../../utils/litegraph'
|
||||
|
||||
describe('extensions', () => {
|
||||
beforeEach(() => {
|
||||
lg.setup(global)
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
lg.teardown(global)
|
||||
})
|
||||
|
||||
it('calls each extension hook', async () => {
|
||||
const mockExtension = {
|
||||
name: 'TestExtension',
|
||||
init: jest.fn(),
|
||||
setup: jest.fn(),
|
||||
addCustomNodeDefs: jest.fn(),
|
||||
getCustomWidgets: jest.fn(),
|
||||
beforeRegisterNodeDef: jest.fn(),
|
||||
registerCustomNodes: jest.fn(),
|
||||
loadedGraphNode: jest.fn(),
|
||||
nodeCreated: jest.fn(),
|
||||
beforeConfigureGraph: jest.fn(),
|
||||
afterConfigureGraph: jest.fn()
|
||||
}
|
||||
|
||||
const { app, ez, graph } = await start({
|
||||
async preSetup(app) {
|
||||
app.registerExtension(mockExtension)
|
||||
}
|
||||
})
|
||||
|
||||
// Basic initialisation hooks should be called once, with app
|
||||
expect(mockExtension.init).toHaveBeenCalledTimes(1)
|
||||
expect(mockExtension.init).toHaveBeenCalledWith(app)
|
||||
|
||||
// Adding custom node defs should be passed the full list of nodes
|
||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1)
|
||||
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app)
|
||||
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0]
|
||||
expect(defs).toHaveProperty('KSampler')
|
||||
expect(defs).toHaveProperty('LoadImage')
|
||||
|
||||
// Get custom widgets is called once and should return new widget types
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1)
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app)
|
||||
|
||||
// Before register node def will be called once per node type
|
||||
const nodeNames = Object.keys(defs)
|
||||
const nodeCount = nodeNames.length
|
||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount)
|
||||
for (let i = 0; i < 10; i++) {
|
||||
// It should be send the JS class and the original JSON definition
|
||||
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]
|
||||
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]
|
||||
|
||||
expect(nodeClass.name).toBe('ComfyNode')
|
||||
expect(nodeClass.comfyClass).toBe(nodeNames[i])
|
||||
expect(nodeDef.name).toBe(nodeNames[i])
|
||||
expect(nodeDef).toHaveProperty('input')
|
||||
expect(nodeDef).toHaveProperty('output')
|
||||
}
|
||||
|
||||
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
|
||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1)
|
||||
|
||||
// Before configure graph will be called here as the default graph is being loaded
|
||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1)
|
||||
// it gets sent the graph data that is going to be loaded
|
||||
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0]
|
||||
|
||||
// A node created is fired for each node constructor that is called
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(
|
||||
graphData.nodes.length
|
||||
)
|
||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
||||
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(
|
||||
graphData.nodes[i].type
|
||||
)
|
||||
}
|
||||
|
||||
// Each node then calls loadedGraphNode to allow them to be updated
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(
|
||||
graphData.nodes.length
|
||||
)
|
||||
for (let i = 0; i < graphData.nodes.length; i++) {
|
||||
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(
|
||||
graphData.nodes[i].type
|
||||
)
|
||||
}
|
||||
|
||||
// After configure is then called once all the setup is done
|
||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1)
|
||||
|
||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1)
|
||||
expect(mockExtension.setup).toHaveBeenCalledWith(app)
|
||||
|
||||
// Ensure hooks are called in the correct order
|
||||
const callOrder = [
|
||||
'init',
|
||||
'addCustomNodeDefs',
|
||||
'getCustomWidgets',
|
||||
'beforeRegisterNodeDef',
|
||||
'registerCustomNodes',
|
||||
'beforeConfigureGraph',
|
||||
'nodeCreated',
|
||||
'loadedGraphNode',
|
||||
'afterConfigureGraph',
|
||||
'setup'
|
||||
]
|
||||
for (let i = 1; i < callOrder.length; i++) {
|
||||
const fn1 = mockExtension[callOrder[i - 1]]
|
||||
const fn2 = mockExtension[callOrder[i]]
|
||||
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(
|
||||
fn2.mock.invocationCallOrder[0]
|
||||
)
|
||||
}
|
||||
|
||||
graph.clear()
|
||||
|
||||
// Ensure adding a new node calls the correct callback
|
||||
ez.LoadImage()
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(
|
||||
graphData.nodes.length
|
||||
)
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(
|
||||
graphData.nodes.length + 1
|
||||
)
|
||||
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe('LoadImage')
|
||||
|
||||
// Reload the graph to ensure correct hooks are fired
|
||||
await graph.reload()
|
||||
|
||||
// These hooks should not be fired again
|
||||
expect(mockExtension.init).toHaveBeenCalledTimes(1)
|
||||
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1)
|
||||
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1)
|
||||
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1)
|
||||
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount)
|
||||
expect(mockExtension.setup).toHaveBeenCalledTimes(1)
|
||||
|
||||
// These should be called again
|
||||
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2)
|
||||
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(
|
||||
graphData.nodes.length + 2
|
||||
)
|
||||
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(
|
||||
graphData.nodes.length + 1
|
||||
)
|
||||
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2)
|
||||
}, 15000)
|
||||
|
||||
it('allows custom nodeDefs and widgets to be registered', async () => {
|
||||
const widgetMock = jest.fn((node, inputName, inputData, app) => {
|
||||
expect(node.constructor.comfyClass).toBe('TestNode')
|
||||
expect(inputName).toBe('test_input')
|
||||
expect(inputData[0]).toBe('CUSTOMWIDGET')
|
||||
expect(inputData[1]?.hello).toBe('world')
|
||||
expect(app).toStrictEqual(app)
|
||||
|
||||
return {
|
||||
widget: node.addWidget('button', inputName, 'hello', () => {})
|
||||
}
|
||||
})
|
||||
|
||||
// Register our extension that adds a custom node + widget type
|
||||
const mockExtension = {
|
||||
name: 'TestExtension',
|
||||
addCustomNodeDefs: (nodeDefs) => {
|
||||
nodeDefs['TestNode'] = {
|
||||
output: [],
|
||||
output_name: [],
|
||||
output_is_list: [],
|
||||
name: 'TestNode',
|
||||
display_name: 'TestNode',
|
||||
category: 'Test',
|
||||
input: {
|
||||
required: {
|
||||
test_input: ['CUSTOMWIDGET', { hello: 'world' }]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
getCustomWidgets: jest.fn(() => {
|
||||
return {
|
||||
CUSTOMWIDGET: widgetMock
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const { graph, ez } = await start({
|
||||
async preSetup(app) {
|
||||
app.registerExtension(mockExtension)
|
||||
}
|
||||
})
|
||||
|
||||
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1)
|
||||
|
||||
graph.clear()
|
||||
expect(widgetMock).toBeCalledTimes(0)
|
||||
const node = ez.TestNode()
|
||||
expect(widgetMock).toBeCalledTimes(1)
|
||||
|
||||
// Ensure our custom widget is created
|
||||
expect(node.inputs.length).toBe(0)
|
||||
expect(node.widgets.length).toBe(1)
|
||||
const w = node.widgets[0].widget
|
||||
expect(w.name).toBe('test_input')
|
||||
expect(w.type).toBe('button')
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user