Type WS messages (#375)

This commit is contained in:
Chenlei Hu
2024-08-11 18:05:14 -04:00
committed by GitHub
parent 2a5b2e8d12
commit edf0396619
5 changed files with 118 additions and 50 deletions

View File

@@ -19,7 +19,7 @@ import {
type ComfyWorkflowJSON,
validateComfyWorkflow
} from '../types/comfyWorkflow'
import { ComfyNodeDef } from '@/types/apiTypes'
import { ComfyNodeDef, StatusWsMessageStatus } from '@/types/apiTypes'
import { lightenColor } from '@/utils/colorUtil'
import { ComfyAppMenu } from './ui/menu/index'
import { getStorageValue } from './utils'
@@ -1586,9 +1586,12 @@ export class ComfyApp {
* Handles updates from the API socket
*/
#addApiUpdateHandlers() {
api.addEventListener('status', ({ detail }) => {
this.ui.setStatus(detail)
})
api.addEventListener(
'status',
({ detail }: CustomEvent<StatusWsMessageStatus>) => {
this.ui.setStatus(detail)
}
)
api.addEventListener('reconnecting', () => {
this.ui.dialog.show('Reconnecting...')

View File

@@ -3,7 +3,7 @@ import { ComfyDialog as _ComfyDialog } from './ui/dialog'
import { toggleSwitch } from './ui/toggleSwitch'
import { ComfySettingsDialog } from './ui/settings'
import { ComfyApp, app } from './app'
import { TaskItem } from '@/types/apiTypes'
import { StatusWsMessageStatus, TaskItem } from '@/types/apiTypes'
export const ComfyDialog = _ComfyDialog

View File

@@ -1,3 +1,4 @@
import { StatusWsMessageStatus } from '@/types/apiTypes'
import { api } from '../../api'
import { ComfyButton } from '../components/button'
@@ -12,10 +13,13 @@ export function getInteruptButton(visibility: string) {
classList: ['comfyui-button', 'comfyui-interrupt-button', visibility]
})
api.addEventListener('status', ({ detail }) => {
const sz = detail?.exec_info?.queue_remaining
btn.enabled = sz > 0
})
api.addEventListener(
'status',
({ detail }: CustomEvent<StatusWsMessageStatus>) => {
const sz = detail.exec_info.queue_remaining
btn.enabled = sz > 0
}
)
return btn
}

View File

@@ -5,6 +5,7 @@ import { ComfySplitButton } from '../components/splitButton'
import { ComfyQueueOptions } from './queueOptions'
import { prop } from '../../utils'
import type { ComfyApp } from '@/scripts/app'
import { StatusWsMessageStatus } from '@/types/apiTypes'
export class ComfyQueueButton {
element = $el('div.comfyui-queue-button')
@@ -86,22 +87,25 @@ export class ComfyQueueButton {
}
})
api.addEventListener('status', ({ detail }) => {
this.#internalQueueSize = detail?.exec_info?.queue_remaining
if (this.#internalQueueSize != null) {
this.queueSizeElement.textContent =
this.#internalQueueSize > 99 ? '99+' : this.#internalQueueSize + ''
this.queueSizeElement.title = `${this.#internalQueueSize} prompts in queue`
if (!this.#internalQueueSize && !app.lastExecutionError) {
if (
this.autoQueueMode === 'instant' ||
(this.autoQueueMode === 'change' && this.graphHasChanged)
) {
this.graphHasChanged = false
this.queuePrompt()
api.addEventListener(
'status',
({ detail }: CustomEvent<StatusWsMessageStatus>) => {
this.#internalQueueSize = detail.exec_info.queue_remaining
if (this.#internalQueueSize != null) {
this.queueSizeElement.textContent =
this.#internalQueueSize > 99 ? '99+' : this.#internalQueueSize + ''
this.queueSizeElement.title = `${this.#internalQueueSize} prompts in queue`
if (!this.#internalQueueSize && !app.lastExecutionError) {
if (
this.autoQueueMode === 'instant' ||
(this.autoQueueMode === 'change' && this.graphHasChanged)
) {
this.graphHasChanged = false
this.queuePrompt()
}
}
}
}
})
)
}
}

View File

@@ -5,6 +5,85 @@ import { fromZodError } from 'zod-validation-error'
const zNodeType = z.string()
const zQueueIndex = z.number()
const zPromptId = z.string()
const zImageResult = z.object({
filename: z.string(),
subfolder: z.string().optional(),
type: z.string()
})
// WS messages
const zStatusWsMessageStatus = z.object({
exec_info: z.object({
queue_remaining: z.number().int()
})
})
const zStatusWsMessage = z.object({
status: zStatusWsMessageStatus
})
const zProgressWsMessage = z.object({
value: z.number().int(),
max: z.number().int(),
prompt_id: zPromptId,
node: zNodeId
})
const zExecutingWsMessage = z.object({
node: zNodeId,
display_node: zNodeId,
prompt_id: zPromptId
})
const zExecutedWsMessage = zExecutingWsMessage.extend({
outputs: z
.object({
images: z.array(zImageResult)
})
.passthrough()
})
const zExecutionWsMessageBase = z.object({
prompt_id: zPromptId,
timestamp: z.number().int()
})
const zExecutionStartWsMessage = zExecutionWsMessageBase
const zExecutionSuccessWsMessage = zExecutionWsMessageBase
const zExecutionCachedWsMessage = zExecutionWsMessageBase.extend({
nodes: z.array(zNodeId)
})
const zExecutionInterruptedWsMessage = zExecutionWsMessageBase.extend({
node_id: zNodeId,
node_type: zNodeType,
executed: z.array(zNodeId)
})
const zExecutionErrorWsMessage = zExecutionWsMessageBase.extend({
node_id: zNodeId,
node_type: zNodeType,
executed: z.array(zNodeId),
exception_message: z.string(),
exception_type: z.string(),
traceback: z.string(),
current_inputs: z.any(),
current_outputs: z.any()
})
export type StatusWsMessageStatus = z.infer<typeof zStatusWsMessageStatus>
export type StatusWsMessage = z.infer<typeof zStatusWsMessage>
export type ProgressWsMessage = z.infer<typeof zProgressWsMessage>
export type ExecutingWsMessage = z.infer<typeof zExecutingWsMessage>
export type ExecutedWsMessage = z.infer<typeof zExecutedWsMessage>
export type ExecutionStartWsMessage = z.infer<typeof zExecutionStartWsMessage>
export type ExecutionSuccessWsMessage = z.infer<
typeof zExecutionSuccessWsMessage
>
export type ExecutionCachedWsMessage = z.infer<typeof zExecutionCachedWsMessage>
export type ExecutionInterruptedWsMessage = z.infer<
typeof zExecutionInterruptedWsMessage
>
export type ExecutionErrorWsMessage = z.infer<typeof zExecutionErrorWsMessage>
// End of ws messages
const zPromptInputItem = z.object({
inputs: z.record(z.string(), z.any()),
@@ -25,51 +104,29 @@ const zExtraData = z.object({
})
const zOutputsToExecute = z.array(zNodeId)
const zMessageDetailBase = z.object({
prompt_id: zPromptId,
timestamp: z.number()
})
const zExecutionStartMessage = z.tuple([
z.literal('execution_start'),
zMessageDetailBase
zExecutionStartWsMessage
])
const zExecutionSuccessMessage = z.tuple([
z.literal('execution_success'),
zMessageDetailBase
zExecutionSuccessWsMessage
])
const zExecutionCachedMessage = z.tuple([
z.literal('execution_cached'),
zMessageDetailBase.extend({
nodes: z.array(zNodeId)
})
zExecutionCachedWsMessage
])
const zExecutionInterruptedMessage = z.tuple([
z.literal('execution_interrupted'),
zMessageDetailBase.extend({
// InterruptProcessingException
node_id: zNodeId,
node_type: zNodeType,
executed: z.array(zNodeId)
})
zExecutionInterruptedWsMessage
])
const zExecutionErrorMessage = z.tuple([
z.literal('execution_error'),
zMessageDetailBase.extend({
node_id: zNodeId,
node_type: zNodeType,
executed: z.array(zNodeId),
exception_message: z.string(),
exception_type: z.string(),
traceback: z.string(),
current_inputs: z.any(),
current_outputs: z.any()
})
zExecutionErrorWsMessage
])
const zStatusMessage = z.union([