From a489c19b079cc212e41b39000042bdf4cb99605a Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 8 Apr 2025 18:32:43 -0400 Subject: [PATCH] Upstream rgthree's link fixer (#3350) --- browser_tests/assets/bad_link.json | 126 ++++++ browser_tests/fixtures/ComfyPage.ts | 2 +- browser_tests/tests/graph.spec.ts | 5 + src/composables/useWorkflowValidation.ts | 95 +++++ src/scripts/app.ts | 14 +- src/utils/linkFixer.ts | 477 +++++++++++++++++++++++ 6 files changed, 709 insertions(+), 10 deletions(-) create mode 100644 browser_tests/assets/bad_link.json create mode 100644 src/composables/useWorkflowValidation.ts create mode 100644 src/utils/linkFixer.ts diff --git a/browser_tests/assets/bad_link.json b/browser_tests/assets/bad_link.json new file mode 100644 index 000000000..c2806bbbb --- /dev/null +++ b/browser_tests/assets/bad_link.json @@ -0,0 +1,126 @@ +{ + "id": "51b9b184-770d-40ac-a478-8cc31667ff23", + "revision": 0, + "last_node_id": 5, + "last_link_id": 3, + "nodes": [ + { + "id": 4, + "type": "KSampler", + "pos": [ + 867.4669799804688, + 347.22369384765625 + ], + "size": [ + 315, + 262 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [ + { + "name": "model", + "type": "MODEL", + "link": null + }, + { + "name": "positive", + "type": "CONDITIONING", + "link": null + }, + { + "name": "negative", + "type": "CONDITIONING", + "link": null + }, + { + "name": "latent_image", + "type": "LATENT", + "link": null + }, + { + "name": "steps", + "type": "INT", + "widget": { + "name": "steps" + }, + "link": 3 + } + ], + "outputs": [ + { + "name": "LATENT", + "type": "LATENT", + "links": null + } + ], + "properties": { + "Node name for S&R": "KSampler" + }, + "widgets_values": [ + 0, + "randomize", + 20, + 8, + "euler", + "normal", + 1 + ] + }, + { + "id": 5, + "type": "PrimitiveInt", + "pos": [ + 443.0852355957031, + 441.131591796875 + ], + "size": [ + 315, + 82 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "INT", + "type": "INT", + "links": [ + 3 + ] + } + ], + "properties": { + "Node name for S&R": "PrimitiveInt" + }, + "widgets_values": [ + 0, + "randomize" + ] + } + ], + "links": [ + [ + 3, + 5, + 0, + 4, + 5, + "INT" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1.9487171000000016, + "offset": [ + -325.57196748514497, + -168.13150517966463 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/browser_tests/fixtures/ComfyPage.ts b/browser_tests/fixtures/ComfyPage.ts index e0e8af7a6..b61a55fd2 100644 --- a/browser_tests/fixtures/ComfyPage.ts +++ b/browser_tests/fixtures/ComfyPage.ts @@ -412,7 +412,7 @@ export class ComfyPage { } async getVisibleToastCount() { - return await this.page.locator('.p-toast:visible').count() + return await this.page.locator('.p-toast-message:visible').count() } async clickTextEncodeNode1() { diff --git a/browser_tests/tests/graph.spec.ts b/browser_tests/tests/graph.spec.ts index 7ba290ce8..a209ee5ba 100644 --- a/browser_tests/tests/graph.spec.ts +++ b/browser_tests/tests/graph.spec.ts @@ -13,4 +13,9 @@ test.describe('Graph', () => { }) ).toBe(1) }) + + test('Validate workflow links', async ({ comfyPage }) => { + await comfyPage.loadWorkflow('bad_link') + await expect(comfyPage.getVisibleToastCount()).resolves.toBe(2) + }) }) diff --git a/src/composables/useWorkflowValidation.ts b/src/composables/useWorkflowValidation.ts new file mode 100644 index 000000000..f16b18139 --- /dev/null +++ b/src/composables/useWorkflowValidation.ts @@ -0,0 +1,95 @@ +import type { ISerialisedGraph } from '@comfyorg/litegraph/dist/types/serialisation' + +import type { ComfyWorkflowJSON } from '@/schemas/comfyWorkflowSchema' +import { validateComfyWorkflow } from '@/schemas/comfyWorkflowSchema' +import { useToastStore } from '@/stores/toastStore' +import { fixBadLinks } from '@/utils/linkFixer' + +export interface ValidationResult { + graphData: ComfyWorkflowJSON | null + linksFixes?: { + patched: number + deleted: number + } +} + +export function useWorkflowValidation() { + const toastStore = useToastStore() + + /** + * Validates a workflow, including link validation and schema validation + */ + async function validateWorkflow( + graphData: ComfyWorkflowJSON, + options: { + silent?: boolean + } = {} + ): Promise { + const { silent = false } = options + + let linksFixes + let validatedData: ComfyWorkflowJSON | null = null + + // First do schema validation + const validatedGraphData = await validateComfyWorkflow( + graphData, + /* onError=*/ (err) => { + if (!silent) { + toastStore.addAlert(err) + } + } + ) + + if (validatedGraphData) { + // Collect all logs in an array + const logs: string[] = [] + // Then validate and fix links if schema validation passed + const linkValidation = fixBadLinks( + validatedGraphData as unknown as ISerialisedGraph, + { + fix: true, + silent, + logger: { + log: (message: string) => { + logs.push(message) + } + } + } + ) + + if (!silent && logs.length > 0) { + toastStore.add({ + severity: 'warn', + summary: 'Workflow Validation', + detail: logs.join('\n') + }) + } + + // If links were fixed, notify the user + if (linkValidation.fixed) { + if (!silent) { + toastStore.add({ + severity: 'info', + summary: 'Workflow Links Fixed', + detail: `Fixed ${linkValidation.patched} node connections and removed ${linkValidation.deleted} invalid links.` + }) + } + + validatedData = linkValidation.graph as unknown as ComfyWorkflowJSON + linksFixes = { + patched: linkValidation.patched, + deleted: linkValidation.deleted + } + } + } + + return { + graphData: validatedData, + linksFixes + } + } + + return { + validateWorkflow + } +} diff --git a/src/scripts/app.ts b/src/scripts/app.ts index 5d1e032bb..bd92274e3 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -11,6 +11,7 @@ import type { ToastMessageOptions } from 'primevue/toast' import { reactive } from 'vue' import { useCanvasPositionConversion } from '@/composables/element/useCanvasPositionConversion' +import { useWorkflowValidation } from '@/composables/useWorkflowValidation' import { st, t } from '@/i18n' import type { ExecutionErrorWsMessage, @@ -21,8 +22,7 @@ import { ComfyApiWorkflow, type ComfyWorkflowJSON, type ModelFile, - type NodeId, - validateComfyWorkflow + type NodeId } from '@/schemas/comfyWorkflowSchema' import type { ComfyNodeDef as ComfyNodeDefV1 } from '@/schemas/nodeDefSchema' import { getFromWebmFile } from '@/scripts/metadata/ebml' @@ -981,13 +981,9 @@ export class ComfyApp { graphData = clone(graphData) if (useSettingStore().get('Comfy.Validation.Workflows')) { - // TODO: Show validation error in a dialog. - const validatedGraphData = await validateComfyWorkflow( - graphData, - /* onError=*/ (err) => { - useToastStore().addAlert(err) - } - ) + const { graphData: validatedGraphData } = + await useWorkflowValidation().validateWorkflow(graphData) + // If the validation failed, use the original graph data. // Ideally we should not block users from loading the workflow. graphData = validatedGraphData ?? graphData diff --git a/src/utils/linkFixer.ts b/src/utils/linkFixer.ts new file mode 100644 index 000000000..10dcec67e --- /dev/null +++ b/src/utils/linkFixer.ts @@ -0,0 +1,477 @@ +/** + * This code is adapted from rgthree-comfy's link_fixer.ts + * @see https://github.com/rgthree/rgthree-comfy/blob/b84f39c7c224de765de0b54c55b967329011819d/src_web/common/link_fixer.ts + * + * MIT License + * + * Copyright (c) 2023 Regis Gaughan, III (rgthree) + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +import type { LGraph, LGraphNode, LLink } from '@comfyorg/litegraph' +import type { NodeId } from '@comfyorg/litegraph/dist/LGraphNode' +import type { SerialisedLLinkArray } from '@comfyorg/litegraph/dist/LLink' +import type { + ISerialisedGraph, + ISerialisedNode +} from '@comfyorg/litegraph/dist/types/serialisation' + +export interface BadLinksData { + hasBadLinks: boolean + fixed: boolean + graph: T + patched: number + deleted: number +} + +enum IoDirection { + INPUT, + OUTPUT +} + +function getNodeById(graph: ISerialisedGraph | LGraph, id: NodeId) { + if ((graph as LGraph).getNodeById) { + return (graph as LGraph).getNodeById(id) + } + graph = graph as ISerialisedGraph + return graph.nodes.find((node: ISerialisedNode) => node.id == id)! +} + +function extendLink(link: SerialisedLLinkArray) { + return { + link: link, + id: link[0], + origin_id: link[1], + origin_slot: link[2], + target_id: link[3], + target_slot: link[4], + type: link[5] + } +} + +/** + * Takes a ISerialisedGraph or live LGraph and inspects the links and nodes to ensure the linking + * makes logical sense. Can apply fixes when passed the `fix` argument as true. + * + * Note that fixes are a best-effort attempt. Seems to get it correct in most cases, but there is a + * chance it correct an anomoly that results in placing an incorrect link (say, if there were two + * links in the data). Users should take care to not overwrite work until manually checking the + * result. + */ +export function fixBadLinks( + graph: ISerialisedGraph | LGraph, + options: { + fix?: boolean + silent?: boolean + logger?: { log: (...args: any[]) => void } + } = {} +): BadLinksData { + const { fix = false, silent = false, logger: _logger = console } = options + const logger = { + log: (...args: any[]) => { + if (!silent) { + _logger.log(...args) + } + } + } + + const patchedNodeSlots: { + [nodeId: string]: { + inputs?: { [slot: number]: number | null } + outputs?: { + [slots: number]: { + links: number[] + changes: { [linkId: number]: 'ADD' | 'REMOVE' } + } + } + } + } = {} + + const data: { + patchedNodes: Array + deletedLinks: number[] + } = { + patchedNodes: [], + deletedLinks: [] + } + + /** + * Internal patch node. We keep track of changes in patchedNodeSlots in case we're in a dry run. + */ + async function patchNodeSlot( + node: ISerialisedNode | LGraphNode, + ioDir: IoDirection, + slot: number, + linkId: number, + op: 'ADD' | 'REMOVE' + ) { + patchedNodeSlots[node.id] = patchedNodeSlots[node.id] || {} + const patchedNode = patchedNodeSlots[node.id]! + if (ioDir == IoDirection.INPUT) { + patchedNode['inputs'] = patchedNode['inputs'] || {} + // We can set to null (delete), so undefined means we haven't set it at all. + if (patchedNode['inputs']![slot] !== undefined) { + logger.log( + ` > Already set ${node.id}.inputs[${slot}] to ${patchedNode[ + 'inputs' + ]![slot]!} Skipping.` + ) + return false + } + const linkIdToSet = op === 'REMOVE' ? null : linkId + patchedNode['inputs']![slot] = linkIdToSet + if (fix) { + // node.inputs[slot]!.link = linkIdToSet; + } + } else { + patchedNode['outputs'] = patchedNode['outputs'] || {} + patchedNode['outputs']![slot] = patchedNode['outputs']![slot] || { + links: [...(node.outputs?.[slot]?.links || [])], + changes: {} + } + if (patchedNode['outputs']![slot]!['changes']![linkId] !== undefined) { + logger.log( + ` > Already set ${node.id}.outputs[${slot}] to ${ + patchedNode['inputs']![slot] + }! Skipping.` + ) + return false + } + patchedNode['outputs']![slot]!['changes']![linkId] = op + if (op === 'ADD') { + const linkIdIndex = + patchedNode['outputs']![slot]!['links'].indexOf(linkId) + if (linkIdIndex !== -1) { + logger.log( + ` > Hmmm.. asked to add ${linkId} but it is already in list...` + ) + return false + } + patchedNode['outputs']![slot]!['links'].push(linkId) + if (fix) { + node.outputs = node.outputs || [] + node.outputs[slot] = node.outputs[slot] || ({} as any) + node.outputs[slot]!.links = node.outputs[slot]!.links || [] + node.outputs[slot]!.links!.push(linkId) + } + } else { + const linkIdIndex = + patchedNode['outputs']![slot]!['links'].indexOf(linkId) + if (linkIdIndex === -1) { + logger.log( + ` > Hmmm.. asked to remove ${linkId} but it doesn't exist...` + ) + return false + } + patchedNode['outputs']![slot]!['links'].splice(linkIdIndex, 1) + if (fix) { + node.outputs?.[slot]!.links!.splice(linkIdIndex, 1) + } + } + } + data.patchedNodes.push(node) + return true + } + + /** + * Internal to check if a node (or patched data) has a linkId. + */ + function nodeHasLinkId( + node: ISerialisedNode | LGraphNode, + ioDir: IoDirection, + slot: number, + linkId: number + ) { + // Patched data should be canonical. We can double check if fixing too. + let has = false + if (ioDir === IoDirection.INPUT) { + const nodeHasIt = node.inputs?.[slot]?.link === linkId + if (patchedNodeSlots[node.id]?.['inputs']) { + const patchedHasIt = + patchedNodeSlots[node.id]!['inputs']![slot] === linkId + // If we're fixing, double check that node matches. + if (fix && nodeHasIt !== patchedHasIt) { + throw Error('Error. Expected node to match patched data.') + } + has = patchedHasIt + } else { + has = !!nodeHasIt + } + } else { + const nodeHasIt = node.outputs?.[slot]?.links?.includes(linkId) + if (patchedNodeSlots[node.id]?.['outputs']?.[slot]?.['changes'][linkId]) { + const patchedHasIt = + patchedNodeSlots[node.id]!['outputs']![slot]?.links.includes(linkId) + // If we're fixing, double check that node matches. + if (fix && nodeHasIt !== patchedHasIt) { + throw Error('Error. Expected node to match patched data.') + } + has = !!patchedHasIt + } else { + has = !!nodeHasIt + } + } + return has + } + + /** + * Internal to check if a node (or patched data) has a linkId. + */ + function nodeHasAnyLink( + node: ISerialisedNode | LGraphNode, + ioDir: IoDirection, + slot: number + ) { + // Patched data should be canonical. We can double check if fixing too. + let hasAny = false + if (ioDir === IoDirection.INPUT) { + const nodeHasAny = node.inputs?.[slot]?.link != null + if (patchedNodeSlots[node.id]?.['inputs']) { + const patchedHasAny = + patchedNodeSlots[node.id]!['inputs']![slot] != null + // If we're fixing, double check that node matches. + if (fix && nodeHasAny !== patchedHasAny) { + throw Error('Error. Expected node to match patched data.') + } + hasAny = patchedHasAny + } else { + hasAny = !!nodeHasAny + } + } else { + const nodeHasAny = node.outputs?.[slot]?.links?.length + if (patchedNodeSlots[node.id]?.['outputs']?.[slot]?.['changes']) { + const patchedHasAny = + patchedNodeSlots[node.id]!['outputs']![slot]?.links.length + // If we're fixing, double check that node matches. + if (fix && nodeHasAny !== patchedHasAny) { + throw Error('Error. Expected node to match patched data.') + } + hasAny = !!patchedHasAny + } else { + hasAny = !!nodeHasAny + } + } + return hasAny + } + + let links: Array = [] + if (!Array.isArray(graph.links)) { + links = Object.values(graph.links).reduce((acc, v) => { + acc[v.id] = v + return acc + }, links) + } else { + links = graph.links + } + + const linksReverse = [...links] + linksReverse.reverse() + for (const l of linksReverse) { + if (!l) continue + const link = + (l as LLink).origin_slot != null + ? (l as LLink) + : extendLink(l as SerialisedLLinkArray) + + const originNode = getNodeById(graph, link.origin_id) + const originHasLink = () => + nodeHasLinkId(originNode!, IoDirection.OUTPUT, link.origin_slot, link.id) + const patchOrigin = (op: 'ADD' | 'REMOVE', id = link.id) => + patchNodeSlot(originNode!, IoDirection.OUTPUT, link.origin_slot, id, op) + + const targetNode = getNodeById(graph, link.target_id) + const targetHasLink = () => + nodeHasLinkId(targetNode!, IoDirection.INPUT, link.target_slot, link.id) + const targetHasAnyLink = () => + nodeHasAnyLink(targetNode!, IoDirection.INPUT, link.target_slot) + const patchTarget = (op: 'ADD' | 'REMOVE', id = link.id) => + patchNodeSlot(targetNode!, IoDirection.INPUT, link.target_slot, id, op) + + const originLog = `origin(${link.origin_id}).outputs[${link.origin_slot}].links` + const targetLog = `target(${link.target_id}).inputs[${link.target_slot}].link` + + if (!originNode || !targetNode) { + if (!originNode && !targetNode) { + logger.log( + `Link ${link.id} is invalid, ` + + `both origin ${link.origin_id} and target ${link.target_id} do not exist` + ) + } else if (!originNode) { + logger.log( + `Link ${link.id} is funky... ` + + `origin ${link.origin_id} does not exist, but target ${link.target_id} does.` + ) + if (targetHasLink()) { + logger.log( + ` > [PATCH] ${targetLog} does have link, will remove the inputs' link first.` + ) + patchTarget('REMOVE', -1) + } + } else if (!targetNode) { + logger.log( + `Link ${link.id} is funky... ` + + `target ${link.target_id} does not exist, but origin ${link.origin_id} does.` + ) + if (originHasLink()) { + logger.log( + ` > [PATCH] Origin's links' has ${link.id}; will remove the link first.` + ) + patchOrigin('REMOVE') + } + } + continue + } + + if (targetHasLink() || originHasLink()) { + if (!originHasLink()) { + logger.log( + `${link.id} is funky... ${originLog} does NOT contain it, but ${targetLog} does.` + ) + + logger.log( + ` > [PATCH] Attempt a fix by adding this ${link.id} to ${originLog}.` + ) + patchOrigin('ADD') + } else if (!targetHasLink()) { + logger.log( + `${link.id} is funky... ${targetLog} is NOT correct (is ${ + targetNode.inputs?.[link.target_slot]?.link + }), but ${originLog} contains it` + ) + if (!targetHasAnyLink()) { + logger.log( + ` > [PATCH] ${targetLog} is not defined, will set to ${link.id}.` + ) + let patched = patchTarget('ADD') + if (!patched) { + logger.log( + ` > [PATCH] Nvm, ${targetLog} already patched. Removing ${link.id} from ${originLog}.` + ) + patched = patchOrigin('REMOVE') + } + } else { + logger.log( + ` > [PATCH] ${targetLog} is defined, removing ${link.id} from ${originLog}.` + ) + patchOrigin('REMOVE') + } + } + } + } + + // Now that we've cleaned up the inputs, outputs, run through it looking for dangling links., + for (const l of linksReverse) { + if (!l) continue + const link = + (l as LLink).origin_slot != null + ? (l as LLink) + : extendLink(l as SerialisedLLinkArray) + const originNode = getNodeById(graph, link.origin_id) + const targetNode = getNodeById(graph, link.target_id) + // Now that we've manipulated the linking, check again if they both exist. + if ( + (!originNode || + !nodeHasLinkId( + originNode, + IoDirection.OUTPUT, + link.origin_slot, + link.id + )) && + (!targetNode || + !nodeHasLinkId( + targetNode, + IoDirection.INPUT, + link.target_slot, + link.id + )) + ) { + logger.log( + `${link.id} is def invalid; BOTH origin node ${link.origin_id} ${ + !originNode ? 'is removed' : `doesn't have ${link.id}` + } and ${link.origin_id} target node ${ + !targetNode ? 'is removed' : `doesn't have ${link.id}` + }.` + ) + data.deletedLinks.push(link.id) + continue + } + } + + // If we're fixing, then we've been patching along the way. Now go through and actually delete + // the zombie links from `app.graph.links` + if (fix) { + for (let i = data.deletedLinks.length - 1; i >= 0; i--) { + logger.log(`Deleting link #${data.deletedLinks[i]}.`) + if ((graph as LGraph).getNodeById) { + delete graph.links[data.deletedLinks[i]!] + } else { + graph = graph as ISerialisedGraph + // Sometimes we got objects for links if passed after ComfyUI's loadGraphData modifies the + // data. We make a copy now, but can handle the bastardized objects just in case. + const idx = graph.links.findIndex( + (l) => + l && + (l[0] === data.deletedLinks[i] || + (l as any).id === data.deletedLinks[i]) + ) + if (idx === -1) { + logger.log(`INDEX NOT FOUND for #${data.deletedLinks[i]}`) + } + logger.log(`splicing ${idx} from links`) + graph.links.splice(idx, 1) + } + } + // If we're a serialized graph, we can filter out the links because it's just an array. + if (!(graph as LGraph).getNodeById) { + graph.links = (graph as ISerialisedGraph).links.filter((l) => !!l) + } + } + if (!data.patchedNodes.length && !data.deletedLinks.length) { + return { + hasBadLinks: false, + fixed: false, + graph, + patched: data.patchedNodes.length, + deleted: data.deletedLinks.length + } + } + + logger.log( + `${fix ? 'Made' : 'Would make'} ${data.patchedNodes.length || 'no'} node link patches, and ${ + data.deletedLinks.length || 'no' + } stale link removals.` + ) + + let hasBadLinks: boolean = !!( + data.patchedNodes.length || data.deletedLinks.length + ) + // If we're fixing, then let's run it again to see if there are no more bad links. + if (fix && !silent) { + const rerun = fixBadLinks(graph, { fix: false, silent: true }) + hasBadLinks = rerun.hasBadLinks + } + + return { + hasBadLinks, + fixed: !!hasBadLinks && fix, + graph, + patched: data.patchedNodes.length, + deleted: data.deletedLinks.length + } +}