diff --git a/src/scripts/app.ts b/src/scripts/app.ts index b8c8a81f09..c99823fcd7 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -38,6 +38,7 @@ import { } from '@/types/comfyWorkflow' import { ExtensionManager } from '@/types/extensionTypes' import { ColorAdjustOptions, adjustColor } from '@/utils/colorUtil' +import { graphToPrompt } from '@/utils/executionUtil' import { isImageNode } from '@/utils/litegraphUtil' import { deserialiseAndCreate } from '@/utils/vintageClipboard' @@ -1226,176 +1227,11 @@ export class ComfyApp { return graph.serialize({ sortNodes }) } - /** - * Converts the current graph workflow for sending to the API. - * Note: Node widgets are updated before serialization to prepare queueing. - * @returns The workflow and node links - */ async graphToPrompt(graph = this.graph, clean = true) { - for (const outerNode of graph.computeExecutionOrder(false)) { - if (outerNode.widgets) { - for (const widget of outerNode.widgets) { - // Allow widgets to run callbacks before a prompt has been queued - // e.g. random seed before every gen - widget.beforeQueued?.() - } - } - - const innerNodes = outerNode.getInnerNodes - ? outerNode.getInnerNodes() - : [outerNode] - for (const node of innerNodes) { - if (node.isVirtualNode) { - // Don't serialize frontend only nodes but let them make changes - if (node.applyToGraph) { - node.applyToGraph() - } - } - } - } - - const workflow = this.serializeGraph(graph) - - // Remove localized_name from the workflow - for (const node of workflow.nodes) { - for (const slot of node.inputs) { - delete slot.localized_name - } - for (const slot of node.outputs) { - delete slot.localized_name - } - } - - const output = {} - // Process nodes in order of execution - for (const outerNode of graph.computeExecutionOrder(false)) { - const skipNode = - outerNode.mode === LGraphEventMode.NEVER || - outerNode.mode === LGraphEventMode.BYPASS - const innerNodes = - !skipNode && outerNode.getInnerNodes - ? outerNode.getInnerNodes() - : [outerNode] - for (const node of innerNodes) { - if (node.isVirtualNode) { - continue - } - - if ( - node.mode === LGraphEventMode.NEVER || - node.mode === LGraphEventMode.BYPASS - ) { - // Don't serialize muted nodes - continue - } - - const inputs = {} - const widgets = node.widgets - - // Store all widget values - if (widgets) { - for (let i = 0; i < widgets.length; i++) { - const widget = widgets[i] - if (!widget.options || widget.options.serialize !== false) { - inputs[widget.name] = widget.serializeValue - ? await widget.serializeValue(node, i) - : widget.value - } - } - } - - // Store all node links - for (let i = 0; i < node.inputs.length; i++) { - let parent = node.getInputNode(i) - if (parent) { - let link = node.getInputLink(i) - while ( - parent.mode === LGraphEventMode.BYPASS || - parent.isVirtualNode - ) { - let found = false - if (parent.isVirtualNode) { - link = parent.getInputLink(link.origin_slot) - if (link) { - parent = parent.getInputNode(link.target_slot) - if (parent) { - found = true - } - } - } else if (link && parent.mode === LGraphEventMode.BYPASS) { - let all_inputs = [link.origin_slot] - if (parent.inputs) { - // @ts-expect-error convert list of strings to list of numbers - all_inputs = all_inputs.concat(Object.keys(parent.inputs)) - for (let parent_input in all_inputs) { - // @ts-expect-error assign string to number - parent_input = all_inputs[parent_input] - if ( - parent.inputs[parent_input]?.type === node.inputs[i].type - ) { - // @ts-expect-error convert string to number - link = parent.getInputLink(parent_input) - if (link) { - // @ts-expect-error convert string to number - parent = parent.getInputNode(parent_input) - } - found = true - break - } - } - } - } - - if (!found) { - break - } - } - - if (link) { - if (parent?.updateLink) { - link = parent.updateLink(link) - } - if (link) { - inputs[node.inputs[i].name] = [ - String(link.origin_id), - // @ts-expect-error link.origin_slot is already number. - parseInt(link.origin_slot) - ] - } - } - } - } - - const node_data = { - inputs, - class_type: node.comfyClass - } - - // Ignored by the backend. - node_data['_meta'] = { - title: node.title - } - - output[String(node.id)] = node_data - } - } - - // Remove inputs connected to removed nodes - if (clean) { - for (const o in output) { - for (const i in output[o].inputs) { - if ( - Array.isArray(output[o].inputs[i]) && - output[o].inputs[i].length === 2 && - !output[output[o].inputs[i][0]] - ) { - delete output[o].inputs[i] - } - } - } - } - - return { workflow, output } + return graphToPrompt(graph, { + clean, + sortNodes: useSettingStore().get('Comfy.Workflow.SortNodeIdOnSave') + }) } #formatPromptError(error) { @@ -1444,7 +1280,6 @@ export class ComfyApp { const p = await this.graphToPrompt() try { - // @ts-expect-error Discrepancies between zod and litegraph - in progress const res = await api.queuePrompt(number, p) this.lastNodeErrors = res.node_errors if (this.lastNodeErrors.length > 0) { diff --git a/src/types/comfyWorkflow.ts b/src/types/comfyWorkflow.ts index 7598438a12..492ae68910 100644 --- a/src/types/comfyWorkflow.ts +++ b/src/types/comfyWorkflow.ts @@ -5,6 +5,7 @@ import { fromZodError } from 'zod-validation-error' // innerNode.id = `${this.node.id}:${i}` // Remove it after GroupNode is redesigned. export const zNodeId = z.union([z.number().int(), z.string()]) +export const zNodeInputName = z.string() export type NodeId = z.infer export const zSlotIndex = z.union([ z.number().int(), @@ -96,7 +97,7 @@ const zNodeOutput = z const zNodeInput = z .object({ - name: z.string(), + name: zNodeInputName, type: zDataType, link: z.number().nullable().optional(), slot_index: zSlotIndex.optional() @@ -251,3 +252,24 @@ export async function validateComfyWorkflow( onError(`Invalid workflow against zod schema:\n${error}`) return null } + +/** + * API format workflow for direct API usage. + */ +const zNodeInputValue = z.union([ + // For widget values (can be any type) + z.any(), + // For node links [nodeId, slotIndex] + z.tuple([zNodeId, zSlotIndex]) +]) + +const zNodeData = z.object({ + inputs: z.record(zNodeInputName, zNodeInputValue), + class_type: z.string(), + _meta: z.object({ + title: z.string() + }) +}) + +export const zComfyApiWorkflow = z.record(zNodeId, zNodeData) +export type ComfyApiWorkflow = z.infer diff --git a/src/utils/executionUtil.ts b/src/utils/executionUtil.ts new file mode 100644 index 0000000000..4a990bff63 --- /dev/null +++ b/src/utils/executionUtil.ts @@ -0,0 +1,184 @@ +import type { LGraph } from '@comfyorg/litegraph' +import { LGraphEventMode } from '@comfyorg/litegraph' + +import type { ComfyApiWorkflow, ComfyWorkflowJSON } from '@/types/comfyWorkflow' + +/** + * Converts the current graph workflow for sending to the API. + * Note: Node widgets are updated before serialization to prepare queueing. + * @returns The workflow and node links + */ +export const graphToPrompt = async ( + graph: LGraph, + options: { clean?: boolean; sortNodes?: boolean } = {} +): Promise<{ workflow: ComfyWorkflowJSON; output: ComfyApiWorkflow }> => { + const { clean = true, sortNodes = false } = options + + for (const outerNode of graph.computeExecutionOrder(false)) { + if (outerNode.widgets) { + for (const widget of outerNode.widgets) { + // Allow widgets to run callbacks before a prompt has been queued + // e.g. random seed before every gen + widget.beforeQueued?.() + } + } + + const innerNodes = outerNode.getInnerNodes + ? outerNode.getInnerNodes() + : [outerNode] + for (const node of innerNodes) { + if (node.isVirtualNode) { + // Don't serialize frontend only nodes but let them make changes + if (node.applyToGraph) { + node.applyToGraph() + } + } + } + } + + const workflow = graph.serialize({ sortNodes }) + + // Remove localized_name from the workflow + for (const node of workflow.nodes) { + for (const slot of node.inputs ?? []) { + delete slot.localized_name + } + for (const slot of node.outputs ?? []) { + delete slot.localized_name + } + } + + const output: ComfyApiWorkflow = {} + // Process nodes in order of execution + for (const outerNode of graph.computeExecutionOrder(false)) { + const skipNode = + outerNode.mode === LGraphEventMode.NEVER || + outerNode.mode === LGraphEventMode.BYPASS + const innerNodes = + !skipNode && outerNode.getInnerNodes + ? outerNode.getInnerNodes() + : [outerNode] + for (const node of innerNodes) { + if (node.isVirtualNode) { + continue + } + + if ( + node.mode === LGraphEventMode.NEVER || + node.mode === LGraphEventMode.BYPASS + ) { + // Don't serialize muted nodes + continue + } + + const inputs: ComfyApiWorkflow[string]['inputs'] = {} + const widgets = node.widgets + + // Store all widget values + if (widgets) { + for (let i = 0; i < widgets.length; i++) { + const widget = widgets[i] + if ( + widget.name && + (!widget.options || widget.options.serialize !== false) + ) { + inputs[widget.name] = widget.serializeValue + ? await widget.serializeValue(node, i) + : widget.value + } + } + } + + // Store all node links + for (let i = 0; i < node.inputs.length; i++) { + let parent = node.getInputNode(i) + if (parent) { + let link = node.getInputLink(i) + while ( + parent.mode === LGraphEventMode.BYPASS || + parent.isVirtualNode + ) { + let found = false + if (parent.isVirtualNode) { + link = link ? parent.getInputLink(link.origin_slot) : null + if (link) { + parent = parent.getInputNode(link.target_slot) + if (parent) { + found = true + } + } + } else if (link && parent.mode === LGraphEventMode.BYPASS) { + let all_inputs = [link.origin_slot] + if (parent.inputs) { + // @ts-expect-error convert list of strings to list of numbers + all_inputs = all_inputs.concat(Object.keys(parent.inputs)) + for (let parent_input in all_inputs) { + // @ts-expect-error assign string to number + parent_input = all_inputs[parent_input] + if ( + parent.inputs[parent_input]?.type === node.inputs[i].type + ) { + // @ts-expect-error convert string to number + link = parent.getInputLink(parent_input) + if (link) { + // @ts-expect-error convert string to number + parent = parent.getInputNode(parent_input) + } + found = true + break + } + } + } + } + + if (!found) { + break + } + } + + if (link) { + if (parent?.updateLink) { + link = parent.updateLink(link) + } + if (link) { + inputs[node.inputs[i].name] = [ + String(link.origin_id), + // @ts-expect-error link.origin_slot is already number. + parseInt(link.origin_slot) + ] + } + } + } + } + + output[String(node.id)] = { + inputs, + // TODO(huchenlei): Filter out all nodes that cannot be mapped to a + // comfyClass. + class_type: node.comfyClass!, + // Ignored by the backend. + _meta: { + title: node.title + } + } + } + } + + // Remove inputs connected to removed nodes + if (clean) { + for (const o in output) { + for (const i in output[o].inputs) { + if ( + Array.isArray(output[o].inputs[i]) && + output[o].inputs[i].length === 2 && + !output[output[o].inputs[i][0]] + ) { + delete output[o].inputs[i] + } + } + } + } + + // @ts-expect-error Convert ISerializedGraph to ComfyWorkflowJSON + return { workflow: workflow as ComfyWorkflowJSON, output } +}