mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-03-21 04:47:34 +00:00
Compare commits
5 Commits
refactor/m
...
perf/batch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5087f3e27b | ||
|
|
139ffcc411 | ||
|
|
69a077241e | ||
|
|
222845d5bc | ||
|
|
b1785a27d9 |
@@ -1,6 +1,7 @@
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { app } from '@/scripts/app'
|
||||
import { api } from '@/scripts/api'
|
||||
import { MAX_PROGRESS_JOBS, useExecutionStore } from '@/stores/executionStore'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { executionIdToNodeLocatorId } from '@/utils/graphTraversalUtil'
|
||||
@@ -15,6 +16,27 @@ import type { NodeProgressState } from '@/schemas/apiSchema'
|
||||
import { createMockLGraphNode } from '@/utils/__tests__/litegraphTestUtils'
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
addEventListener: vi.fn(),
|
||||
removeEventListener: vi.fn(),
|
||||
clientId: null,
|
||||
apiURL: vi.fn((path: string) => path)
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/imagePreviewStore', () => ({
|
||||
useNodeOutputStore: () => ({
|
||||
revokePreviewsByExecutionId: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/jobPreviewStore', () => ({
|
||||
useJobPreviewStore: () => ({
|
||||
clearPreview: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
// Mock the workflowStore
|
||||
vi.mock('@/platform/workflow/management/stores/workflowStore', async () => {
|
||||
const { ComfyWorkflow } = await vi.importActual<typeof WorkflowStoreModule>(
|
||||
@@ -330,9 +352,12 @@ describe('useExecutionStore - nodeProgressStatesByJob eviction', () => {
|
||||
handler(
|
||||
new CustomEvent('progress_state', { detail: { nodes, prompt_id: jobId } })
|
||||
)
|
||||
// Flush the RAF so the batched update is applied immediately
|
||||
vi.advanceTimersByTime(16)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
vi.clearAllMocks()
|
||||
apiEventHandlers.clear()
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
@@ -340,6 +365,10 @@ describe('useExecutionStore - nodeProgressStatesByJob eviction', () => {
|
||||
store.bindExecutionEvents()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should retain entries below the limit', () => {
|
||||
for (let i = 0; i < 5; i++) {
|
||||
fireProgressState(`job-${i}`, makeProgressNodes(`${i}`, `job-${i}`))
|
||||
@@ -694,3 +723,309 @@ describe('useExecutionErrorStore - setMissingNodeTypes', () => {
|
||||
expect(store.missingNodesError?.nodeTypes).toEqual(input)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useExecutionStore - RAF batching', () => {
|
||||
let store: ReturnType<typeof useExecutionStore>
|
||||
|
||||
function getRegisteredHandler(eventName: string) {
|
||||
const calls = vi.mocked(api.addEventListener).mock.calls
|
||||
const call = calls.find(([name]) => name === eventName)
|
||||
return call?.[1] as (e: CustomEvent) => void
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
vi.clearAllMocks()
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
store = useExecutionStore()
|
||||
store.bindExecutionEvents()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
describe('handleProgress', () => {
|
||||
function makeProgressEvent(value: number, max: number): CustomEvent {
|
||||
return new CustomEvent('progress', {
|
||||
detail: { value, max, prompt_id: 'job-1', node: '1' }
|
||||
})
|
||||
}
|
||||
|
||||
it('batches multiple progress events into one reactive update per frame', () => {
|
||||
const handler = getRegisteredHandler('progress')
|
||||
|
||||
handler(makeProgressEvent(1, 10))
|
||||
handler(makeProgressEvent(5, 10))
|
||||
handler(makeProgressEvent(9, 10))
|
||||
|
||||
expect(store._executingNodeProgress).toBeNull()
|
||||
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toEqual({
|
||||
value: 9,
|
||||
max: 10,
|
||||
prompt_id: 'job-1',
|
||||
node: '1'
|
||||
})
|
||||
})
|
||||
|
||||
it('does not update reactive state before RAF fires', () => {
|
||||
const handler = getRegisteredHandler('progress')
|
||||
|
||||
handler(makeProgressEvent(3, 10))
|
||||
|
||||
expect(store._executingNodeProgress).toBeNull()
|
||||
})
|
||||
|
||||
it('allows a new batch after the previous RAF fires', () => {
|
||||
const handler = getRegisteredHandler('progress')
|
||||
|
||||
handler(makeProgressEvent(1, 10))
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toEqual(
|
||||
expect.objectContaining({ value: 1 })
|
||||
)
|
||||
|
||||
handler(makeProgressEvent(7, 10))
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toEqual(
|
||||
expect.objectContaining({ value: 7 })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleProgressState', () => {
|
||||
function makeProgressStateEvent(
|
||||
nodeId: string,
|
||||
state: string,
|
||||
value = 0,
|
||||
max = 10
|
||||
): CustomEvent {
|
||||
return new CustomEvent('progress_state', {
|
||||
detail: {
|
||||
prompt_id: 'job-1',
|
||||
nodes: {
|
||||
[nodeId]: {
|
||||
value,
|
||||
max,
|
||||
state,
|
||||
node_id: nodeId,
|
||||
prompt_id: 'job-1',
|
||||
display_node_id: nodeId
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
it('batches multiple progress_state events into one reactive update per frame', () => {
|
||||
const handler = getRegisteredHandler('progress_state')
|
||||
|
||||
handler(makeProgressStateEvent('1', 'running', 1))
|
||||
handler(makeProgressStateEvent('1', 'running', 5))
|
||||
handler(makeProgressStateEvent('1', 'running', 9))
|
||||
|
||||
expect(Object.keys(store.nodeProgressStates)).toHaveLength(0)
|
||||
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store.nodeProgressStates['1']).toEqual(
|
||||
expect.objectContaining({ value: 9, state: 'running' })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not update reactive state before RAF fires', () => {
|
||||
const handler = getRegisteredHandler('progress_state')
|
||||
|
||||
handler(makeProgressStateEvent('1', 'running'))
|
||||
|
||||
expect(Object.keys(store.nodeProgressStates)).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('pending RAF is discarded when execution completes', () => {
|
||||
it('discards pending progress RAF on execution_success', () => {
|
||||
const progressHandler = getRegisteredHandler('progress')
|
||||
const startHandler = getRegisteredHandler('execution_start')
|
||||
const successHandler = getRegisteredHandler('execution_success')
|
||||
|
||||
startHandler(
|
||||
new CustomEvent('execution_start', {
|
||||
detail: { prompt_id: 'job-1', timestamp: 0 }
|
||||
})
|
||||
)
|
||||
|
||||
progressHandler(
|
||||
new CustomEvent('progress', {
|
||||
detail: { value: 5, max: 10, prompt_id: 'job-1', node: '1' }
|
||||
})
|
||||
)
|
||||
|
||||
successHandler(
|
||||
new CustomEvent('execution_success', {
|
||||
detail: { prompt_id: 'job-1', timestamp: 0 }
|
||||
})
|
||||
)
|
||||
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toBeNull()
|
||||
})
|
||||
|
||||
it('discards pending progress_state RAF on execution_success', () => {
|
||||
const progressStateHandler = getRegisteredHandler('progress_state')
|
||||
const startHandler = getRegisteredHandler('execution_start')
|
||||
const successHandler = getRegisteredHandler('execution_success')
|
||||
|
||||
startHandler(
|
||||
new CustomEvent('execution_start', {
|
||||
detail: { prompt_id: 'job-1', timestamp: 0 }
|
||||
})
|
||||
)
|
||||
|
||||
progressStateHandler(
|
||||
new CustomEvent('progress_state', {
|
||||
detail: {
|
||||
prompt_id: 'job-1',
|
||||
nodes: {
|
||||
'1': {
|
||||
value: 5,
|
||||
max: 10,
|
||||
state: 'running',
|
||||
node_id: '1',
|
||||
prompt_id: 'job-1',
|
||||
display_node_id: '1'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
successHandler(
|
||||
new CustomEvent('execution_success', {
|
||||
detail: { prompt_id: 'job-1', timestamp: 0 }
|
||||
})
|
||||
)
|
||||
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(Object.keys(store.nodeProgressStates)).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('discards pending progress RAF on execution_error', () => {
|
||||
const progressHandler = getRegisteredHandler('progress')
|
||||
const startHandler = getRegisteredHandler('execution_start')
|
||||
const errorHandler = getRegisteredHandler('execution_error')
|
||||
|
||||
startHandler(
|
||||
new CustomEvent('execution_start', {
|
||||
detail: { prompt_id: 'job-1', timestamp: 0 }
|
||||
})
|
||||
)
|
||||
|
||||
progressHandler(
|
||||
new CustomEvent('progress', {
|
||||
detail: { value: 5, max: 10, prompt_id: 'job-1', node: '1' }
|
||||
})
|
||||
)
|
||||
|
||||
errorHandler(
|
||||
new CustomEvent('execution_error', {
|
||||
detail: {
|
||||
prompt_id: 'job-1',
|
||||
node_id: '1',
|
||||
node_type: 'TestNode',
|
||||
exception_message: 'error',
|
||||
exception_type: 'RuntimeError',
|
||||
traceback: []
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toBeNull()
|
||||
})
|
||||
|
||||
it('discards pending progress RAF on execution_interrupted', () => {
|
||||
const progressHandler = getRegisteredHandler('progress')
|
||||
const startHandler = getRegisteredHandler('execution_start')
|
||||
const interruptedHandler = getRegisteredHandler('execution_interrupted')
|
||||
|
||||
startHandler(
|
||||
new CustomEvent('execution_start', {
|
||||
detail: { prompt_id: 'job-1', timestamp: 0 }
|
||||
})
|
||||
)
|
||||
|
||||
progressHandler(
|
||||
new CustomEvent('progress', {
|
||||
detail: { value: 5, max: 10, prompt_id: 'job-1', node: '1' }
|
||||
})
|
||||
)
|
||||
|
||||
interruptedHandler(
|
||||
new CustomEvent('execution_interrupted', {
|
||||
detail: {
|
||||
prompt_id: 'job-1',
|
||||
node_id: '1',
|
||||
node_type: 'TestNode',
|
||||
executed: []
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toBeNull()
|
||||
})
|
||||
})
|
||||
|
||||
describe('unbindExecutionEvents cancels pending RAFs', () => {
|
||||
it('cancels pending progress RAF on unbind', () => {
|
||||
const handler = getRegisteredHandler('progress')
|
||||
|
||||
handler(
|
||||
new CustomEvent('progress', {
|
||||
detail: { value: 5, max: 10, prompt_id: 'job-1', node: '1' }
|
||||
})
|
||||
)
|
||||
|
||||
store.unbindExecutionEvents()
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toBeNull()
|
||||
})
|
||||
|
||||
it('cancels pending progress_state RAF on unbind', () => {
|
||||
const handler = getRegisteredHandler('progress_state')
|
||||
|
||||
handler(
|
||||
new CustomEvent('progress_state', {
|
||||
detail: {
|
||||
prompt_id: 'job-1',
|
||||
nodes: {
|
||||
'1': {
|
||||
value: 0,
|
||||
max: 10,
|
||||
state: 'running',
|
||||
node_id: '1',
|
||||
prompt_id: 'job-1',
|
||||
display_node_id: '1'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
store.unbindExecutionEvents()
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(Object.keys(store.nodeProgressStates)).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -33,6 +33,7 @@ import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import type { NodeLocatorId } from '@/types/nodeIdentification'
|
||||
import { classifyCloudValidationError } from '@/utils/executionErrorUtil'
|
||||
import { executionIdToNodeLocatorId } from '@/utils/graphTraversalUtil'
|
||||
import { createRafBatch } from '@/utils/rafBatch'
|
||||
|
||||
interface QueuedJob {
|
||||
/**
|
||||
@@ -242,6 +243,11 @@ export const useExecutionStore = defineStore('execution', () => {
|
||||
api.removeEventListener('status', handleStatus)
|
||||
api.removeEventListener('execution_error', handleExecutionError)
|
||||
api.removeEventListener('progress_text', handleProgressText)
|
||||
|
||||
progressBatch.cancel()
|
||||
_pendingProgress = null
|
||||
progressStateBatch.cancel()
|
||||
_pendingProgressState = null
|
||||
}
|
||||
|
||||
function handleExecutionStart(e: CustomEvent<ExecutionStartWsMessage>) {
|
||||
@@ -290,6 +296,11 @@ export const useExecutionStore = defineStore('execution', () => {
|
||||
}
|
||||
|
||||
function handleExecuting(e: CustomEvent<NodeId | null>): void {
|
||||
// Cancel any pending progress RAF before clearing state to prevent
|
||||
// stale data from being written back on the next frame.
|
||||
progressBatch.cancel()
|
||||
_pendingProgress = null
|
||||
|
||||
// Clear the current node progress when a new node starts executing
|
||||
_executingNodeProgress.value = null
|
||||
|
||||
@@ -332,8 +343,21 @@ export const useExecutionStore = defineStore('execution', () => {
|
||||
nodeProgressStatesByJob.value = pruned
|
||||
}
|
||||
|
||||
let _pendingProgressState: ProgressStateWsMessage | null = null
|
||||
const progressStateBatch = createRafBatch(() => {
|
||||
if (_pendingProgressState) {
|
||||
_applyProgressState(_pendingProgressState)
|
||||
_pendingProgressState = null
|
||||
}
|
||||
})
|
||||
|
||||
function handleProgressState(e: CustomEvent<ProgressStateWsMessage>) {
|
||||
const { nodes, prompt_id: jobId } = e.detail
|
||||
_pendingProgressState = e.detail
|
||||
progressStateBatch.schedule()
|
||||
}
|
||||
|
||||
function _applyProgressState(detail: ProgressStateWsMessage) {
|
||||
const { nodes, prompt_id: jobId } = detail
|
||||
|
||||
// Revoke previews for nodes that are starting to execute
|
||||
const previousForJob = nodeProgressStatesByJob.value[jobId] || {}
|
||||
@@ -369,8 +393,17 @@ export const useExecutionStore = defineStore('execution', () => {
|
||||
}
|
||||
}
|
||||
|
||||
let _pendingProgress: ProgressWsMessage | null = null
|
||||
const progressBatch = createRafBatch(() => {
|
||||
if (_pendingProgress) {
|
||||
_executingNodeProgress.value = _pendingProgress
|
||||
_pendingProgress = null
|
||||
}
|
||||
})
|
||||
|
||||
function handleProgress(e: CustomEvent<ProgressWsMessage>) {
|
||||
_executingNodeProgress.value = e.detail
|
||||
_pendingProgress = e.detail
|
||||
progressBatch.schedule()
|
||||
}
|
||||
|
||||
function handleStatus() {
|
||||
@@ -492,6 +525,13 @@ export const useExecutionStore = defineStore('execution', () => {
|
||||
* Reset execution-related state after a run completes or is stopped.
|
||||
*/
|
||||
function resetExecutionState(jobIdParam?: string | null) {
|
||||
// Cancel pending RAFs before clearing state to prevent stale data
|
||||
// from being written back on the next frame.
|
||||
progressBatch.cancel()
|
||||
_pendingProgress = null
|
||||
progressStateBatch.cancel()
|
||||
_pendingProgressState = null
|
||||
|
||||
executionIdToLocatorCache.clear()
|
||||
nodeProgressStates.value = {}
|
||||
const jobId = jobIdParam ?? activeJobId.value ?? null
|
||||
|
||||
Reference in New Issue
Block a user