Manage widget definitions with Pinia store (#1510)

* Fix compile

* nit

* Remove extensions.test

* nit
This commit is contained in:
Chenlei Hu
2024-11-11 17:23:52 -05:00
committed by GitHub
parent 64ef0f18b1
commit 1ff6e27d9c
12 changed files with 111 additions and 269 deletions

View File

@@ -80,12 +80,13 @@ https://github.com/Nuked88/ComfyUI-N-Sidebar/blob/7ae7da4a9761009fb6629bc04c6830
</template>
<script setup lang="ts">
import { ComfyNodeDefImpl, useNodeDefStore } from '@/stores/nodeDefStore'
import { ComfyNodeDefImpl } from '@/stores/nodeDefStore'
import {
getColorPalette,
defaultColorPalette
} from '@/extensions/core/colorPalette'
import _ from 'lodash'
import { useWidgetStore } from '@/stores/widgetStore'
const props = defineProps({
nodeDef: {
@@ -101,16 +102,16 @@ const props = defineProps({
const colors = getColorPalette()?.colors?.litegraph_base
const litegraphColors = colors ?? defaultColorPalette.colors.litegraph_base
const nodeDefStore = useNodeDefStore()
const widgetStore = useWidgetStore()
const nodeDef = props.nodeDef
const allInputDefs = nodeDef.input.all
const allOutputDefs = nodeDef.output.all
const slotInputDefs = allInputDefs.filter(
(input) => !nodeDefStore.inputIsWidget(input)
(input) => !widgetStore.inputIsWidget(input)
)
const widgetInputDefs = allInputDefs.filter((input) =>
nodeDefStore.inputIsWidget(input)
widgetStore.inputIsWidget(input)
)
const truncateDefaultValue = (value: any, charLimit: number = 32): string => {
let stringValue: string

View File

@@ -5,6 +5,7 @@ import type { IWidget } from '@comfyorg/litegraph'
import type { DOMWidget } from '@/scripts/domWidget'
import { ComfyNodeDef } from '@/types/apiTypes'
import { useToastStore } from '@/stores/toastStore'
import { Widgets } from '@/types/comfy'
type FolderType = 'input' | 'output' | 'temp'
@@ -107,7 +108,7 @@ app.registerExtension({
audioUIWidget.element.classList.add('empty-audio-widget')
// Populate the audio widget UI on node execution.
const onExecuted = node.onExecuted
node.onExecuted = function (message) {
node.onExecuted = function (message: any) {
onExecuted?.apply(this, arguments)
const audios = message.audio
if (!audios) return
@@ -120,7 +121,7 @@ app.registerExtension({
}
return { widget: audioUIWidget }
}
}
} as Widgets
},
onNodeOutputsUpdated(nodeOutputs: Record<number, any>) {
for (const [nodeId, output] of Object.entries(nodeOutputs)) {
@@ -153,9 +154,9 @@ app.registerExtension({
const audioWidget: IWidget = node.widgets.find(
(w: IWidget) => w.name === 'audio'
)
const audioUIWidget: DOMWidget<HTMLAudioElement> = node.widgets.find(
const audioUIWidget = node.widgets.find(
(w: IWidget) => w.name === 'audioUI'
)
) as DOMWidget<HTMLAudioElement>
const onAudioWidgetUpdate = () => {
audioUIWidget.element.src = api.apiURL(

View File

@@ -1,6 +1,10 @@
// @ts-strict-ignore
import { ComfyLogging } from './logging'
import { ComfyWidgetConstructor, ComfyWidgets, initWidgets } from './widgets'
import {
type ComfyWidgetConstructor,
ComfyWidgets,
initWidgets
} from './widgets'
import { ComfyUI, $el } from './ui'
import { api } from './api'
import { defaultGraph } from './defaultGraph'
@@ -60,6 +64,7 @@ import { useCommandStore } from '@/stores/commandStore'
import { shallowReactive } from 'vue'
import { type IBaseWidget } from '@comfyorg/litegraph/dist/types/widgets'
import { workflowService } from '@/services/workflowService'
import { useWidgetStore } from '@/stores/widgetStore'
export const ANIM_PREVIEW_WIDGET = '$$comfy_animation_preview'
@@ -142,7 +147,6 @@ export class ComfyApp {
storageLocation: StorageLocation
multiUserServer: boolean
ctx: CanvasRenderingContext2D
widgets: Record<string, ComfyWidgetConstructor>
bodyTop: HTMLElement
bodyLeft: HTMLElement
bodyRight: HTMLElement
@@ -167,6 +171,16 @@ export class ComfyApp {
return useWorkspaceStore().shiftDown
}
/**
* @deprecated Use useWidgetStore().widgets instead
*/
get widgets(): Record<string, ComfyWidgetConstructor> {
if (this.vueAppReady) {
return useWidgetStore().widgets
}
return ComfyWidgets
}
constructor() {
this.vueAppReady = false
this.ui = new ComfyUI(this)
@@ -1551,10 +1565,7 @@ export class ComfyApp {
}
const node = this.graph.getNodeById(detail.display_node || detail.node)
if (node) {
// @ts-expect-error
if (node.onExecuted)
// @ts-expect-error
node.onExecuted(detail.output)
if (node.onExecuted) node.onExecuted(detail.output)
}
})
@@ -1653,7 +1664,6 @@ export class ComfyApp {
app.graph.onConfigure = function () {
// Fire callbacks before the onConfigure, this is used by widget inputs to setup the config
for (const node of app.graph.nodes) {
// @ts-expect-error
node.onGraphConfigured?.()
}
@@ -1920,7 +1930,6 @@ export class ComfyApp {
const nodeDefArray: ComfyNodeDef[] = Object.values(allNodeDefs)
this.#invokeExtensions('beforeRegisterVueAppNodeDefs', nodeDefArray, this)
nodeDefStore.updateNodeDefs(nodeDefArray)
nodeDefStore.widgets = this.widgets
}
/**
@@ -1938,7 +1947,11 @@ export class ComfyApp {
}
}
getWidgetType(inputData, inputName) {
/**
* Remove the impl after groupNode jest tests are removed.
* @deprecated Use useWidgetStore().getWidgetType instead
*/
getWidgetType(inputData, inputName: string) {
const type = inputData[0]
if (Array.isArray(type)) {
@@ -2092,13 +2105,6 @@ export class ComfyApp {
async registerNodesFromDefs(defs: Record<string, ComfyNodeDef>) {
await this.#invokeExtensionsAsync('addCustomNodeDefs', defs)
// Generate list of known widgets
this.widgets = Object.assign(
{},
ComfyWidgets,
...(await this.#invokeExtensionsAsync('getCustomWidgets')).filter(Boolean)
)
// Register a node for each definition
for (const nodeId in defs) {
this.registerNodeDef(nodeId, defs[nodeId])

View File

@@ -268,7 +268,7 @@ LGraphNode.prototype.addDOMWidget = function (
name: string,
type: string,
element: HTMLElement,
options: Record<string, any>
options: Record<string, any> = {}
): DOMWidget {
options = { hideOnZoom: true, selectOn: ['focus', 'click'], ...options }

View File

@@ -2,10 +2,11 @@
import { api } from './api'
import './domWidget'
import type { ComfyApp } from './app'
import type { IWidget, LGraphNode } from '@comfyorg/litegraph'
import type { LGraphNode } from '@comfyorg/litegraph'
import { InputSpec } from '@/types/apiTypes'
import { useSettingStore } from '@/stores/settingStore'
import { useToastStore } from '@/stores/toastStore'
import type { IWidget } from '@comfyorg/litegraph'
export type ComfyWidgetConstructor = (
node: LGraphNode,

View File

@@ -7,6 +7,7 @@ import { useSettingStore } from './settingStore'
import { app } from '@/scripts/app'
import { useMenuItemStore } from './menuItemStore'
import { useBottomPanelStore } from './workspace/bottomPanelStore'
import { useWidgetStore } from './widgetStore'
export const useExtensionStore = defineStore('extension', () => {
// For legacy reasons, the name uniquely identifies an extension
@@ -50,6 +51,16 @@ export const useExtensionStore = defineStore('extension', () => {
useMenuItemStore().loadExtensionMenuCommands(extension)
useSettingStore().loadExtensionSettings(extension)
useBottomPanelStore().registerExtensionBottomPanelTabs(extension)
if (extension.getCustomWidgets) {
// TODO(huchenlei): We should deprecate the async return value of
// getCustomWidgets.
;(async () => {
if (extension.getCustomWidgets) {
const widgets = await extension.getCustomWidgets(app)
useWidgetStore().registerCustomWidgets(widgets)
}
})()
}
/*
* Extensions are currently stored in both extensionStore and app.extensions.
* Legacy jest tests still depend on app.extensions being populated.

View File

@@ -323,20 +323,6 @@ export const useNodeDefStore = defineStore('nodeDef', () => {
nodeDefsByName.value[nodeDef.name] = nodeDefImpl
nodeDefsByDisplayName.value[nodeDef.display_name] = nodeDefImpl
}
function getWidgetType(type: string, inputName: string) {
if (type === 'COMBO') {
return 'COMBO'
} else if (`${type}:${inputName}` in widgets.value) {
return `${type}:${inputName}`
} else if (type in widgets.value) {
return type
} else {
return null
}
}
function inputIsWidget(spec: BaseInputSpec) {
return getWidgetType(spec.type, spec.name) !== null
}
function fromLGraphNode(node: LGraphNode): ComfyNodeDefImpl | null {
// Frontend-only nodes don't have nodeDef
return nodeDefsByName.value[node.constructor?.nodeData?.name] ?? null
@@ -345,7 +331,6 @@ export const useNodeDefStore = defineStore('nodeDef', () => {
return {
nodeDefsByName,
nodeDefsByDisplayName,
widgets,
showDeprecated,
showExperimental,
@@ -356,8 +341,6 @@ export const useNodeDefStore = defineStore('nodeDef', () => {
updateNodeDefs,
addNodeDef,
getWidgetType,
inputIsWidget,
fromLGraphNode
}
})

45
src/stores/widgetStore.ts Normal file
View File

@@ -0,0 +1,45 @@
import { ComfyWidgets, ComfyWidgetConstructor } from '@/scripts/widgets'
import { defineStore } from 'pinia'
import { ref, computed } from 'vue'
import type { BaseInputSpec } from './nodeDefStore'
export const useWidgetStore = defineStore('widget', () => {
const coreWidgets = ComfyWidgets
const customWidgets = ref<Record<string, ComfyWidgetConstructor>>({})
const widgets = computed(() => ({
...customWidgets.value,
...coreWidgets
}))
function getWidgetType(type: string, inputName: string) {
if (type === 'COMBO') {
return 'COMBO'
} else if (`${type}:${inputName}` in widgets.value) {
return `${type}:${inputName}`
} else if (type in widgets.value) {
return type
} else {
return null
}
}
function inputIsWidget(spec: BaseInputSpec) {
return getWidgetType(spec.type, spec.name) !== null
}
function registerCustomWidgets(
newWidgets: Record<string, ComfyWidgetConstructor>
) {
customWidgets.value = {
...customWidgets.value,
...newWidgets
}
}
return {
widgets,
getWidgetType,
inputIsWidget,
registerCustomWidgets
}
})

19
src/types/comfy.d.ts vendored
View File

@@ -1,20 +1,13 @@
import { LGraphNode, IWidget } from './litegraph'
import { ComfyApp } from '../scripts/app'
import type { LGraphNode } from './litegraph'
import type { ComfyApp } from '../scripts/app'
import type { ComfyNodeDef } from '@/types/apiTypes'
import type { Keybinding } from '@/types/keyBindingTypes'
import type { ComfyCommand } from '@/stores/commandStore'
import { SettingParams } from './settingTypes'
import type { SettingParams } from './settingTypes'
import type { BottomPanelExtension } from './extensionTypes'
import type { ComfyWidgetConstructor } from '@/scripts/widgets'
export type Widgets = Record<
string,
(
node,
inputName,
inputData,
app?: ComfyApp
) => { widget?: IWidget; minWidth?: number; minHeight?: number }
>
export type Widgets = Record<string, ComfyWidgetConstructor>
export interface AboutPageBadge {
label: string
@@ -94,6 +87,8 @@ export interface ComfyExtension {
defs: Record<string, ComfyNodeDef>,
app: ComfyApp
): Promise<void> | void
// TODO(huchenlei): We should deprecate the async return value of
// getCustomWidgets.
/**
* Allows the extension to add custom widgets
* @param app The ComfyUI app instance

View File

@@ -22,6 +22,8 @@ declare module '@comfyorg/litegraph' {
* Callback fired on each node after the graph is configured
*/
onAfterGraphConfigured?(): void
onGraphConfigured?(): void
onExecuted?(output: any): void
onNodeCreated?(this: LGraphNode): void
setInnerNodes?(nodes: LGraphNode[]): void
applyToGraph?(extraLinks?: LLink[]): void
@@ -38,7 +40,7 @@ declare module '@comfyorg/litegraph' {
name: string,
type: string,
element: HTMLElement,
options: Record<string, any>
options?: Record<string, any>
): DOMWidget
}

View File

@@ -69,6 +69,16 @@ module.exports = async function () {
}
})
jest.mock('@/stores/widgetStore', () => {
const widgets = {}
return {
useWidgetStore: () => ({
widgets,
registerCustomWidgets: jest.fn()
})
}
})
jest.mock('vue-i18n', () => {
return {
useI18n: jest.fn()

View File

@@ -1,213 +0,0 @@
// @ts-strict-ignore
import { start } from '../../utils'
import lg from '../../utils/litegraph'
describe('extensions', () => {
beforeEach(() => {
lg.setup(global)
})
afterEach(() => {
lg.teardown(global)
})
it('calls each extension hook', async () => {
const mockExtension = {
name: 'TestExtension',
init: jest.fn(),
setup: jest.fn(),
addCustomNodeDefs: jest.fn(),
getCustomWidgets: jest.fn(),
beforeRegisterNodeDef: jest.fn(),
registerCustomNodes: jest.fn(),
loadedGraphNode: jest.fn(),
nodeCreated: jest.fn(),
beforeConfigureGraph: jest.fn(),
afterConfigureGraph: jest.fn()
}
const { app, ez, graph } = await start({
async preSetup(app) {
app.registerExtension(mockExtension)
}
})
// Basic initialisation hooks should be called once, with app
expect(mockExtension.init).toHaveBeenCalledTimes(1)
expect(mockExtension.init).toHaveBeenCalledWith(app)
// Adding custom node defs should be passed the full list of nodes
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1)
expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app)
const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0]
expect(defs).toHaveProperty('KSampler')
expect(defs).toHaveProperty('LoadImage')
// Get custom widgets is called once and should return new widget types
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1)
expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app)
// Before register node def will be called once per node type
const nodeNames = Object.keys(defs)
const nodeCount = nodeNames.length
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount)
for (let i = 0; i < 10; i++) {
// It should be send the JS class and the original JSON definition
const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]
const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]
expect(nodeClass.name).toBe('ComfyNode')
expect(nodeClass.comfyClass).toBe(nodeNames[i])
expect(nodeDef.name).toBe(nodeNames[i])
expect(nodeDef).toHaveProperty('input')
expect(nodeDef).toHaveProperty('output')
}
// Register custom nodes is called once after registerNode defs to allow adding other frontend nodes
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1)
// Before configure graph will be called here as the default graph is being loaded
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1)
// it gets sent the graph data that is going to be loaded
const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0]
// A node created is fired for each node constructor that is called
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(
graphData.nodes.length
)
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(
graphData.nodes[i].type
)
}
// Each node then calls loadedGraphNode to allow them to be updated
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(
graphData.nodes.length
)
for (let i = 0; i < graphData.nodes.length; i++) {
expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(
graphData.nodes[i].type
)
}
// After configure is then called once all the setup is done
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1)
expect(mockExtension.setup).toHaveBeenCalledTimes(1)
expect(mockExtension.setup).toHaveBeenCalledWith(app)
// Ensure hooks are called in the correct order
const callOrder = [
'init',
'addCustomNodeDefs',
'getCustomWidgets',
'beforeRegisterNodeDef',
'registerCustomNodes',
'beforeConfigureGraph',
'nodeCreated',
'loadedGraphNode',
'afterConfigureGraph',
'setup'
]
for (let i = 1; i < callOrder.length; i++) {
const fn1 = mockExtension[callOrder[i - 1]]
const fn2 = mockExtension[callOrder[i]]
expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(
fn2.mock.invocationCallOrder[0]
)
}
graph.clear()
// Ensure adding a new node calls the correct callback
ez.LoadImage()
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(
graphData.nodes.length
)
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(
graphData.nodes.length + 1
)
expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe('LoadImage')
// Reload the graph to ensure correct hooks are fired
await graph.reload()
// These hooks should not be fired again
expect(mockExtension.init).toHaveBeenCalledTimes(1)
expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1)
expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1)
expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1)
expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount)
expect(mockExtension.setup).toHaveBeenCalledTimes(1)
// These should be called again
expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2)
expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(
graphData.nodes.length + 2
)
expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(
graphData.nodes.length + 1
)
expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2)
}, 15000)
it('allows custom nodeDefs and widgets to be registered', async () => {
const widgetMock = jest.fn((node, inputName, inputData, app) => {
expect(node.constructor.comfyClass).toBe('TestNode')
expect(inputName).toBe('test_input')
expect(inputData[0]).toBe('CUSTOMWIDGET')
expect(inputData[1]?.hello).toBe('world')
expect(app).toStrictEqual(app)
return {
widget: node.addWidget('button', inputName, 'hello', () => {})
}
})
// Register our extension that adds a custom node + widget type
const mockExtension = {
name: 'TestExtension',
addCustomNodeDefs: (nodeDefs) => {
nodeDefs['TestNode'] = {
output: [],
output_name: [],
output_is_list: [],
name: 'TestNode',
display_name: 'TestNode',
category: 'Test',
input: {
required: {
test_input: ['CUSTOMWIDGET', { hello: 'world' }]
}
}
}
},
getCustomWidgets: jest.fn(() => {
return {
CUSTOMWIDGET: widgetMock
}
})
}
const { graph, ez } = await start({
async preSetup(app) {
app.registerExtension(mockExtension)
}
})
expect(mockExtension.getCustomWidgets).toBeCalledTimes(1)
graph.clear()
expect(widgetMock).toBeCalledTimes(0)
const node = ez.TestNode()
expect(widgetMock).toBeCalledTimes(1)
// Ensure our custom widget is created
expect(node.inputs.length).toBe(0)
expect(node.widgets.length).toBe(1)
const w = node.widgets[0].widget
expect(w.name).toBe('test_input')
expect(w.type).toBe('button')
})
})