Add TS types - API (#1736)

* nit

* Add TS types - API events

* Replace all API event emits with type-safe variants

* Add missing API type

* nit

* Remove test code, nit
This commit is contained in:
filtered
2024-11-30 05:15:25 +11:00
committed by GitHub
parent 0bf30e7621
commit 2017b9016b
8 changed files with 225 additions and 112 deletions

View File

@@ -1,19 +1,26 @@
import type { ComfyWorkflowJSON } from '@/types/comfyWorkflow'
import {
type HistoryTaskItem,
type PendingTaskItem,
type RunningTaskItem,
type ComfyNodeDef,
type EmbeddingsResponse,
type ExtensionsResponse,
type PromptResponse,
type SystemStats,
type User,
type Settings,
type UserDataFullInfo,
validateComfyNodeDef,
LogsRawResponse
import type { ComfyWorkflowJSON, NodeId } from '@/types/comfyWorkflow'
import type {
HistoryTaskItem,
PendingTaskItem,
RunningTaskItem,
ComfyNodeDef,
EmbeddingsResponse,
ExtensionsResponse,
PromptResponse,
SystemStats,
User,
Settings,
UserDataFullInfo,
LogsRawResponse,
ExecutingWsMessage,
ExecutedWsMessage,
ProgressWsMessage,
ExecutionStartWsMessage,
ExecutionErrorWsMessage,
StatusWsMessage,
StatusWsMessageStatus
} from '@/types/apiTypes'
import { validateComfyNodeDef } from '@/types/apiTypes'
import axios from 'axios'
interface QueuePromptRequestBody {
@@ -30,6 +37,105 @@ interface QueuePromptRequestBody {
number?: number
}
/** Dictionary of Frontend-generated API calls */
interface FrontendApiCalls {
graphChanged: ComfyWorkflowJSON
promptQueued: { number: number; batchCount: number }
graphCleared: never
reconnecting: never
reconnected: never
}
/** Dictionary of calls originating from ComfyUI core */
interface BackendApiCalls {
progress: ProgressWsMessage
executing: ExecutingWsMessage
executed: ExecutedWsMessage
status: StatusWsMessage
execution_start: ExecutionStartWsMessage
execution_success: never
execution_error: ExecutionErrorWsMessage
execution_cached: never
logs: never
/** Mr Blob Preview, I presume? */
b_preview: Blob
}
/** Dictionary of all api calls */
interface ApiCalls extends BackendApiCalls, FrontendApiCalls {}
/** Used to create a discriminating union on type value. */
interface ApiMessage<T extends keyof ApiCalls> {
type: T
data: ApiCalls[T]
}
/** Ensures workers get a fair shake. */
type Unionize<T> = T[keyof T]
/**
* Discriminated union of generic, i.e.:
* ```ts
* // Convert
* type ApiMessageUnion = ApiMessage<'status' | 'executing' | ...>
* // To
* type ApiMessageUnion = ApiMessage<'status'> | ApiMessage<'executing'> | ...
* ```
*/
type ApiMessageUnion = Unionize<{
[Key in keyof ApiCalls]: ApiMessage<Key>
}>
/** Wraps all properties in {@link CustomEvent}. */
type AsCustomEvents<T> = {
readonly [K in keyof T]: CustomEvent<T[K]>
}
/** Handles differing event and API signatures. */
type ApiToEventType<T = ApiCalls> = {
[K in keyof T]: K extends 'status'
? StatusWsMessageStatus
: K extends 'executing'
? NodeId
: T[K]
}
/** Dictionary of types used in the detail for a custom event */
type ApiEventTypes = ApiToEventType<ApiCalls>
/** Dictionary of API events: `[name]: CustomEvent<Type>` */
type ApiEvents = AsCustomEvents<ApiEventTypes>
/** {@link Omit} all properties that evaluate to `never`. */
type NeverNever<T> = {
[K in keyof T as T[K] extends never ? never : K]: T[K]
}
/** {@link Pick} only properties that evaluate to `never`. */
type PickNevers<T> = {
[K in keyof T as T[K] extends never ? K : never]: T[K]
}
/** Keys (names) of API events that _do not_ pass a {@link CustomEvent} `detail` object. */
type SimpleApiEvents = keyof PickNevers<ApiEventTypes>
/** Keys (names) of API events that pass a {@link CustomEvent} `detail` object. */
type ComplexApiEvents = keyof NeverNever<ApiEventTypes>
/** EventTarget typing has no generic capability. This interface enables tsc strict. */
export interface ComfyApi extends EventTarget {
addEventListener<TEvent extends keyof ApiEvents>(
type: TEvent,
callback: ((event: ApiEvents[TEvent]) => void) | null,
options?: AddEventListenerOptions | boolean
): void
removeEventListener<TEvent extends keyof ApiEvents>(
type: TEvent,
callback: ((event: ApiEvents[TEvent]) => void) | null,
options?: EventListenerOptions | boolean
): void
}
export class ComfyApi extends EventTarget {
#registered = new Set()
api_host: string
@@ -92,15 +198,51 @@ export class ComfyApi extends EventTarget {
return fetch(this.apiURL(route), options)
}
addEventListener(
type: string,
callback: any,
options?: AddEventListenerOptions
addEventListener<TEvent extends keyof ApiEvents>(
type: TEvent,
callback: ((event: ApiEvents[TEvent]) => void) | null,
options?: AddEventListenerOptions | boolean
) {
super.addEventListener(type, callback, options)
// Type assertion: strictFunctionTypes. So long as we emit events in a type-safe fashion, this is safe.
super.addEventListener(type, callback as EventListener, options)
this.#registered.add(type)
}
removeEventListener<TEvent extends keyof ApiEvents>(
type: TEvent,
callback: ((event: ApiEvents[TEvent]) => void) | null,
options?: EventListenerOptions | boolean
): void {
super.removeEventListener(type, callback as EventListener, options)
}
/**
* Dispatches a custom event.
* Provides type safety for the contravariance issue with EventTarget (last checked TS 5.6).
* @param type The type of event to emit
* @param detail The detail property used for a custom event ({@link CustomEventInit.detail})
*/
dispatchCustomEvent<T extends SimpleApiEvents>(type: T): boolean
dispatchCustomEvent<T extends ComplexApiEvents>(
type: T,
detail: ApiEventTypes[T] | null
): boolean
dispatchCustomEvent<T extends keyof ApiEventTypes>(
type: T,
detail?: ApiEventTypes[T]
): boolean {
const event =
detail === undefined
? new CustomEvent(type)
: new CustomEvent(type, { detail })
return super.dispatchEvent(event)
}
/** @deprecated Use {@link dispatchCustomEvent}. */
dispatchEvent(event: never): boolean {
return super.dispatchEvent(event)
}
/**
* Poll status for colab and other things that don't support websockets.
*/
@@ -108,10 +250,10 @@ export class ComfyApi extends EventTarget {
setInterval(async () => {
try {
const resp = await this.fetchApi('/prompt')
const status = await resp.json()
this.dispatchEvent(new CustomEvent('status', { detail: status }))
const status = (await resp.json()) as StatusWsMessageStatus
this.dispatchCustomEvent('status', status)
} catch (error) {
this.dispatchEvent(new CustomEvent('status', { detail: null }))
this.dispatchCustomEvent('status', null)
}
}, 1000)
}
@@ -138,7 +280,7 @@ export class ComfyApi extends EventTarget {
this.socket.addEventListener('open', () => {
opened = true
if (isReconnect) {
this.dispatchEvent(new CustomEvent('reconnected'))
this.dispatchCustomEvent('reconnected')
}
})
@@ -155,8 +297,8 @@ export class ComfyApi extends EventTarget {
this.#createSocket(true)
}, 300)
if (opened) {
this.dispatchEvent(new CustomEvent('status', { detail: null }))
this.dispatchEvent(new CustomEvent('reconnecting'))
this.dispatchCustomEvent('status', null)
this.dispatchCustomEvent('reconnecting')
}
})
@@ -182,9 +324,7 @@ export class ComfyApi extends EventTarget {
const imageBlob = new Blob([buffer.slice(4)], {
type: imageMime
})
this.dispatchEvent(
new CustomEvent('b_preview', { detail: imageBlob })
)
this.dispatchCustomEvent('b_preview', imageBlob)
break
default:
throw new Error(
@@ -192,7 +332,7 @@ export class ComfyApi extends EventTarget {
)
}
} else {
const msg = JSON.parse(event.data)
const msg = JSON.parse(event.data) as ApiMessageUnion
switch (msg.type) {
case 'status':
if (msg.data.sid) {
@@ -201,31 +341,32 @@ export class ComfyApi extends EventTarget {
window.name = clientId // use window name so it isnt reused when duplicating tabs
sessionStorage.setItem('clientId', clientId) // store in session storage so duplicate tab can load correct workflow
}
this.dispatchEvent(
new CustomEvent('status', { detail: msg.data.status })
)
this.dispatchCustomEvent('status', msg.data.status ?? null)
break
case 'executing':
this.dispatchEvent(
new CustomEvent('executing', {
detail: msg.data.display_node || msg.data.node
})
this.dispatchCustomEvent(
'executing',
msg.data.display_node || msg.data.node
)
break
case 'execution_start':
case 'execution_error':
case 'progress':
case 'executed':
case 'execution_start':
case 'graphChanged':
case 'promptQueued':
case 'b_preview':
this.dispatchCustomEvent(msg.type, msg.data)
break
case 'execution_success':
case 'execution_error':
case 'execution_cached':
case 'logs':
this.dispatchEvent(
new CustomEvent(msg.type, { detail: msg.data })
)
this.dispatchCustomEvent(msg.type)
break
default:
if (this.#registered.has(msg.type)) {
this.dispatchEvent(
// Fallback for custom types - calls super direct.
super.dispatchEvent(
new CustomEvent(msg.type, { detail: msg.data })
)
} else if (!this.reportedUnknownMessageTypes.has(msg.type)) {