From 076acf1b31ee4331d5032b708cf56917fe9557ae Mon Sep 17 00:00:00 2001 From: bymyself Date: Sat, 13 Dec 2025 03:32:48 -0800 Subject: [PATCH] graph state store impl --- src/core/graph/state/graphStateStore.test.ts | 130 +++++++++++++++++++ src/core/graph/state/graphStateStore.ts | 79 +++++++++++ src/core/graph/state/useGraphErrorState.ts | 45 +++++++ src/scripts/app.ts | 22 ++++ src/stores/executionStore.ts | 75 ++++++----- 5 files changed, 313 insertions(+), 38 deletions(-) create mode 100644 src/core/graph/state/graphStateStore.test.ts create mode 100644 src/core/graph/state/graphStateStore.ts create mode 100644 src/core/graph/state/useGraphErrorState.ts diff --git a/src/core/graph/state/graphStateStore.test.ts b/src/core/graph/state/graphStateStore.test.ts new file mode 100644 index 000000000..9d0943218 --- /dev/null +++ b/src/core/graph/state/graphStateStore.test.ts @@ -0,0 +1,130 @@ +import { createPinia, setActivePinia } from 'pinia' +import { beforeEach, describe, expect, it } from 'vitest' + +import { useGraphStateStore } from './graphStateStore' + +describe('graphStateStore', () => { + beforeEach(() => { + setActivePinia(createPinia()) + }) + + describe('execute SetNodeError command', () => { + it('sets hasError on new node', () => { + const store = useGraphStateStore() + + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: '123', + hasError: true + }) + + expect(store.getNodeState('123')?.hasError).toBe(true) + }) + + it('updates hasError on existing node', () => { + const store = useGraphStateStore() + + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: '123', + hasError: true + }) + + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: '123', + hasError: false + }) + + expect(store.getNodeState('123')?.hasError).toBe(false) + }) + + it('handles subgraph node locator IDs', () => { + const store = useGraphStateStore() + + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: 'uuid-123:456', + hasError: true + }) + + expect(store.getNodeState('uuid-123:456')?.hasError).toBe(true) + }) + }) + + describe('execute ClearAllErrors command', () => { + it('clears all error flags', () => { + const store = useGraphStateStore() + + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: '1', + hasError: true + }) + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: '2', + hasError: true + }) + + store.execute({ type: 'ClearAllErrors', version: 1 }) + + expect(store.getNodeState('1')?.hasError).toBe(false) + expect(store.getNodeState('2')?.hasError).toBe(false) + }) + }) + + describe('getNodesWithErrors', () => { + it('returns only nodes with errors', () => { + const store = useGraphStateStore() + + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: '1', + hasError: true + }) + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: '2', + hasError: false + }) + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: '3', + hasError: true + }) + + const nodesWithErrors = store.getNodesWithErrors() + + expect(nodesWithErrors).toHaveLength(2) + expect(nodesWithErrors).toContain('1') + expect(nodesWithErrors).toContain('3') + expect(nodesWithErrors).not.toContain('2') + }) + }) + + describe('stateRef reactivity', () => { + it('increments revision on command execution', () => { + const store = useGraphStateStore() + const initialRevision = store.stateRef + + store.execute({ + type: 'SetNodeError', + version: 1, + nodeId: '1', + hasError: true + }) + + expect(store.stateRef).not.toBe(initialRevision) + }) + }) +}) diff --git a/src/core/graph/state/graphStateStore.ts b/src/core/graph/state/graphStateStore.ts new file mode 100644 index 000000000..4e609dbb4 --- /dev/null +++ b/src/core/graph/state/graphStateStore.ts @@ -0,0 +1,79 @@ +import { defineStore } from 'pinia' +import { customRef } from 'vue' + +import type { NodeLocatorId } from '@/types/nodeIdentification' + +export interface NodeState { + hasError: boolean +} + +export interface SetNodeErrorCommand { + type: 'SetNodeError' + version: 1 + nodeId: NodeLocatorId + hasError: boolean +} + +export interface ClearAllErrorsCommand { + type: 'ClearAllErrors' + version: 1 +} + +export type GraphStateCommand = SetNodeErrorCommand | ClearAllErrorsCommand + +export const useGraphStateStore = defineStore('graphState', () => { + const nodes = new Map() + + let revision = 0 + const stateRef = customRef((track, trigger) => ({ + get() { + track() + return revision + }, + set() { + revision++ + trigger() + } + })) + + const execute = (command: GraphStateCommand): void => { + switch (command.type) { + case 'SetNodeError': { + const existing = nodes.get(command.nodeId) + if (existing) { + existing.hasError = command.hasError + } else { + nodes.set(command.nodeId, { hasError: command.hasError }) + } + break + } + case 'ClearAllErrors': { + for (const state of nodes.values()) { + state.hasError = false + } + break + } + } + stateRef.value = revision + 1 + } + + const getNodeState = (nodeId: NodeLocatorId): NodeState | undefined => { + return nodes.get(nodeId) + } + + const getNodesWithErrors = (): NodeLocatorId[] => { + const result: NodeLocatorId[] = [] + for (const [nodeId, state] of nodes) { + if (state.hasError) result.push(nodeId) + } + return result + } + + return { + stateRef, + nodes, + execute, + getNodeState, + getNodesWithErrors + } +}) diff --git a/src/core/graph/state/useGraphErrorState.ts b/src/core/graph/state/useGraphErrorState.ts new file mode 100644 index 000000000..f3c60ab5f --- /dev/null +++ b/src/core/graph/state/useGraphErrorState.ts @@ -0,0 +1,45 @@ +import { watch } from 'vue' + +import type { LGraphNode } from '@/lib/litegraph/src/litegraph' +import { app } from '@/scripts/app' +import { + forEachNode, + forEachSubgraphNode, + getNodeByLocatorId +} from '@/utils/graphTraversalUtil' + +import { useGraphStateStore } from './graphStateStore' + +const propagateErrorToParents = (node: LGraphNode): void => { + const subgraph = node.graph + if (!subgraph || subgraph.isRootGraph) return + + forEachSubgraphNode(app.rootGraph, subgraph.id, (subgraphNode) => { + subgraphNode.has_errors = true + propagateErrorToParents(subgraphNode) + }) +} + +export const useGraphErrorState = () => { + const store = useGraphStateStore() + + watch( + () => store.stateRef, + () => { + if (!app.rootGraph) return + + forEachNode(app.rootGraph, (node) => { + node.has_errors = false + }) + + for (const locatorId of store.getNodesWithErrors()) { + const node = getNodeByLocatorId(app.rootGraph, locatorId) + if (!node) continue + + node.has_errors = true + propagateErrorToParents(node) + } + }, + { immediate: true } + ) +} diff --git a/src/scripts/app.ts b/src/scripts/app.ts index 7201fb68c..9dc0d775f 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -47,12 +47,14 @@ import { useSubscription } from '@/platform/cloud/subscription/composables/useSu import { useExtensionService } from '@/services/extensionService' import { useLitegraphService } from '@/services/litegraphService' import { useSubgraphService } from '@/services/subgraphService' +import { useGraphErrorState } from '@/core/graph/state/useGraphErrorState' import { useApiKeyAuthStore } from '@/stores/apiKeyAuthStore' import { useCommandStore } from '@/stores/commandStore' import { useDomWidgetStore } from '@/stores/domWidgetStore' import { useExecutionStore } from '@/stores/executionStore' import { useExtensionStore } from '@/stores/extensionStore' import { useFirebaseAuthStore } from '@/stores/firebaseAuthStore' +import { useGraphStateStore } from '@/core/graph/state/graphStateStore' import { useNodeOutputStore } from '@/stores/imagePreviewStore' import { KeyComboImpl, useKeybindingStore } from '@/stores/keybindingStore' import { useModelStore } from '@/stores/modelStore' @@ -78,6 +80,7 @@ import { findLegacyRerouteNodes, noNativeReroutes } from '@/utils/migration/migrateReroute' +import { collectMissingNodes } from '@/workbench/extensions/manager/utils/graphHasMissingNodes' import { getSelectedModelsMetadata } from '@/workbench/utils/modelMetadataUtil' import { deserialiseAndCreate } from '@/utils/vintageClipboard' @@ -764,6 +767,8 @@ export class ComfyApp { void useSubgraphStore().fetchSubgraphs() await useExtensionService().loadExtensions() + useGraphErrorState() + this.addProcessKeyHandler() this.addConfigureHandler() this.addApiUpdateHandlers() @@ -1231,6 +1236,23 @@ export class ComfyApp { }) } } + + const graphStateStore = useGraphStateStore() + const missingNodes = collectMissingNodes( + this.rootGraph, + useNodeDefStore().nodeDefsByName + ) + for (const node of missingNodes) { + const locatorId = node.graph?.isRootGraph + ? String(node.id) + : `${node.graph?.id}:${node.id}` + graphStateStore.execute({ + type: 'SetNodeError', + version: 1, + nodeId: locatorId, + hasError: true + }) + } } catch (error) { useDialogService().showErrorDialog(error, { title: t('errorDialog.loadWorkflowTitle'), diff --git a/src/stores/executionStore.ts b/src/stores/executionStore.ts index b9f89c29a..7798f837b 100644 --- a/src/stores/executionStore.ts +++ b/src/stores/executionStore.ts @@ -32,6 +32,7 @@ import { app } from '@/scripts/app' import { useNodeOutputStore } from '@/stores/imagePreviewStore' import type { NodeLocatorId } from '@/types/nodeIdentification' import { createNodeLocatorId } from '@/types/nodeIdentification' +import { useGraphStateStore } from '@/core/graph/state/graphStateStore' import { forEachNode, getNodeByExecutionId } from '@/utils/graphTraversalUtil' interface QueuedPrompt { @@ -574,56 +575,54 @@ export const useExecutionStore = defineStore('execution', () => { } /** - * Update node and slot error flags when validation errors change. - * Propagates errors up subgraph chains. + * Push execution errors to graphStateStore and handle slot errors. */ watch(lastNodeErrors, () => { - if (!app.rootGraph) return + const graphStateStore = useGraphStateStore() - // Clear all error flags - forEachNode(app.rootGraph, (node) => { - node.has_errors = false - if (node.inputs) { - for (const slot of node.inputs) { - slot.hasErrors = false + // Clear slot errors + if (app.rootGraph) { + forEachNode(app.rootGraph, (node) => { + if (node.inputs) { + for (const slot of node.inputs) { + slot.hasErrors = false + } } - } - }) + }) + } + + // Clear previous execution errors + graphStateStore.execute({ type: 'ClearAllErrors', version: 1 }) if (!lastNodeErrors.value) return - // Set error flags on nodes and slots + // Push execution errors to graphStateStore for (const [executionId, nodeError] of Object.entries( lastNodeErrors.value )) { - const node = getNodeByExecutionId(app.rootGraph, executionId) - if (!node) continue - - node.has_errors = true - - // Mark input slots with errors - if (node.inputs) { - for (const error of nodeError.errors) { - const slotName = error.extra_info?.input_name - if (!slotName) continue - - const slot = node.inputs.find((s) => s.name === slotName) - if (slot) { - slot.hasErrors = true - } - } + const locatorId = executionIdToNodeLocatorId(executionId) + if (locatorId) { + graphStateStore.execute({ + type: 'SetNodeError', + version: 1, + nodeId: locatorId, + hasError: true + }) } - // Propagate errors to parent subgraph nodes - const parts = executionId.split(':') - for (let i = parts.length - 1; i > 0; i--) { - const parentExecutionId = parts.slice(0, i).join(':') - const parentNode = getNodeByExecutionId( - app.rootGraph, - parentExecutionId - ) - if (parentNode) { - parentNode.has_errors = true + // Handle slot errors directly (not part of graphStateStore yet) + if (app.rootGraph) { + const node = getNodeByExecutionId(app.rootGraph, executionId) + if (node?.inputs) { + for (const error of nodeError.errors) { + const slotName = error.extra_info?.input_name + if (!slotName) continue + + const slot = node.inputs.find((s) => s.name === slotName) + if (slot) { + slot.hasErrors = true + } + } } } }