diff --git a/src/scripts/app.ts b/src/scripts/app.ts index e70fe8d22..a5e0b9134 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -39,6 +39,7 @@ import { getSvgMetadata } from '@/scripts/metadata/svg' import { useDialogService } from '@/services/dialogService' import { useExtensionService } from '@/services/extensionService' import { useLitegraphService } from '@/services/litegraphService' +import { useSubgraphService } from '@/services/subgraphService' import { useWorkflowService } from '@/services/workflowService' import { useApiKeyAuthStore } from '@/stores/apiKeyAuthStore' import { useCommandStore } from '@/stores/commandStore' @@ -764,6 +765,20 @@ export class ComfyApp { this.#graph = new LGraph() + // Register the subgraph - adds type wrapper for Litegraph's `createNode` factory + this.graph.events.addEventListener('subgraph-created', (e) => { + try { + useSubgraphService().registerNewSubgraph(e.detail) + } catch (err) { + console.error('Failed to register subgraph', err) + useToastStore().add({ + severity: 'error', + summary: 'Failed to register subgraph', + detail: err instanceof Error ? err.message : String(err) + }) + } + }) + this.#addAfterConfigureHandler() this.canvas = new LGraphCanvas(canvasEl, this.graph) @@ -1012,6 +1027,7 @@ export class ComfyApp { }) } useWorkflowService().beforeLoadNewGraph() + useSubgraphService().loadSubgraphs(graphData) const missingNodeTypes: MissingNodeType[] = [] const missingModels: ModelFile[] = [] diff --git a/src/services/litegraphService.ts b/src/services/litegraphService.ts index 256c06723..d8b123a56 100644 --- a/src/services/litegraphService.ts +++ b/src/services/litegraphService.ts @@ -5,10 +5,13 @@ import { LGraphNode, LiteGraph, RenderShape, + type Subgraph, + SubgraphNode, type Vector2, createBounds } from '@comfyorg/litegraph' import type { + ExportedSubgraphInstance, ISerialisableNodeInput, ISerialisableNodeOutput, ISerialisedNode @@ -56,6 +59,260 @@ export const useLitegraphService = () => { const widgetStore = useWidgetStore() const canvasStore = useCanvasStore() + // TODO: Dedupe `registerNodeDef`; this should remain synchronous. + function registerSubgraphNodeDef( + nodeDefV1: ComfyNodeDefV1, + subgraph: Subgraph, + instanceData: ExportedSubgraphInstance + ) { + const node = class ComfyNode extends SubgraphNode { + static comfyClass: string + static override title: string + static override category: string + static nodeData: ComfyNodeDefV1 & ComfyNodeDefV2 + + /** + * @internal The initial minimum size of the node. + */ + #initialMinSize = { width: 1, height: 1 } + /** + * @internal The key for the node definition in the i18n file. + */ + get #nodeKey(): string { + return `nodeDefs.${normalizeI18nKey(ComfyNode.nodeData.name)}` + } + + constructor() { + super(app.graph, subgraph, instanceData) + + this.#setupStrokeStyles() + this.#addInputs(ComfyNode.nodeData.inputs) + this.#addOutputs(ComfyNode.nodeData.outputs) + this.#setInitialSize() + this.serialize_widgets = true + void extensionService.invokeExtensionsAsync('nodeCreated', this) + } + + /** + * @internal Setup stroke styles for the node under various conditions. + */ + #setupStrokeStyles() { + this.strokeStyles['running'] = function (this: LGraphNode) { + if (this.id == app.runningNodeId) { + return { color: '#0f0' } + } + } + this.strokeStyles['nodeError'] = function (this: LGraphNode) { + if (app.lastNodeErrors?.[this.id]?.errors) { + return { color: 'red' } + } + } + this.strokeStyles['dragOver'] = function (this: LGraphNode) { + if (app.dragOverNode?.id == this.id) { + return { color: 'dodgerblue' } + } + } + this.strokeStyles['executionError'] = function (this: LGraphNode) { + if (app.lastExecutionError?.node_id == this.id) { + return { color: '#f0f', lineWidth: 2 } + } + } + } + + /** + * @internal Add input sockets to the node. (No widget) + */ + #addInputSocket(inputSpec: InputSpec) { + const inputName = inputSpec.name + const nameKey = `${this.#nodeKey}.inputs.${normalizeI18nKey(inputName)}.name` + const widgetConstructor = widgetStore.widgets.get( + inputSpec.widgetType ?? inputSpec.type + ) + if (widgetConstructor && !inputSpec.forceInput) return + + this.addInput(inputName, inputSpec.type, { + shape: inputSpec.isOptional ? RenderShape.HollowCircle : undefined, + localized_name: st(nameKey, inputName) + }) + } + + /** + * @internal Add a widget to the node. For both primitive types and custom widgets + * (unless `socketless`), an input socket is also added. + */ + #addInputWidget(inputSpec: InputSpec) { + const widgetInputSpec = { ...inputSpec } + if (inputSpec.widgetType) { + widgetInputSpec.type = inputSpec.widgetType + } + const inputName = inputSpec.name + const nameKey = `${this.#nodeKey}.inputs.${normalizeI18nKey(inputName)}.name` + const widgetConstructor = widgetStore.widgets.get(widgetInputSpec.type) + if (!widgetConstructor || inputSpec.forceInput) return + + const { + widget, + minWidth = 1, + minHeight = 1 + } = widgetConstructor( + this, + inputName, + transformInputSpecV2ToV1(widgetInputSpec), + app + ) ?? {} + + if (widget) { + widget.label = st(nameKey, widget.label ?? inputName) + widget.options ??= {} + Object.assign(widget.options, { + advanced: inputSpec.advanced, + hidden: inputSpec.hidden + }) + } + + if (!widget?.options?.socketless) { + const inputSpecV1 = transformInputSpecV2ToV1(widgetInputSpec) + this.addInput(inputName, inputSpec.type, { + shape: inputSpec.isOptional ? RenderShape.HollowCircle : undefined, + localized_name: st(nameKey, inputName), + widget: { name: inputName, [GET_CONFIG]: () => inputSpecV1 } + }) + } + + this.#initialMinSize.width = Math.max( + this.#initialMinSize.width, + minWidth + ) + this.#initialMinSize.height = Math.max( + this.#initialMinSize.height, + minHeight + ) + } + + /** + * @internal Add inputs to the node. + */ + #addInputs(inputs: Record) { + for (const inputSpec of Object.values(inputs)) + this.#addInputSocket(inputSpec) + for (const inputSpec of Object.values(inputs)) + this.#addInputWidget(inputSpec) + } + + /** + * @internal Add outputs to the node. + */ + #addOutputs(outputs: OutputSpec[]) { + for (const output of outputs) { + const { name, type, is_list } = output + const shapeOptions = is_list ? { shape: LiteGraph.GRID_SHAPE } : {} + const nameKey = `${this.#nodeKey}.outputs.${output.index}.name` + const typeKey = `dataTypes.${normalizeI18nKey(type)}` + const outputOptions = { + ...shapeOptions, + // If the output name is different from the output type, use the output name. + // e.g. + // - type ("INT"); name ("Positive") => translate name + // - type ("FLOAT"); name ("FLOAT") => translate type + localized_name: + type !== name ? st(nameKey, name) : st(typeKey, name) + } + this.addOutput(name, type, outputOptions) + } + } + + /** + * @internal Set the initial size of the node. + */ + #setInitialSize() { + const s = this.computeSize() + // Expand the width a little to fit widget values on screen. + const pad = + this.widgets?.length && + !useSettingStore().get('LiteGraph.Node.DefaultPadding') + s[0] = Math.max(this.#initialMinSize.width, s[0] + (pad ? 60 : 0)) + s[1] = Math.max(this.#initialMinSize.height, s[1]) + this.setSize(s) + } + + /** + * Configure the node from a serialised node. Keep 'name', 'type', 'shape', + * and 'localized_name' information from the original node definition. + */ + override configure(data: ISerialisedNode): void { + const RESERVED_KEYS = ['name', 'type', 'shape', 'localized_name'] + + // Note: input name is unique in a node definition, so we can lookup + // input by name. + const inputByName = new Map( + data.inputs?.map((input) => [input.name, input]) ?? [] + ) + // Inputs defined by the node definition. + const definedInputNames = new Set( + this.inputs.map((input) => input.name) + ) + const definedInputs = this.inputs.map((input) => { + const inputData = inputByName.get(input.name) + return inputData + ? { + ...inputData, + // Whether the input has associated widget follows the + // original node definition. + ..._.pick(input, RESERVED_KEYS.concat('widget')) + } + : input + }) + // Extra inputs that potentially dynamically added by custom js logic. + const extraInputs = data.inputs?.filter( + (input) => !definedInputNames.has(input.name) + ) + data.inputs = [...definedInputs, ...(extraInputs ?? [])] + + // Note: output name is not unique, so we cannot lookup output by name. + // Use index instead. + data.outputs = _.zip(this.outputs, data.outputs).map( + ([output, outputData]) => { + // If there are extra outputs in the serialised node, use them directly. + // There are currently custom nodes that dynamically add outputs via + // js logic. + if (!output) return outputData as ISerialisableNodeOutput + + return outputData + ? { + ...outputData, + ..._.pick(output, RESERVED_KEYS) + } + : output + } + ) + + data.widgets_values = migrateWidgetsValues( + ComfyNode.nodeData.inputs, + this.widgets ?? [], + data.widgets_values ?? [] + ) + + super.configure(data) + } + } + + addNodeContextMenuHandler(node) + addDrawBackgroundHandler(node) + addNodeKeyHandler(node) + // Note: Some extensions expects node.comfyClass to be set in + // `beforeRegisterNodeDef`. + node.prototype.comfyClass = nodeDefV1.name + node.comfyClass = nodeDefV1.name + + const nodeDef = new ComfyNodeDefImpl(nodeDefV1) + node.nodeData = nodeDef + LiteGraph.registerNodeType(subgraph.id, node) + // Note: Do not following assignments before `LiteGraph.registerNodeType` + // because `registerNodeType` will overwrite the assignments. + node.category = nodeDef.category + node.title = nodeDef.display_name || nodeDef.name + } + async function registerNodeDef(nodeId: string, nodeDefV1: ComfyNodeDefV1) { const node = class ComfyNode extends LGraphNode { static comfyClass: string @@ -665,6 +922,7 @@ export const useLitegraphService = () => { return { registerNodeDef, + registerSubgraphNodeDef, addNodeOnGraph, getCanvasCenter, goToNode, diff --git a/src/services/subgraphService.ts b/src/services/subgraphService.ts new file mode 100644 index 000000000..97d7f3cac --- /dev/null +++ b/src/services/subgraphService.ts @@ -0,0 +1,85 @@ +import { + type ExportedSubgraph, + type ExportedSubgraphInstance, + type Subgraph +} from '@comfyorg/litegraph' + +import type { ComfyWorkflowJSON } from '@/schemas/comfyWorkflowSchema' +import type { ComfyNodeDef as ComfyNodeDefV1 } from '@/schemas/nodeDefSchema' +import { app as comfyApp } from '@/scripts/app' +import { useNodeDefStore } from '@/stores/nodeDefStore' + +import { useLitegraphService } from './litegraphService' + +export const useSubgraphService = () => { + /** @todo Move to store */ + const subgraphs: Subgraph[] = [] + + /** Loads a single subgraph definition and registers it with the node def store */ + const deserialiseSubgraph = ( + subgraph: Subgraph, + exportedSubgraph: ExportedSubgraph + ) => { + const { id, name } = exportedSubgraph + + const nodeDef: ComfyNodeDefV1 = { + input: { required: {} }, + output: [], + output_is_list: [], + output_name: [], + output_tooltips: [], + name: id, + display_name: name, + description: `Subgraph node for ${name}`, + category: 'subgraph', + output_node: false, + python_module: 'nodes' + } + + useNodeDefStore().addNodeDef(nodeDef) + + const instanceData: ExportedSubgraphInstance = { + id: -1, + type: exportedSubgraph.id, + pos: [0, 0], + size: [100, 100], + inputs: [], + outputs: [], + flags: {}, + order: 0, + mode: 0 + } + + useLitegraphService().registerSubgraphNodeDef( + nodeDef, + subgraph, + instanceData + ) + } + + /** Loads all exported subgraph definitionsfrom workflow */ + const loadSubgraphs = (graphData: ComfyWorkflowJSON) => { + if (!graphData.definitions?.subgraphs) return + + for (const subgraphData of graphData.definitions.subgraphs) { + const subgraph = + subgraphs.find((x) => x.id === subgraphData.id) ?? + comfyApp.graph.createSubgraph(subgraphData as ExportedSubgraph) + + // @ts-expect-error Zod + deserialiseSubgraph(subgraph, subgraphData) + } + } + + /** Registers a new subgraph (e.g. user converted from nodes) */ + const registerNewSubgraph = (subgraph: Subgraph) => { + subgraphs.push(subgraph) + + deserialiseSubgraph(subgraph, subgraph.asSerialisable()) + } + + return { + loadSubgraphs, + registerNewSubgraph + } +} diff --git a/src/types/litegraph-augmentation.d.ts b/src/types/litegraph-augmentation.d.ts index dbdc388f2..fee3e515a 100644 --- a/src/types/litegraph-augmentation.d.ts +++ b/src/types/litegraph-augmentation.d.ts @@ -75,6 +75,22 @@ declare module '@comfyorg/litegraph' { // eslint-disable-next-line @typescript-eslint/no-empty-object-type interface BaseWidget extends IBaseWidget {} + /** Actual members required for execution. */ + type ExecutableLGraphNode = Pick< + LGraphNode, + | 'id' + | 'type' + | 'comfyClass' + | 'title' + | 'mode' + | 'inputs' + | 'widgets' + | 'isVirtualNode' + | 'applyToGraph' + | 'getInputNode' + | 'getInputLink' + > + interface LGraphNode { constructor: LGraphNodeConstructor @@ -88,7 +104,10 @@ declare module '@comfyorg/litegraph' { /** @deprecated groupNode */ setInnerNodes?(nodes: LGraphNode[]): void /** Originally a group node API. */ - getInnerNodes?(): LGraphNode[] + getInnerNodes?( + nodes?: ExecutableLGraphNode[], + subgraphs?: WeakSet + ): ExecutableLGraphNode[] /** @deprecated groupNode */ convertToNodes?(): LGraphNode[] recreate?(): Promise