Files
ComfyUI_frontend/src/services/litegraphService.ts

802 lines
24 KiB
TypeScript

// @ts-strict-ignore
import {
type IContextMenuValue,
type INodeInputSlot,
LGraphEventMode,
LGraphNode,
LiteGraph
} from '@comfyorg/litegraph'
import { Vector2 } from '@comfyorg/litegraph'
import { IBaseWidget, IWidget } from '@comfyorg/litegraph/dist/types/widgets'
import { st } from '@/i18n'
import { api } from '@/scripts/api'
import { ANIM_PREVIEW_WIDGET, ComfyApp, app } from '@/scripts/app'
import { $el } from '@/scripts/ui'
import { calculateImageGrid, createImageHost } from '@/scripts/ui/imagePreview'
import { useToastStore } from '@/stores/toastStore'
import { ComfyNodeDef, ExecutedWsMessage } from '@/types/apiTypes'
import type { NodeId } from '@/types/comfyWorkflow'
import { normalizeI18nKey } from '@/utils/formatUtil'
import { isImageNode } from '@/utils/litegraphUtil'
import { useExtensionService } from './extensionService'
/**
* Service that augments litegraph with ComfyUI specific functionality.
*/
export const useLitegraphService = () => {
const extensionService = useExtensionService()
const toastStore = useToastStore()
async function registerNodeDef(nodeId: string, nodeData: ComfyNodeDef) {
const node = class ComfyNode extends LGraphNode {
static comfyClass? = nodeData.name
// TODO: change to "title?" once litegraph.d.ts has been updated
static title = nodeData.display_name || nodeData.name
static nodeData? = nodeData
static category?: string
constructor(title?: string) {
super(title)
const requiredInputs = nodeData.input.required
let inputs = nodeData['input']['required']
if (nodeData['input']['optional'] != undefined) {
inputs = Object.assign(
{},
nodeData['input']['required'],
nodeData['input']['optional']
)
}
const config: {
minWidth: number
minHeight: number
widget?: IBaseWidget
} = { minWidth: 1, minHeight: 1 }
for (const inputName in inputs) {
const _inputData = inputs[inputName]
const type = _inputData[0]
const options = _inputData[1] ?? {}
const inputData = [type, options]
const nameKey = `nodeDefs.${normalizeI18nKey(nodeData.name)}.inputs.${normalizeI18nKey(inputName)}.name`
const inputIsRequired = requiredInputs && inputName in requiredInputs
let widgetCreated = true
const widgetType = app.getWidgetType(inputData, inputName)
if (widgetType) {
if (widgetType === 'COMBO') {
Object.assign(
config,
app.widgets.COMBO(this, inputName, inputData, app) || {}
)
} else {
Object.assign(
config,
app.widgets[widgetType](this, inputName, inputData, app) || {}
)
}
if (config.widget) {
config.widget.label = st(nameKey, inputName)
}
} else {
// Node connection inputs
const shapeOptions = inputIsRequired
? {}
: { shape: LiteGraph.SlotShape.HollowCircle }
const inputOptions = {
...shapeOptions,
localized_name: st(nameKey, inputName)
}
this.addInput(inputName, type, inputOptions)
widgetCreated = false
}
if (widgetCreated && config?.widget) {
config.widget.options ??= {}
if (!inputIsRequired) {
config.widget.options.inputIsOptional = true
}
if (inputData[1]?.forceInput) {
config.widget.options.forceInput = true
}
if (inputData[1]?.defaultInput) {
config.widget.options.defaultInput = true
}
if (inputData[1]?.advanced) {
config.widget.advanced = true
}
if (inputData[1]?.hidden) {
config.widget.hidden = true
}
}
}
for (const o in nodeData['output']) {
let output = nodeData['output'][o]
if (output instanceof Array) output = 'COMBO'
const outputName = nodeData['output_name'][o] || output
const outputIsList = nodeData['output_is_list'][o]
const shapeOptions = outputIsList
? { shape: LiteGraph.GRID_SHAPE }
: {}
const nameKey = `nodeDefs.${normalizeI18nKey(nodeData.name)}.outputs.${o}.name`
const typeKey = `dataTypes.${normalizeI18nKey(output)}`
const outputOptions = {
...shapeOptions,
// If the output name is different from the output type, use the output name.
// e.g.
// - type ("INT"); name ("Positive") => translate name
// - type ("FLOAT"); name ("FLOAT") => translate type
localized_name:
output !== outputName
? st(nameKey, outputName)
: st(typeKey, outputName)
}
this.addOutput(outputName, output, outputOptions)
}
const s = this.computeSize()
s[0] = Math.max(config.minWidth, s[0] * 1.5)
s[1] = Math.max(config.minHeight, s[1])
this.size = s
this.serialize_widgets = true
extensionService.invokeExtensionsAsync('nodeCreated', this)
}
configure(data: any) {
// Keep 'name', 'type', 'shape', and 'localized_name' information from the original node definition.
const merge = (
current: Record<string, any>,
incoming: Record<string, any>
) => {
const result = { ...incoming }
if (current.widget === undefined && incoming.widget !== undefined) {
// Field must be input as only inputs can be converted
this.inputs.push(current as INodeInputSlot)
return incoming
}
for (const key of ['name', 'type', 'shape', 'localized_name']) {
if (current[key] !== undefined) {
result[key] = current[key]
}
}
return result
}
for (const field of ['inputs', 'outputs']) {
const slots = data[field] ?? []
data[field] = slots.map((slot, i) =>
merge(this[field][i] ?? {}, slot)
)
}
super.configure(data)
}
}
node.prototype.comfyClass = nodeData.name
addNodeContextMenuHandler(node)
addDrawBackgroundHandler(node)
addNodeKeyHandler(node)
await extensionService.invokeExtensionsAsync(
'beforeRegisterNodeDef',
node,
nodeData
)
LiteGraph.registerNodeType(nodeId, node)
// Note: Do not move this to the class definition, it will be overwritten
node.category = nodeData.category
}
/**
* Adds special context menu handling for nodes
* e.g. this adds Open Image functionality for nodes that show images
* @param {*} node The node to add the menu handler
*/
function addNodeContextMenuHandler(node: typeof LGraphNode) {
function getCopyImageOption(img: HTMLImageElement): IContextMenuValue[] {
if (typeof window.ClipboardItem === 'undefined') return []
return [
{
content: 'Copy Image',
// @ts-expect-error: async callback is not accepted by litegraph
callback: async () => {
const url = new URL(img.src)
url.searchParams.delete('preview')
const writeImage = async (blob) => {
await navigator.clipboard.write([
new ClipboardItem({
[blob.type]: blob
})
])
}
try {
const data = await fetch(url)
const blob = await data.blob()
try {
await writeImage(blob)
} catch (error) {
// Chrome seems to only support PNG on write, convert and try again
if (blob.type !== 'image/png') {
const canvas = $el('canvas', {
width: img.naturalWidth,
height: img.naturalHeight
}) as HTMLCanvasElement
const ctx = canvas.getContext('2d')
let image
if (typeof window.createImageBitmap === 'undefined') {
image = new Image()
const p = new Promise((resolve, reject) => {
image.onload = resolve
image.onerror = reject
}).finally(() => {
URL.revokeObjectURL(image.src)
})
image.src = URL.createObjectURL(blob)
await p
} else {
image = await createImageBitmap(blob)
}
try {
ctx.drawImage(image, 0, 0)
canvas.toBlob(writeImage, 'image/png')
} finally {
if (typeof image.close === 'function') {
image.close()
}
}
return
}
throw error
}
} catch (error) {
toastStore.addAlert(
'Error copying image: ' + (error.message ?? error)
)
}
}
}
]
}
node.prototype.getExtraMenuOptions = function (_, options) {
if (this.imgs) {
// If this node has images then we add an open in new tab item
let img
if (this.imageIndex != null) {
// An image is selected so select that
img = this.imgs[this.imageIndex]
} else if (this.overIndex != null) {
// No image is selected but one is hovered
img = this.imgs[this.overIndex]
}
if (img) {
options.unshift(
{
content: 'Open Image',
callback: () => {
const url = new URL(img.src)
url.searchParams.delete('preview')
window.open(url, '_blank')
}
},
...getCopyImageOption(img),
{
content: 'Save Image',
callback: () => {
const a = document.createElement('a')
const url = new URL(img.src)
url.searchParams.delete('preview')
a.href = url.toString()
a.setAttribute(
'download',
new URLSearchParams(url.search).get('filename')
)
document.body.append(a)
a.click()
requestAnimationFrame(() => a.remove())
}
}
)
}
}
options.push({
content: 'Bypass',
callback: (obj) => {
const mode =
this.mode === LGraphEventMode.BYPASS
? LGraphEventMode.ALWAYS
: LGraphEventMode.BYPASS
for (const item of app.canvas.selectedItems) {
if (item instanceof LGraphNode) item.mode = mode
}
this.graph.change()
}
})
// prevent conflict of clipspace content
if (!ComfyApp.clipspace_return_node) {
options.push({
content: 'Copy (Clipspace)',
callback: (obj) => {
ComfyApp.copyToClipspace(this)
}
})
if (ComfyApp.clipspace != null) {
options.push({
content: 'Paste (Clipspace)',
callback: () => {
ComfyApp.pasteFromClipspace(this)
}
})
}
if (isImageNode(this)) {
options.push({
content: 'Open in MaskEditor',
callback: (obj) => {
ComfyApp.copyToClipspace(this)
ComfyApp.clipspace_return_node = this
ComfyApp.open_maskeditor()
}
})
}
}
return []
}
}
/**
* Adds Custom drawing logic for nodes
* e.g. Draws images and handles thumbnail navigation on nodes that output images
* @param {*} node The node to add the draw handler
*/
function addDrawBackgroundHandler(node: typeof LGraphNode) {
function getImageTop(node: LGraphNode) {
let shiftY: number
if (node.imageOffset != null) {
return node.imageOffset
} else if (node.widgets?.length) {
const w = node.widgets[node.widgets.length - 1]
shiftY = w.last_y
if (w.computeSize) {
shiftY += w.computeSize()[1] + 4
// @ts-expect-error computedHeight only exists for DOMWidget
} else if (w.computedHeight) {
// @ts-expect-error computedHeight only exists for DOMWidget
shiftY += w.computedHeight
} else {
shiftY += LiteGraph.NODE_WIDGET_HEIGHT + 4
}
} else {
return node.computeSize()[1]
}
return shiftY
}
node.prototype.setSizeForImage = function (
this: LGraphNode,
force: boolean
) {
if (!force && this.animatedImages) return
if (this.inputHeight || this.freeWidgetSpace > 210) {
this.setSize(this.size)
return
}
const minHeight = getImageTop(this) + 220
if (this.size[1] < minHeight) {
this.setSize([this.size[0], minHeight])
}
}
function unsafeDrawBackground(
this: LGraphNode,
ctx: CanvasRenderingContext2D
) {
if (this.flags.collapsed) return
const imgURLs: (string[] | string)[] = []
let imagesChanged = false
const output: ExecutedWsMessage['output'] = app.nodeOutputs[this.id + '']
if (output?.images && this.images !== output.images) {
this.animatedImages = output?.animated?.find(Boolean)
this.images = output.images
imagesChanged = true
const preview = this.animatedImages ? '' : app.getPreviewFormatParam()
for (const params of output.images) {
const imgUrlPart = new URLSearchParams(params).toString()
const rand = app.getRandParam()
const imgUrl = api.apiURL(`/view?${imgUrlPart}${preview}${rand}`)
imgURLs.push(imgUrl)
}
}
const preview = app.nodePreviewImages[this.id + '']
if (this.preview !== preview) {
this.preview = preview
imagesChanged = true
if (preview != null) {
imgURLs.push(preview)
}
}
if (imagesChanged) {
this.imageIndex = null
if (imgURLs.length > 0) {
Promise.all(
imgURLs.flat().map((src) => {
return new Promise<HTMLImageElement | null>((r) => {
const img = new Image()
img.onload = () => r(img)
img.onerror = () => r(null)
img.src = src
})
})
).then((imgs) => {
if (
(!output || this.images === output.images) &&
(!preview || this.preview === preview)
) {
this.imgs = imgs.filter(Boolean)
this.setSizeForImage?.()
app.graph.setDirtyCanvas(true)
}
})
} else {
this.imgs = null
}
}
const is_all_same_aspect_ratio = (imgs: HTMLImageElement[]) => {
// assume: imgs.length >= 2
const ratio = imgs[0].naturalWidth / imgs[0].naturalHeight
for (let i = 1; i < imgs.length; i++) {
const this_ratio = imgs[i].naturalWidth / imgs[i].naturalHeight
if (ratio != this_ratio) return false
}
return true
}
// Nothing to do
if (!this.imgs?.length) return
const widgetIdx = this.widgets?.findIndex(
(w) => w.name === ANIM_PREVIEW_WIDGET
)
if (this.animatedImages) {
// Instead of using the canvas we'll use a IMG
if (widgetIdx > -1) {
// Replace content
const widget = this.widgets[widgetIdx] as IWidget & {
options: { host: ReturnType<typeof createImageHost> }
}
widget.options.host.updateImages(this.imgs)
} else {
const host = createImageHost(this)
this.setSizeForImage(true)
const widget = this.addDOMWidget(
ANIM_PREVIEW_WIDGET,
'img',
host.el,
{
host,
getHeight: host.getHeight,
onDraw: host.onDraw,
hideOnZoom: false
}
)
widget.serializeValue = () => undefined
widget.options.host.updateImages(this.imgs)
}
return
}
if (widgetIdx > -1) {
this.widgets[widgetIdx].onRemove?.()
this.widgets.splice(widgetIdx, 1)
}
const canvas = app.graph.list_of_graphcanvas[0]
const mouse = canvas.graph_mouse
if (!canvas.pointer_is_down && this.pointerDown) {
if (
mouse[0] === this.pointerDown.pos[0] &&
mouse[1] === this.pointerDown.pos[1]
) {
this.imageIndex = this.pointerDown.index
}
this.pointerDown = null
}
let { imageIndex } = this
const numImages = this.imgs.length
if (numImages === 1 && !imageIndex) {
// This skips the thumbnail render section below
this.imageIndex = imageIndex = 0
}
const shiftY = getImageTop(this)
const dw = this.size[0]
const dh = this.size[1] - shiftY
if (imageIndex == null) {
// No image selected; draw thumbnails of all
let cellWidth: number
let cellHeight: number
let shiftX: number
let cell_padding: number
let cols: number
const compact_mode = is_all_same_aspect_ratio(this.imgs)
if (!compact_mode) {
// use rectangle cell style and border line
cell_padding = 2
// Prevent infinite canvas2d scale-up
const largestDimension = this.imgs.reduce(
(acc, current) =>
Math.max(acc, current.naturalWidth, current.naturalHeight),
0
)
const fakeImgs = []
fakeImgs.length = this.imgs.length
fakeImgs[0] = {
naturalWidth: largestDimension,
naturalHeight: largestDimension
}
;({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(
fakeImgs,
dw,
dh
))
} else {
cell_padding = 0
;({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(
this.imgs,
dw,
dh
))
}
let anyHovered = false
this.imageRects = []
for (let i = 0; i < numImages; i++) {
const img = this.imgs[i]
const row = Math.floor(i / cols)
const col = i % cols
const x = col * cellWidth + shiftX
const y = row * cellHeight + shiftY
if (!anyHovered) {
anyHovered = LiteGraph.isInsideRectangle(
mouse[0],
mouse[1],
x + this.pos[0],
y + this.pos[1],
cellWidth,
cellHeight
)
if (anyHovered) {
this.overIndex = i
let value = 110
if (canvas.pointer_is_down) {
if (!this.pointerDown || this.pointerDown.index !== i) {
this.pointerDown = { index: i, pos: [...mouse] }
}
value = 125
}
ctx.filter = `contrast(${value}%) brightness(${value}%)`
canvas.canvas.style.cursor = 'pointer'
}
}
this.imageRects.push([x, y, cellWidth, cellHeight])
const wratio = cellWidth / img.width
const hratio = cellHeight / img.height
const ratio = Math.min(wratio, hratio)
const imgHeight = ratio * img.height
const imgY = row * cellHeight + shiftY + (cellHeight - imgHeight) / 2
const imgWidth = ratio * img.width
const imgX = col * cellWidth + shiftX + (cellWidth - imgWidth) / 2
ctx.drawImage(
img,
imgX + cell_padding,
imgY + cell_padding,
imgWidth - cell_padding * 2,
imgHeight - cell_padding * 2
)
if (!compact_mode) {
// rectangle cell and border line style
ctx.strokeStyle = '#8F8F8F'
ctx.lineWidth = 1
ctx.strokeRect(
x + cell_padding,
y + cell_padding,
cellWidth - cell_padding * 2,
cellHeight - cell_padding * 2
)
}
ctx.filter = 'none'
}
if (!anyHovered) {
this.pointerDown = null
this.overIndex = null
}
return
}
// Draw individual
let w = this.imgs[imageIndex].naturalWidth
let h = this.imgs[imageIndex].naturalHeight
const scaleX = dw / w
const scaleY = dh / h
const scale = Math.min(scaleX, scaleY, 1)
w *= scale
h *= scale
const x = (dw - w) / 2
const y = (dh - h) / 2 + shiftY
ctx.drawImage(this.imgs[imageIndex], x, y, w, h)
const drawButton = (
x: number,
y: number,
sz: number,
text: string
): boolean => {
const hovered = LiteGraph.isInsideRectangle(
mouse[0],
mouse[1],
x + this.pos[0],
y + this.pos[1],
sz,
sz
)
let fill = '#333'
let textFill = '#fff'
let isClicking = false
if (hovered) {
canvas.canvas.style.cursor = 'pointer'
if (canvas.pointer_is_down) {
fill = '#1e90ff'
isClicking = true
} else {
fill = '#eee'
textFill = '#000'
}
}
ctx.fillStyle = fill
ctx.beginPath()
ctx.roundRect(x, y, sz, sz, [4])
ctx.fill()
ctx.fillStyle = textFill
ctx.font = '12px Arial'
ctx.textAlign = 'center'
ctx.fillText(text, x + 15, y + 20)
return isClicking
}
if (!(numImages > 1)) return
const imageNum = this.imageIndex + 1
if (
drawButton(dw - 40, dh + shiftY - 40, 30, `${imageNum}/${numImages}`)
) {
const i = imageNum >= numImages ? 0 : imageNum
if (!this.pointerDown || this.pointerDown.index !== i) {
this.pointerDown = { index: i, pos: [...mouse] }
}
}
if (drawButton(dw - 40, shiftY + 10, 30, `x`)) {
if (!this.pointerDown || this.pointerDown.index !== null) {
this.pointerDown = { index: null, pos: [...mouse] }
}
}
}
node.prototype.onDrawBackground = function (ctx) {
try {
unsafeDrawBackground.call(this, ctx)
} catch (error) {
console.error('Error drawing node background', error)
}
}
}
function addNodeKeyHandler(node: typeof LGraphNode) {
const origNodeOnKeyDown = node.prototype.onKeyDown
node.prototype.onKeyDown = function (e) {
if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) {
return false
}
if (this.flags.collapsed || !this.imgs || this.imageIndex === null) {
return
}
let handled = false
if (e.key === 'ArrowLeft' || e.key === 'ArrowRight') {
if (e.key === 'ArrowLeft') {
this.imageIndex -= 1
} else if (e.key === 'ArrowRight') {
this.imageIndex += 1
}
this.imageIndex %= this.imgs.length
if (this.imageIndex < 0) {
this.imageIndex = this.imgs.length + this.imageIndex
}
handled = true
} else if (e.key === 'Escape') {
this.imageIndex = null
handled = true
}
if (handled === true) {
e.preventDefault()
e.stopImmediatePropagation()
return false
}
}
}
function addNodeOnGraph(
nodeDef: ComfyNodeDef,
options: Record<string, any> = {}
): LGraphNode {
options.pos ??= getCanvasCenter()
const node = LiteGraph.createNode(
nodeDef.name,
nodeDef.display_name,
options
)
app.graph.add(node)
return node
}
function getCanvasCenter(): Vector2 {
const dpi = Math.max(window.devicePixelRatio ?? 1, 1)
const [x, y, w, h] = app.canvas.ds.visible_area
return [x + w / dpi / 2, y + h / dpi / 2]
}
function goToNode(nodeId: NodeId) {
const graphNode = app.graph.getNodeById(nodeId)
if (!graphNode) return
app.canvas.animateToBounds(graphNode.boundingRect)
}
return {
registerNodeDef,
addNodeOnGraph,
getCanvasCenter,
goToNode
}
}