mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-04 07:00:23 +00:00
graph state store impl
This commit is contained in:
130
src/core/graph/state/graphStateStore.test.ts
Normal file
130
src/core/graph/state/graphStateStore.test.ts
Normal file
@@ -0,0 +1,130 @@
|
||||
import { createPinia, setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it } from 'vitest'
|
||||
|
||||
import { useGraphStateStore } from './graphStateStore'
|
||||
|
||||
describe('graphStateStore', () => {
|
||||
beforeEach(() => {
|
||||
setActivePinia(createPinia())
|
||||
})
|
||||
|
||||
describe('execute SetNodeError command', () => {
|
||||
it('sets hasError on new node', () => {
|
||||
const store = useGraphStateStore()
|
||||
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: '123',
|
||||
hasError: true
|
||||
})
|
||||
|
||||
expect(store.getNodeState('123')?.hasError).toBe(true)
|
||||
})
|
||||
|
||||
it('updates hasError on existing node', () => {
|
||||
const store = useGraphStateStore()
|
||||
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: '123',
|
||||
hasError: true
|
||||
})
|
||||
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: '123',
|
||||
hasError: false
|
||||
})
|
||||
|
||||
expect(store.getNodeState('123')?.hasError).toBe(false)
|
||||
})
|
||||
|
||||
it('handles subgraph node locator IDs', () => {
|
||||
const store = useGraphStateStore()
|
||||
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: 'uuid-123:456',
|
||||
hasError: true
|
||||
})
|
||||
|
||||
expect(store.getNodeState('uuid-123:456')?.hasError).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('execute ClearAllErrors command', () => {
|
||||
it('clears all error flags', () => {
|
||||
const store = useGraphStateStore()
|
||||
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: '1',
|
||||
hasError: true
|
||||
})
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: '2',
|
||||
hasError: true
|
||||
})
|
||||
|
||||
store.execute({ type: 'ClearAllErrors', version: 1 })
|
||||
|
||||
expect(store.getNodeState('1')?.hasError).toBe(false)
|
||||
expect(store.getNodeState('2')?.hasError).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getNodesWithErrors', () => {
|
||||
it('returns only nodes with errors', () => {
|
||||
const store = useGraphStateStore()
|
||||
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: '1',
|
||||
hasError: true
|
||||
})
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: '2',
|
||||
hasError: false
|
||||
})
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: '3',
|
||||
hasError: true
|
||||
})
|
||||
|
||||
const nodesWithErrors = store.getNodesWithErrors()
|
||||
|
||||
expect(nodesWithErrors).toHaveLength(2)
|
||||
expect(nodesWithErrors).toContain('1')
|
||||
expect(nodesWithErrors).toContain('3')
|
||||
expect(nodesWithErrors).not.toContain('2')
|
||||
})
|
||||
})
|
||||
|
||||
describe('stateRef reactivity', () => {
|
||||
it('increments revision on command execution', () => {
|
||||
const store = useGraphStateStore()
|
||||
const initialRevision = store.stateRef
|
||||
|
||||
store.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: '1',
|
||||
hasError: true
|
||||
})
|
||||
|
||||
expect(store.stateRef).not.toBe(initialRevision)
|
||||
})
|
||||
})
|
||||
})
|
||||
79
src/core/graph/state/graphStateStore.ts
Normal file
79
src/core/graph/state/graphStateStore.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import { defineStore } from 'pinia'
|
||||
import { customRef } from 'vue'
|
||||
|
||||
import type { NodeLocatorId } from '@/types/nodeIdentification'
|
||||
|
||||
export interface NodeState {
|
||||
hasError: boolean
|
||||
}
|
||||
|
||||
export interface SetNodeErrorCommand {
|
||||
type: 'SetNodeError'
|
||||
version: 1
|
||||
nodeId: NodeLocatorId
|
||||
hasError: boolean
|
||||
}
|
||||
|
||||
export interface ClearAllErrorsCommand {
|
||||
type: 'ClearAllErrors'
|
||||
version: 1
|
||||
}
|
||||
|
||||
export type GraphStateCommand = SetNodeErrorCommand | ClearAllErrorsCommand
|
||||
|
||||
export const useGraphStateStore = defineStore('graphState', () => {
|
||||
const nodes = new Map<NodeLocatorId, NodeState>()
|
||||
|
||||
let revision = 0
|
||||
const stateRef = customRef<number>((track, trigger) => ({
|
||||
get() {
|
||||
track()
|
||||
return revision
|
||||
},
|
||||
set() {
|
||||
revision++
|
||||
trigger()
|
||||
}
|
||||
}))
|
||||
|
||||
const execute = (command: GraphStateCommand): void => {
|
||||
switch (command.type) {
|
||||
case 'SetNodeError': {
|
||||
const existing = nodes.get(command.nodeId)
|
||||
if (existing) {
|
||||
existing.hasError = command.hasError
|
||||
} else {
|
||||
nodes.set(command.nodeId, { hasError: command.hasError })
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'ClearAllErrors': {
|
||||
for (const state of nodes.values()) {
|
||||
state.hasError = false
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
stateRef.value = revision + 1
|
||||
}
|
||||
|
||||
const getNodeState = (nodeId: NodeLocatorId): NodeState | undefined => {
|
||||
return nodes.get(nodeId)
|
||||
}
|
||||
|
||||
const getNodesWithErrors = (): NodeLocatorId[] => {
|
||||
const result: NodeLocatorId[] = []
|
||||
for (const [nodeId, state] of nodes) {
|
||||
if (state.hasError) result.push(nodeId)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
return {
|
||||
stateRef,
|
||||
nodes,
|
||||
execute,
|
||||
getNodeState,
|
||||
getNodesWithErrors
|
||||
}
|
||||
})
|
||||
45
src/core/graph/state/useGraphErrorState.ts
Normal file
45
src/core/graph/state/useGraphErrorState.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { watch } from 'vue'
|
||||
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { app } from '@/scripts/app'
|
||||
import {
|
||||
forEachNode,
|
||||
forEachSubgraphNode,
|
||||
getNodeByLocatorId
|
||||
} from '@/utils/graphTraversalUtil'
|
||||
|
||||
import { useGraphStateStore } from './graphStateStore'
|
||||
|
||||
const propagateErrorToParents = (node: LGraphNode): void => {
|
||||
const subgraph = node.graph
|
||||
if (!subgraph || subgraph.isRootGraph) return
|
||||
|
||||
forEachSubgraphNode(app.rootGraph, subgraph.id, (subgraphNode) => {
|
||||
subgraphNode.has_errors = true
|
||||
propagateErrorToParents(subgraphNode)
|
||||
})
|
||||
}
|
||||
|
||||
export const useGraphErrorState = () => {
|
||||
const store = useGraphStateStore()
|
||||
|
||||
watch(
|
||||
() => store.stateRef,
|
||||
() => {
|
||||
if (!app.rootGraph) return
|
||||
|
||||
forEachNode(app.rootGraph, (node) => {
|
||||
node.has_errors = false
|
||||
})
|
||||
|
||||
for (const locatorId of store.getNodesWithErrors()) {
|
||||
const node = getNodeByLocatorId(app.rootGraph, locatorId)
|
||||
if (!node) continue
|
||||
|
||||
node.has_errors = true
|
||||
propagateErrorToParents(node)
|
||||
}
|
||||
},
|
||||
{ immediate: true }
|
||||
)
|
||||
}
|
||||
@@ -47,12 +47,14 @@ import { useSubscription } from '@/platform/cloud/subscription/composables/useSu
|
||||
import { useExtensionService } from '@/services/extensionService'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { useSubgraphService } from '@/services/subgraphService'
|
||||
import { useGraphErrorState } from '@/core/graph/state/useGraphErrorState'
|
||||
import { useApiKeyAuthStore } from '@/stores/apiKeyAuthStore'
|
||||
import { useCommandStore } from '@/stores/commandStore'
|
||||
import { useDomWidgetStore } from '@/stores/domWidgetStore'
|
||||
import { useExecutionStore } from '@/stores/executionStore'
|
||||
import { useExtensionStore } from '@/stores/extensionStore'
|
||||
import { useFirebaseAuthStore } from '@/stores/firebaseAuthStore'
|
||||
import { useGraphStateStore } from '@/core/graph/state/graphStateStore'
|
||||
import { useNodeOutputStore } from '@/stores/imagePreviewStore'
|
||||
import { KeyComboImpl, useKeybindingStore } from '@/stores/keybindingStore'
|
||||
import { useModelStore } from '@/stores/modelStore'
|
||||
@@ -78,6 +80,7 @@ import {
|
||||
findLegacyRerouteNodes,
|
||||
noNativeReroutes
|
||||
} from '@/utils/migration/migrateReroute'
|
||||
import { collectMissingNodes } from '@/workbench/extensions/manager/utils/graphHasMissingNodes'
|
||||
import { getSelectedModelsMetadata } from '@/workbench/utils/modelMetadataUtil'
|
||||
import { deserialiseAndCreate } from '@/utils/vintageClipboard'
|
||||
|
||||
@@ -764,6 +767,8 @@ export class ComfyApp {
|
||||
void useSubgraphStore().fetchSubgraphs()
|
||||
await useExtensionService().loadExtensions()
|
||||
|
||||
useGraphErrorState()
|
||||
|
||||
this.addProcessKeyHandler()
|
||||
this.addConfigureHandler()
|
||||
this.addApiUpdateHandlers()
|
||||
@@ -1231,6 +1236,23 @@ export class ComfyApp {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const graphStateStore = useGraphStateStore()
|
||||
const missingNodes = collectMissingNodes(
|
||||
this.rootGraph,
|
||||
useNodeDefStore().nodeDefsByName
|
||||
)
|
||||
for (const node of missingNodes) {
|
||||
const locatorId = node.graph?.isRootGraph
|
||||
? String(node.id)
|
||||
: `${node.graph?.id}:${node.id}`
|
||||
graphStateStore.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: locatorId,
|
||||
hasError: true
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
useDialogService().showErrorDialog(error, {
|
||||
title: t('errorDialog.loadWorkflowTitle'),
|
||||
|
||||
@@ -32,6 +32,7 @@ import { app } from '@/scripts/app'
|
||||
import { useNodeOutputStore } from '@/stores/imagePreviewStore'
|
||||
import type { NodeLocatorId } from '@/types/nodeIdentification'
|
||||
import { createNodeLocatorId } from '@/types/nodeIdentification'
|
||||
import { useGraphStateStore } from '@/core/graph/state/graphStateStore'
|
||||
import { forEachNode, getNodeByExecutionId } from '@/utils/graphTraversalUtil'
|
||||
|
||||
interface QueuedPrompt {
|
||||
@@ -574,56 +575,54 @@ export const useExecutionStore = defineStore('execution', () => {
|
||||
}
|
||||
|
||||
/**
|
||||
* Update node and slot error flags when validation errors change.
|
||||
* Propagates errors up subgraph chains.
|
||||
* Push execution errors to graphStateStore and handle slot errors.
|
||||
*/
|
||||
watch(lastNodeErrors, () => {
|
||||
if (!app.rootGraph) return
|
||||
const graphStateStore = useGraphStateStore()
|
||||
|
||||
// Clear all error flags
|
||||
forEachNode(app.rootGraph, (node) => {
|
||||
node.has_errors = false
|
||||
if (node.inputs) {
|
||||
for (const slot of node.inputs) {
|
||||
slot.hasErrors = false
|
||||
// Clear slot errors
|
||||
if (app.rootGraph) {
|
||||
forEachNode(app.rootGraph, (node) => {
|
||||
if (node.inputs) {
|
||||
for (const slot of node.inputs) {
|
||||
slot.hasErrors = false
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Clear previous execution errors
|
||||
graphStateStore.execute({ type: 'ClearAllErrors', version: 1 })
|
||||
|
||||
if (!lastNodeErrors.value) return
|
||||
|
||||
// Set error flags on nodes and slots
|
||||
// Push execution errors to graphStateStore
|
||||
for (const [executionId, nodeError] of Object.entries(
|
||||
lastNodeErrors.value
|
||||
)) {
|
||||
const node = getNodeByExecutionId(app.rootGraph, executionId)
|
||||
if (!node) continue
|
||||
|
||||
node.has_errors = true
|
||||
|
||||
// Mark input slots with errors
|
||||
if (node.inputs) {
|
||||
for (const error of nodeError.errors) {
|
||||
const slotName = error.extra_info?.input_name
|
||||
if (!slotName) continue
|
||||
|
||||
const slot = node.inputs.find((s) => s.name === slotName)
|
||||
if (slot) {
|
||||
slot.hasErrors = true
|
||||
}
|
||||
}
|
||||
const locatorId = executionIdToNodeLocatorId(executionId)
|
||||
if (locatorId) {
|
||||
graphStateStore.execute({
|
||||
type: 'SetNodeError',
|
||||
version: 1,
|
||||
nodeId: locatorId,
|
||||
hasError: true
|
||||
})
|
||||
}
|
||||
|
||||
// Propagate errors to parent subgraph nodes
|
||||
const parts = executionId.split(':')
|
||||
for (let i = parts.length - 1; i > 0; i--) {
|
||||
const parentExecutionId = parts.slice(0, i).join(':')
|
||||
const parentNode = getNodeByExecutionId(
|
||||
app.rootGraph,
|
||||
parentExecutionId
|
||||
)
|
||||
if (parentNode) {
|
||||
parentNode.has_errors = true
|
||||
// Handle slot errors directly (not part of graphStateStore yet)
|
||||
if (app.rootGraph) {
|
||||
const node = getNodeByExecutionId(app.rootGraph, executionId)
|
||||
if (node?.inputs) {
|
||||
for (const error of nodeError.errors) {
|
||||
const slotName = error.extra_info?.input_name
|
||||
if (!slotName) continue
|
||||
|
||||
const slot = node.inputs.find((s) => s.name === slotName)
|
||||
if (slot) {
|
||||
slot.hasErrors = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user