mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-04-20 14:30:41 +00:00
fix: RAF-batch WebSocket progress events to reduce reactive update storms
Amp-Thread-ID: https://ampcode.com/threads/T-019ca43d-b7a5-759f-b88f-1319faac8a01
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { app } from '@/scripts/app'
|
||||
import { api } from '@/scripts/api'
|
||||
import { MAX_PROGRESS_JOBS, useExecutionStore } from '@/stores/executionStore'
|
||||
import { useExecutionErrorStore } from '@/stores/executionErrorStore'
|
||||
import { executionIdToNodeLocatorId } from '@/utils/graphTraversalUtil'
|
||||
@@ -15,6 +16,27 @@ import type { NodeProgressState } from '@/schemas/apiSchema'
|
||||
import { createMockLGraphNode } from '@/utils/__tests__/litegraphTestUtils'
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
|
||||
vi.mock('@/scripts/api', () => ({
|
||||
api: {
|
||||
addEventListener: vi.fn(),
|
||||
removeEventListener: vi.fn(),
|
||||
clientId: null,
|
||||
apiURL: vi.fn((path: string) => path)
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/imagePreviewStore', () => ({
|
||||
useNodeOutputStore: () => ({
|
||||
revokePreviewsByExecutionId: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/jobPreviewStore', () => ({
|
||||
useJobPreviewStore: () => ({
|
||||
clearPreview: vi.fn()
|
||||
})
|
||||
}))
|
||||
|
||||
// Mock the workflowStore
|
||||
vi.mock('@/platform/workflow/management/stores/workflowStore', async () => {
|
||||
const { ComfyWorkflow } = await vi.importActual<typeof WorkflowStoreModule>(
|
||||
@@ -694,3 +716,174 @@ describe('useExecutionErrorStore - setMissingNodeTypes', () => {
|
||||
expect(store.missingNodesError?.nodeTypes).toEqual(input)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useExecutionStore - RAF batching', () => {
|
||||
let store: ReturnType<typeof useExecutionStore>
|
||||
|
||||
function getRegisteredHandler(eventName: string) {
|
||||
const calls = vi.mocked(api.addEventListener).mock.calls
|
||||
const call = calls.find(([name]) => name === eventName)
|
||||
return call?.[1] as (e: CustomEvent) => void
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
vi.clearAllMocks()
|
||||
setActivePinia(createTestingPinia({ stubActions: false }))
|
||||
store = useExecutionStore()
|
||||
store.bindExecutionEvents()
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
describe('handleProgress', () => {
|
||||
function makeProgressEvent(
|
||||
value: number,
|
||||
max: number
|
||||
): CustomEvent {
|
||||
return new CustomEvent('progress', {
|
||||
detail: { value, max, prompt_id: 'job-1', node: '1' }
|
||||
})
|
||||
}
|
||||
|
||||
it('batches multiple progress events into one reactive update per frame', () => {
|
||||
const handler = getRegisteredHandler('progress')
|
||||
|
||||
handler(makeProgressEvent(1, 10))
|
||||
handler(makeProgressEvent(5, 10))
|
||||
handler(makeProgressEvent(9, 10))
|
||||
|
||||
expect(store._executingNodeProgress).toBeNull()
|
||||
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toEqual({
|
||||
value: 9,
|
||||
max: 10,
|
||||
prompt_id: 'job-1',
|
||||
node: '1'
|
||||
})
|
||||
})
|
||||
|
||||
it('does not update reactive state before RAF fires', () => {
|
||||
const handler = getRegisteredHandler('progress')
|
||||
|
||||
handler(makeProgressEvent(3, 10))
|
||||
|
||||
expect(store._executingNodeProgress).toBeNull()
|
||||
})
|
||||
|
||||
it('allows a new batch after the previous RAF fires', () => {
|
||||
const handler = getRegisteredHandler('progress')
|
||||
|
||||
handler(makeProgressEvent(1, 10))
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toEqual(
|
||||
expect.objectContaining({ value: 1 })
|
||||
)
|
||||
|
||||
handler(makeProgressEvent(7, 10))
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toEqual(
|
||||
expect.objectContaining({ value: 7 })
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleProgressState', () => {
|
||||
function makeProgressStateEvent(
|
||||
nodeId: string,
|
||||
state: string,
|
||||
value = 0,
|
||||
max = 10
|
||||
): CustomEvent {
|
||||
return new CustomEvent('progress_state', {
|
||||
detail: {
|
||||
prompt_id: 'job-1',
|
||||
nodes: {
|
||||
[nodeId]: {
|
||||
value,
|
||||
max,
|
||||
state,
|
||||
node_id: nodeId,
|
||||
prompt_id: 'job-1',
|
||||
display_node_id: nodeId
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
it('batches multiple progress_state events into one reactive update per frame', () => {
|
||||
const handler = getRegisteredHandler('progress_state')
|
||||
|
||||
handler(makeProgressStateEvent('1', 'running', 1))
|
||||
handler(makeProgressStateEvent('1', 'running', 5))
|
||||
handler(makeProgressStateEvent('1', 'running', 9))
|
||||
|
||||
expect(Object.keys(store.nodeProgressStates)).toHaveLength(0)
|
||||
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store.nodeProgressStates['1']).toEqual(
|
||||
expect.objectContaining({ value: 9, state: 'running' })
|
||||
)
|
||||
})
|
||||
|
||||
it('does not update reactive state before RAF fires', () => {
|
||||
const handler = getRegisteredHandler('progress_state')
|
||||
|
||||
handler(makeProgressStateEvent('1', 'running'))
|
||||
|
||||
expect(Object.keys(store.nodeProgressStates)).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('unbindExecutionEvents cancels pending RAFs', () => {
|
||||
it('cancels pending progress RAF on unbind', () => {
|
||||
const handler = getRegisteredHandler('progress')
|
||||
|
||||
handler(
|
||||
new CustomEvent('progress', {
|
||||
detail: { value: 5, max: 10, prompt_id: 'job-1', node: '1' }
|
||||
})
|
||||
)
|
||||
|
||||
store.unbindExecutionEvents()
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(store._executingNodeProgress).toBeNull()
|
||||
})
|
||||
|
||||
it('cancels pending progress_state RAF on unbind', () => {
|
||||
const handler = getRegisteredHandler('progress_state')
|
||||
|
||||
handler(
|
||||
new CustomEvent('progress_state', {
|
||||
detail: {
|
||||
prompt_id: 'job-1',
|
||||
nodes: {
|
||||
'1': {
|
||||
value: 0,
|
||||
max: 10,
|
||||
state: 'running',
|
||||
node_id: '1',
|
||||
prompt_id: 'job-1',
|
||||
display_node_id: '1'
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
store.unbindExecutionEvents()
|
||||
vi.advanceTimersByTime(16)
|
||||
|
||||
expect(Object.keys(store.nodeProgressStates)).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -242,6 +242,17 @@ export const useExecutionStore = defineStore('execution', () => {
|
||||
api.removeEventListener('status', handleStatus)
|
||||
api.removeEventListener('execution_error', handleExecutionError)
|
||||
api.removeEventListener('progress_text', handleProgressText)
|
||||
|
||||
if (_progressRafId !== null) {
|
||||
cancelAnimationFrame(_progressRafId)
|
||||
_progressRafId = null
|
||||
_pendingProgress = null
|
||||
}
|
||||
if (_progressStateRafId !== null) {
|
||||
cancelAnimationFrame(_progressStateRafId)
|
||||
_progressStateRafId = null
|
||||
_pendingProgressState = null
|
||||
}
|
||||
}
|
||||
|
||||
function handleExecutionStart(e: CustomEvent<ExecutionStartWsMessage>) {
|
||||
@@ -332,8 +343,24 @@ export const useExecutionStore = defineStore('execution', () => {
|
||||
nodeProgressStatesByJob.value = pruned
|
||||
}
|
||||
|
||||
let _pendingProgressState: ProgressStateWsMessage | null = null
|
||||
let _progressStateRafId: number | null = null
|
||||
|
||||
function handleProgressState(e: CustomEvent<ProgressStateWsMessage>) {
|
||||
const { nodes, prompt_id: jobId } = e.detail
|
||||
_pendingProgressState = e.detail
|
||||
if (_progressStateRafId === null) {
|
||||
_progressStateRafId = requestAnimationFrame(() => {
|
||||
_progressStateRafId = null
|
||||
if (_pendingProgressState) {
|
||||
_applyProgressState(_pendingProgressState)
|
||||
_pendingProgressState = null
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function _applyProgressState(detail: ProgressStateWsMessage) {
|
||||
const { nodes, prompt_id: jobId } = detail
|
||||
|
||||
// Revoke previews for nodes that are starting to execute
|
||||
const previousForJob = nodeProgressStatesByJob.value[jobId] || {}
|
||||
@@ -369,8 +396,20 @@ export const useExecutionStore = defineStore('execution', () => {
|
||||
}
|
||||
}
|
||||
|
||||
let _pendingProgress: ProgressWsMessage | null = null
|
||||
let _progressRafId: number | null = null
|
||||
|
||||
function handleProgress(e: CustomEvent<ProgressWsMessage>) {
|
||||
_executingNodeProgress.value = e.detail
|
||||
_pendingProgress = e.detail
|
||||
if (_progressRafId === null) {
|
||||
_progressRafId = requestAnimationFrame(() => {
|
||||
_progressRafId = null
|
||||
if (_pendingProgress) {
|
||||
_executingNodeProgress.value = _pendingProgress
|
||||
_pendingProgress = null
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function handleStatus() {
|
||||
|
||||
Reference in New Issue
Block a user