graph state store impl

This commit is contained in:
bymyself
2025-12-13 03:32:48 -08:00
parent 7613e70f63
commit 076acf1b31
5 changed files with 313 additions and 38 deletions

View File

@@ -0,0 +1,130 @@
import { createPinia, setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it } from 'vitest'
import { useGraphStateStore } from './graphStateStore'
describe('graphStateStore', () => {
beforeEach(() => {
setActivePinia(createPinia())
})
describe('execute SetNodeError command', () => {
it('sets hasError on new node', () => {
const store = useGraphStateStore()
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: '123',
hasError: true
})
expect(store.getNodeState('123')?.hasError).toBe(true)
})
it('updates hasError on existing node', () => {
const store = useGraphStateStore()
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: '123',
hasError: true
})
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: '123',
hasError: false
})
expect(store.getNodeState('123')?.hasError).toBe(false)
})
it('handles subgraph node locator IDs', () => {
const store = useGraphStateStore()
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: 'uuid-123:456',
hasError: true
})
expect(store.getNodeState('uuid-123:456')?.hasError).toBe(true)
})
})
describe('execute ClearAllErrors command', () => {
it('clears all error flags', () => {
const store = useGraphStateStore()
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: '1',
hasError: true
})
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: '2',
hasError: true
})
store.execute({ type: 'ClearAllErrors', version: 1 })
expect(store.getNodeState('1')?.hasError).toBe(false)
expect(store.getNodeState('2')?.hasError).toBe(false)
})
})
describe('getNodesWithErrors', () => {
it('returns only nodes with errors', () => {
const store = useGraphStateStore()
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: '1',
hasError: true
})
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: '2',
hasError: false
})
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: '3',
hasError: true
})
const nodesWithErrors = store.getNodesWithErrors()
expect(nodesWithErrors).toHaveLength(2)
expect(nodesWithErrors).toContain('1')
expect(nodesWithErrors).toContain('3')
expect(nodesWithErrors).not.toContain('2')
})
})
describe('stateRef reactivity', () => {
it('increments revision on command execution', () => {
const store = useGraphStateStore()
const initialRevision = store.stateRef
store.execute({
type: 'SetNodeError',
version: 1,
nodeId: '1',
hasError: true
})
expect(store.stateRef).not.toBe(initialRevision)
})
})
})

View File

@@ -0,0 +1,79 @@
import { defineStore } from 'pinia'
import { customRef } from 'vue'
import type { NodeLocatorId } from '@/types/nodeIdentification'
export interface NodeState {
hasError: boolean
}
export interface SetNodeErrorCommand {
type: 'SetNodeError'
version: 1
nodeId: NodeLocatorId
hasError: boolean
}
export interface ClearAllErrorsCommand {
type: 'ClearAllErrors'
version: 1
}
export type GraphStateCommand = SetNodeErrorCommand | ClearAllErrorsCommand
export const useGraphStateStore = defineStore('graphState', () => {
const nodes = new Map<NodeLocatorId, NodeState>()
let revision = 0
const stateRef = customRef<number>((track, trigger) => ({
get() {
track()
return revision
},
set() {
revision++
trigger()
}
}))
const execute = (command: GraphStateCommand): void => {
switch (command.type) {
case 'SetNodeError': {
const existing = nodes.get(command.nodeId)
if (existing) {
existing.hasError = command.hasError
} else {
nodes.set(command.nodeId, { hasError: command.hasError })
}
break
}
case 'ClearAllErrors': {
for (const state of nodes.values()) {
state.hasError = false
}
break
}
}
stateRef.value = revision + 1
}
const getNodeState = (nodeId: NodeLocatorId): NodeState | undefined => {
return nodes.get(nodeId)
}
const getNodesWithErrors = (): NodeLocatorId[] => {
const result: NodeLocatorId[] = []
for (const [nodeId, state] of nodes) {
if (state.hasError) result.push(nodeId)
}
return result
}
return {
stateRef,
nodes,
execute,
getNodeState,
getNodesWithErrors
}
})

View File

@@ -0,0 +1,45 @@
import { watch } from 'vue'
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
import { app } from '@/scripts/app'
import {
forEachNode,
forEachSubgraphNode,
getNodeByLocatorId
} from '@/utils/graphTraversalUtil'
import { useGraphStateStore } from './graphStateStore'
const propagateErrorToParents = (node: LGraphNode): void => {
const subgraph = node.graph
if (!subgraph || subgraph.isRootGraph) return
forEachSubgraphNode(app.rootGraph, subgraph.id, (subgraphNode) => {
subgraphNode.has_errors = true
propagateErrorToParents(subgraphNode)
})
}
export const useGraphErrorState = () => {
const store = useGraphStateStore()
watch(
() => store.stateRef,
() => {
if (!app.rootGraph) return
forEachNode(app.rootGraph, (node) => {
node.has_errors = false
})
for (const locatorId of store.getNodesWithErrors()) {
const node = getNodeByLocatorId(app.rootGraph, locatorId)
if (!node) continue
node.has_errors = true
propagateErrorToParents(node)
}
},
{ immediate: true }
)
}

View File

@@ -47,12 +47,14 @@ import { useSubscription } from '@/platform/cloud/subscription/composables/useSu
import { useExtensionService } from '@/services/extensionService'
import { useLitegraphService } from '@/services/litegraphService'
import { useSubgraphService } from '@/services/subgraphService'
import { useGraphErrorState } from '@/core/graph/state/useGraphErrorState'
import { useApiKeyAuthStore } from '@/stores/apiKeyAuthStore'
import { useCommandStore } from '@/stores/commandStore'
import { useDomWidgetStore } from '@/stores/domWidgetStore'
import { useExecutionStore } from '@/stores/executionStore'
import { useExtensionStore } from '@/stores/extensionStore'
import { useFirebaseAuthStore } from '@/stores/firebaseAuthStore'
import { useGraphStateStore } from '@/core/graph/state/graphStateStore'
import { useNodeOutputStore } from '@/stores/imagePreviewStore'
import { KeyComboImpl, useKeybindingStore } from '@/stores/keybindingStore'
import { useModelStore } from '@/stores/modelStore'
@@ -78,6 +80,7 @@ import {
findLegacyRerouteNodes,
noNativeReroutes
} from '@/utils/migration/migrateReroute'
import { collectMissingNodes } from '@/workbench/extensions/manager/utils/graphHasMissingNodes'
import { getSelectedModelsMetadata } from '@/workbench/utils/modelMetadataUtil'
import { deserialiseAndCreate } from '@/utils/vintageClipboard'
@@ -764,6 +767,8 @@ export class ComfyApp {
void useSubgraphStore().fetchSubgraphs()
await useExtensionService().loadExtensions()
useGraphErrorState()
this.addProcessKeyHandler()
this.addConfigureHandler()
this.addApiUpdateHandlers()
@@ -1231,6 +1236,23 @@ export class ComfyApp {
})
}
}
const graphStateStore = useGraphStateStore()
const missingNodes = collectMissingNodes(
this.rootGraph,
useNodeDefStore().nodeDefsByName
)
for (const node of missingNodes) {
const locatorId = node.graph?.isRootGraph
? String(node.id)
: `${node.graph?.id}:${node.id}`
graphStateStore.execute({
type: 'SetNodeError',
version: 1,
nodeId: locatorId,
hasError: true
})
}
} catch (error) {
useDialogService().showErrorDialog(error, {
title: t('errorDialog.loadWorkflowTitle'),

View File

@@ -32,6 +32,7 @@ import { app } from '@/scripts/app'
import { useNodeOutputStore } from '@/stores/imagePreviewStore'
import type { NodeLocatorId } from '@/types/nodeIdentification'
import { createNodeLocatorId } from '@/types/nodeIdentification'
import { useGraphStateStore } from '@/core/graph/state/graphStateStore'
import { forEachNode, getNodeByExecutionId } from '@/utils/graphTraversalUtil'
interface QueuedPrompt {
@@ -574,56 +575,54 @@ export const useExecutionStore = defineStore('execution', () => {
}
/**
* Update node and slot error flags when validation errors change.
* Propagates errors up subgraph chains.
* Push execution errors to graphStateStore and handle slot errors.
*/
watch(lastNodeErrors, () => {
if (!app.rootGraph) return
const graphStateStore = useGraphStateStore()
// Clear all error flags
forEachNode(app.rootGraph, (node) => {
node.has_errors = false
if (node.inputs) {
for (const slot of node.inputs) {
slot.hasErrors = false
// Clear slot errors
if (app.rootGraph) {
forEachNode(app.rootGraph, (node) => {
if (node.inputs) {
for (const slot of node.inputs) {
slot.hasErrors = false
}
}
}
})
})
}
// Clear previous execution errors
graphStateStore.execute({ type: 'ClearAllErrors', version: 1 })
if (!lastNodeErrors.value) return
// Set error flags on nodes and slots
// Push execution errors to graphStateStore
for (const [executionId, nodeError] of Object.entries(
lastNodeErrors.value
)) {
const node = getNodeByExecutionId(app.rootGraph, executionId)
if (!node) continue
node.has_errors = true
// Mark input slots with errors
if (node.inputs) {
for (const error of nodeError.errors) {
const slotName = error.extra_info?.input_name
if (!slotName) continue
const slot = node.inputs.find((s) => s.name === slotName)
if (slot) {
slot.hasErrors = true
}
}
const locatorId = executionIdToNodeLocatorId(executionId)
if (locatorId) {
graphStateStore.execute({
type: 'SetNodeError',
version: 1,
nodeId: locatorId,
hasError: true
})
}
// Propagate errors to parent subgraph nodes
const parts = executionId.split(':')
for (let i = parts.length - 1; i > 0; i--) {
const parentExecutionId = parts.slice(0, i).join(':')
const parentNode = getNodeByExecutionId(
app.rootGraph,
parentExecutionId
)
if (parentNode) {
parentNode.has_errors = true
// Handle slot errors directly (not part of graphStateStore yet)
if (app.rootGraph) {
const node = getNodeByExecutionId(app.rootGraph, executionId)
if (node?.inputs) {
for (const error of nodeError.errors) {
const slotName = error.extra_info?.input_name
if (!slotName) continue
const slot = node.inputs.find((s) => s.name === slotName)
if (slot) {
slot.hasErrors = true
}
}
}
}
}