From edf039661911b582ff0ebbd29b84d2402936b76a Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Sun, 11 Aug 2024 18:05:14 -0400 Subject: [PATCH] Type WS messages (#375) --- src/scripts/app.ts | 11 ++- src/scripts/ui.ts | 2 +- src/scripts/ui/menu/interruptButton.ts | 12 ++- src/scripts/ui/menu/queueButton.ts | 32 +++---- src/types/apiTypes.ts | 111 +++++++++++++++++++------ 5 files changed, 118 insertions(+), 50 deletions(-) diff --git a/src/scripts/app.ts b/src/scripts/app.ts index a647a1348d..661e042e13 100644 --- a/src/scripts/app.ts +++ b/src/scripts/app.ts @@ -19,7 +19,7 @@ import { type ComfyWorkflowJSON, validateComfyWorkflow } from '../types/comfyWorkflow' -import { ComfyNodeDef } from '@/types/apiTypes' +import { ComfyNodeDef, StatusWsMessageStatus } from '@/types/apiTypes' import { lightenColor } from '@/utils/colorUtil' import { ComfyAppMenu } from './ui/menu/index' import { getStorageValue } from './utils' @@ -1586,9 +1586,12 @@ export class ComfyApp { * Handles updates from the API socket */ #addApiUpdateHandlers() { - api.addEventListener('status', ({ detail }) => { - this.ui.setStatus(detail) - }) + api.addEventListener( + 'status', + ({ detail }: CustomEvent) => { + this.ui.setStatus(detail) + } + ) api.addEventListener('reconnecting', () => { this.ui.dialog.show('Reconnecting...') diff --git a/src/scripts/ui.ts b/src/scripts/ui.ts index 4e81caaaa9..e9c2550131 100644 --- a/src/scripts/ui.ts +++ b/src/scripts/ui.ts @@ -3,7 +3,7 @@ import { ComfyDialog as _ComfyDialog } from './ui/dialog' import { toggleSwitch } from './ui/toggleSwitch' import { ComfySettingsDialog } from './ui/settings' import { ComfyApp, app } from './app' -import { TaskItem } from '@/types/apiTypes' +import { StatusWsMessageStatus, TaskItem } from '@/types/apiTypes' export const ComfyDialog = _ComfyDialog diff --git a/src/scripts/ui/menu/interruptButton.ts b/src/scripts/ui/menu/interruptButton.ts index 7e00182049..e48245b0b8 100644 --- a/src/scripts/ui/menu/interruptButton.ts +++ b/src/scripts/ui/menu/interruptButton.ts @@ -1,3 +1,4 @@ +import { StatusWsMessageStatus } from '@/types/apiTypes' import { api } from '../../api' import { ComfyButton } from '../components/button' @@ -12,10 +13,13 @@ export function getInteruptButton(visibility: string) { classList: ['comfyui-button', 'comfyui-interrupt-button', visibility] }) - api.addEventListener('status', ({ detail }) => { - const sz = detail?.exec_info?.queue_remaining - btn.enabled = sz > 0 - }) + api.addEventListener( + 'status', + ({ detail }: CustomEvent) => { + const sz = detail.exec_info.queue_remaining + btn.enabled = sz > 0 + } + ) return btn } diff --git a/src/scripts/ui/menu/queueButton.ts b/src/scripts/ui/menu/queueButton.ts index 1778dd7eb9..d1e8bc7332 100644 --- a/src/scripts/ui/menu/queueButton.ts +++ b/src/scripts/ui/menu/queueButton.ts @@ -5,6 +5,7 @@ import { ComfySplitButton } from '../components/splitButton' import { ComfyQueueOptions } from './queueOptions' import { prop } from '../../utils' import type { ComfyApp } from '@/scripts/app' +import { StatusWsMessageStatus } from '@/types/apiTypes' export class ComfyQueueButton { element = $el('div.comfyui-queue-button') @@ -86,22 +87,25 @@ export class ComfyQueueButton { } }) - api.addEventListener('status', ({ detail }) => { - this.#internalQueueSize = detail?.exec_info?.queue_remaining - if (this.#internalQueueSize != null) { - this.queueSizeElement.textContent = - this.#internalQueueSize > 99 ? '99+' : this.#internalQueueSize + '' - this.queueSizeElement.title = `${this.#internalQueueSize} prompts in queue` - if (!this.#internalQueueSize && !app.lastExecutionError) { - if ( - this.autoQueueMode === 'instant' || - (this.autoQueueMode === 'change' && this.graphHasChanged) - ) { - this.graphHasChanged = false - this.queuePrompt() + api.addEventListener( + 'status', + ({ detail }: CustomEvent) => { + this.#internalQueueSize = detail.exec_info.queue_remaining + if (this.#internalQueueSize != null) { + this.queueSizeElement.textContent = + this.#internalQueueSize > 99 ? '99+' : this.#internalQueueSize + '' + this.queueSizeElement.title = `${this.#internalQueueSize} prompts in queue` + if (!this.#internalQueueSize && !app.lastExecutionError) { + if ( + this.autoQueueMode === 'instant' || + (this.autoQueueMode === 'change' && this.graphHasChanged) + ) { + this.graphHasChanged = false + this.queuePrompt() + } } } } - }) + ) } } diff --git a/src/types/apiTypes.ts b/src/types/apiTypes.ts index 4246f5b2f1..4c93dd3079 100644 --- a/src/types/apiTypes.ts +++ b/src/types/apiTypes.ts @@ -5,6 +5,85 @@ import { fromZodError } from 'zod-validation-error' const zNodeType = z.string() const zQueueIndex = z.number() const zPromptId = z.string() +const zImageResult = z.object({ + filename: z.string(), + subfolder: z.string().optional(), + type: z.string() +}) + +// WS messages +const zStatusWsMessageStatus = z.object({ + exec_info: z.object({ + queue_remaining: z.number().int() + }) +}) + +const zStatusWsMessage = z.object({ + status: zStatusWsMessageStatus +}) + +const zProgressWsMessage = z.object({ + value: z.number().int(), + max: z.number().int(), + prompt_id: zPromptId, + node: zNodeId +}) + +const zExecutingWsMessage = z.object({ + node: zNodeId, + display_node: zNodeId, + prompt_id: zPromptId +}) + +const zExecutedWsMessage = zExecutingWsMessage.extend({ + outputs: z + .object({ + images: z.array(zImageResult) + }) + .passthrough() +}) + +const zExecutionWsMessageBase = z.object({ + prompt_id: zPromptId, + timestamp: z.number().int() +}) + +const zExecutionStartWsMessage = zExecutionWsMessageBase +const zExecutionSuccessWsMessage = zExecutionWsMessageBase +const zExecutionCachedWsMessage = zExecutionWsMessageBase.extend({ + nodes: z.array(zNodeId) +}) +const zExecutionInterruptedWsMessage = zExecutionWsMessageBase.extend({ + node_id: zNodeId, + node_type: zNodeType, + executed: z.array(zNodeId) +}) +const zExecutionErrorWsMessage = zExecutionWsMessageBase.extend({ + node_id: zNodeId, + node_type: zNodeType, + executed: z.array(zNodeId), + exception_message: z.string(), + exception_type: z.string(), + traceback: z.string(), + current_inputs: z.any(), + current_outputs: z.any() +}) + +export type StatusWsMessageStatus = z.infer +export type StatusWsMessage = z.infer +export type ProgressWsMessage = z.infer +export type ExecutingWsMessage = z.infer +export type ExecutedWsMessage = z.infer +export type ExecutionStartWsMessage = z.infer +export type ExecutionSuccessWsMessage = z.infer< + typeof zExecutionSuccessWsMessage +> +export type ExecutionCachedWsMessage = z.infer +export type ExecutionInterruptedWsMessage = z.infer< + typeof zExecutionInterruptedWsMessage +> +export type ExecutionErrorWsMessage = z.infer +// End of ws messages const zPromptInputItem = z.object({ inputs: z.record(z.string(), z.any()), @@ -25,51 +104,29 @@ const zExtraData = z.object({ }) const zOutputsToExecute = z.array(zNodeId) -const zMessageDetailBase = z.object({ - prompt_id: zPromptId, - timestamp: z.number() -}) - const zExecutionStartMessage = z.tuple([ z.literal('execution_start'), - zMessageDetailBase + zExecutionStartWsMessage ]) const zExecutionSuccessMessage = z.tuple([ z.literal('execution_success'), - zMessageDetailBase + zExecutionSuccessWsMessage ]) const zExecutionCachedMessage = z.tuple([ z.literal('execution_cached'), - zMessageDetailBase.extend({ - nodes: z.array(zNodeId) - }) + zExecutionCachedWsMessage ]) const zExecutionInterruptedMessage = z.tuple([ z.literal('execution_interrupted'), - zMessageDetailBase.extend({ - // InterruptProcessingException - node_id: zNodeId, - node_type: zNodeType, - executed: z.array(zNodeId) - }) + zExecutionInterruptedWsMessage ]) const zExecutionErrorMessage = z.tuple([ z.literal('execution_error'), - zMessageDetailBase.extend({ - node_id: zNodeId, - node_type: zNodeType, - executed: z.array(zNodeId), - - exception_message: z.string(), - exception_type: z.string(), - traceback: z.string(), - current_inputs: z.any(), - current_outputs: z.any() - }) + zExecutionErrorWsMessage ]) const zStatusMessage = z.union([