mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-10 01:50:08 +00:00
WIP autogrow rewrite
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import { remove } from 'es-toolkit'
|
||||
|
||||
import { useChainCallback } from '@/composables/functional/useChainCallback'
|
||||
import { NodeSlotType } from '@/lib/litegraph/src/types/globalEnums'
|
||||
import type {
|
||||
ISlotType,
|
||||
INodeInputSlot,
|
||||
@@ -24,6 +25,19 @@ const INLINE_INPUTS = false
|
||||
|
||||
type MatchTypeNode = LGraphNode &
|
||||
Pick<Required<LGraphNode>, 'comfyMatchType' | 'onConnectionsChange'>
|
||||
type AutogrowNode = LGraphNode &
|
||||
Pick<Required<LGraphNode>, 'onConnectionsChange' | 'widgets'> & {
|
||||
comfyAutogrow: Record<
|
||||
string,
|
||||
{
|
||||
min: number
|
||||
max: number
|
||||
inputSpecs: InputSpecV2[]
|
||||
prefix?: string
|
||||
names?: string[]
|
||||
}
|
||||
>
|
||||
}
|
||||
|
||||
function ensureWidgetForInput(node: LGraphNode, input: INodeInputSlot) {
|
||||
if (input.widget?.name) return
|
||||
@@ -329,160 +343,209 @@ function applyMatchType(node: LGraphNode, inputSpec: InputSpecV2) {
|
||||
)
|
||||
}
|
||||
|
||||
function applyAutogrow(node: LGraphNode, untypedInputSpec: InputSpecV2) {
|
||||
function autogrowOrdinalToName(
|
||||
ordinal: number,
|
||||
key: string,
|
||||
groupName: string,
|
||||
node: AutogrowNode
|
||||
) {
|
||||
const { names, prefix = '', inputSpecs } = node.comfyAutogrow[groupName]
|
||||
const baseName = names
|
||||
? names[ordinal]
|
||||
: (inputSpecs.length == 1 ? prefix : key) + ordinal
|
||||
return { name: `${groupName}.${baseName}`, display_name: baseName }
|
||||
}
|
||||
|
||||
function addAutogrowGroup(
|
||||
ordinal: number,
|
||||
groupName: string,
|
||||
node: AutogrowNode
|
||||
) {
|
||||
const { addNodeInput } = useLitegraphService()
|
||||
const { max, min, inputSpecs } = node.comfyAutogrow[groupName]
|
||||
if (ordinal >= max) return
|
||||
|
||||
const parseResult = zAutogrowOptions.safeParse(untypedInputSpec)
|
||||
if (!parseResult.success) throw new Error('invalid Autogrow spec')
|
||||
const inputSpec = parseResult.data
|
||||
const namedSpecs = inputSpecs.map((input) => ({
|
||||
...input,
|
||||
isOptional: ordinal >= (min ?? 0) || input.isOptional,
|
||||
...autogrowOrdinalToName(ordinal, input.name, groupName, node)
|
||||
}))
|
||||
|
||||
const { input, min, names, prefix, max } = inputSpec.template
|
||||
const inputTypes: [Record<string, InputSpec> | undefined, boolean][] = [
|
||||
[input.required, false],
|
||||
[input.optional, true]
|
||||
]
|
||||
const inputsV2 = inputTypes.flatMap(([inputType, isOptional]) =>
|
||||
Object.entries(inputType ?? {}).map(([name, v]) =>
|
||||
transformInputSpecV1ToV2(v, { name, isOptional })
|
||||
const newInputs = namedSpecs
|
||||
.filter(
|
||||
(namedSpec) => !node.inputs.some((inp) => inp.name === namedSpec.name)
|
||||
)
|
||||
.map((namedSpec) => {
|
||||
addNodeInput(node, namedSpec)
|
||||
const input = spliceInputs(node, node.inputs.length - 1, 1)[0]
|
||||
if (inputSpecs.length !== 1) ensureWidgetForInput(node, input)
|
||||
return input
|
||||
})
|
||||
|
||||
const lastIndex = node.inputs.findLastIndex((inp) =>
|
||||
inp.name.startsWith(groupName)
|
||||
)
|
||||
const insertionIndex = lastIndex === -1 ? node.inputs.length : lastIndex + 1
|
||||
spliceInputs(node, insertionIndex, 0, ...newInputs)
|
||||
app.canvas?.setDirty(true, true)
|
||||
}
|
||||
function removeAutogrowGroup(
|
||||
ordinal: number,
|
||||
groupName: string,
|
||||
node: AutogrowNode
|
||||
) {
|
||||
const { inputSpecs } = node.comfyAutogrow[groupName]
|
||||
for (const spec of inputSpecs) {
|
||||
const { name } = autogrowOrdinalToName(ordinal, spec.name, groupName, node)
|
||||
|
||||
const removed = remove(node.inputs, (inp) => inp.name.startsWith(name))
|
||||
for (const input of removed) {
|
||||
const widgetName = input?.widget?.name
|
||||
if (!widgetName) continue
|
||||
remove(node.widgets, (w) => w.name === widgetName)
|
||||
}
|
||||
}
|
||||
|
||||
node.size[1] = node.computeSize([...node.size])[1]
|
||||
}
|
||||
function resolveAutogrowOrdinal(
|
||||
inputName: string,
|
||||
groupName: string,
|
||||
node: AutogrowNode
|
||||
): number | undefined {
|
||||
//TODO preslice groupname?
|
||||
const name = inputName.slice(groupName.length + 1)
|
||||
const { names, prefix } = node.comfyAutogrow[groupName]
|
||||
if (names) {
|
||||
const ordinal = names.findIndex((s) => s === name)
|
||||
return ordinal === -1 ? undefined : ordinal
|
||||
}
|
||||
//FIXME multi input group prefixes?
|
||||
const ordinal = parseInt(name.slice(prefix!.length))
|
||||
return ordinal !== ordinal ? undefined : ordinal
|
||||
}
|
||||
function autogrowInputConnected(index: number, node: AutogrowNode) {
|
||||
const input = node.inputs[index]
|
||||
const groupName = input.name.slice(0, input.name.lastIndexOf('.'))
|
||||
const lastInput = node.inputs.findLast((inp) =>
|
||||
inp.name.startsWith(groupName)
|
||||
)
|
||||
if (lastInput !== input) return
|
||||
const ordinal = resolveAutogrowOrdinal(input.name, groupName, node)
|
||||
if (ordinal == undefined) return //TODO consider warning here
|
||||
addAutogrowGroup(ordinal + 1, groupName, node)
|
||||
}
|
||||
function autogrowInputDisconnected(index: number, node: AutogrowNode) {
|
||||
const input = node.inputs[index]
|
||||
const groupName = input.name.slice(0, input.name.lastIndexOf('.'))
|
||||
const { min } = node.comfyAutogrow[groupName]
|
||||
const ordinal = resolveAutogrowOrdinal(input.name, groupName, node)
|
||||
if (ordinal == undefined || ordinal + 1 < min) return
|
||||
|
||||
//resolve all inputs in group
|
||||
const groupInputs = node.inputs.filter(
|
||||
(inp) =>
|
||||
inp.name.startsWith(groupName + '.') &&
|
||||
inp.name.lastIndexOf('.') === groupName.length
|
||||
)
|
||||
|
||||
function nameToInputIndex(name: string) {
|
||||
const index = node.inputs.findIndex((input) => input.name === name)
|
||||
if (index === -1) throw new Error('Failed to find input')
|
||||
return index
|
||||
//segment groupInputs by ordinal??
|
||||
//FIXME
|
||||
//for each column?
|
||||
for (
|
||||
let bubbleOrdinal = ordinal;
|
||||
bubbleOrdinal < groupInputs.length - 1;
|
||||
bubbleOrdinal++
|
||||
) {
|
||||
const curInput = groupInputs[bubbleOrdinal]
|
||||
curInput.link = groupInputs[bubbleOrdinal + 1].link
|
||||
if (!curInput.link) continue
|
||||
const link = node.graph?.links[curInput.link]
|
||||
if (!link) continue
|
||||
const curIndex = node.inputs.findIndex((inp) => inp === curInput)
|
||||
if (curIndex === -1) throw new Error('missing input')
|
||||
link.target_slot = curIndex
|
||||
}
|
||||
function nameToInput(name: string) {
|
||||
return node.inputs[nameToInputIndex(name)]
|
||||
//if second to last input in group lacks connection, remove the last
|
||||
const penultimateInput = groupInputs.at(-2)
|
||||
if (penultimateInput && penultimateInput.link == null) {
|
||||
const removeOrdinal = resolveAutogrowOrdinal(
|
||||
groupInputs.at(-1)!.name,
|
||||
groupName,
|
||||
node
|
||||
)
|
||||
if (removeOrdinal === undefined) return
|
||||
removeAutogrowGroup(removeOrdinal, groupName, node)
|
||||
}
|
||||
app.canvas?.setDirty(true, true)
|
||||
}
|
||||
|
||||
//In the distance, someone shouting YAGNI
|
||||
const trackedInputs: string[][] = []
|
||||
function addInputGroup(insertionIndex: number) {
|
||||
const ordinal = trackedInputs.length
|
||||
const inputGroup = inputsV2.map((input) => ({
|
||||
...input,
|
||||
name: names
|
||||
? names[ordinal]
|
||||
: ((inputsV2.length == 1 ? prefix : input.name) ?? '') + ordinal,
|
||||
isOptional: ordinal >= (min ?? 0) || input.isOptional
|
||||
}))
|
||||
const newInputs = inputGroup
|
||||
.filter(
|
||||
(namedSpec) => !node.inputs.some((inp) => inp.name === namedSpec.name)
|
||||
)
|
||||
.map((namedSpec) => {
|
||||
addNodeInput(node, namedSpec)
|
||||
const input = spliceInputs(node, node.inputs.length - 1, 1)[0]
|
||||
if (inputsV2.length !== 1) ensureWidgetForInput(node, input)
|
||||
return input
|
||||
})
|
||||
spliceInputs(node, insertionIndex, 0, ...newInputs)
|
||||
trackedInputs.push(inputGroup.map((inp) => inp.name))
|
||||
app.canvas?.setDirty(true, true)
|
||||
}
|
||||
for (let i = 0; i < (min || 1); i++) addInputGroup(node.inputs.length)
|
||||
function removeInputGroup(inputName: string) {
|
||||
const groupIndex = trackedInputs.findIndex((ig) =>
|
||||
ig.some((inpName) => inpName === inputName)
|
||||
)
|
||||
if (groupIndex == -1) throw new Error('Failed to find group')
|
||||
const group = trackedInputs[groupIndex]
|
||||
for (const nameToRemove of group) {
|
||||
const inputIndex = nameToInputIndex(nameToRemove)
|
||||
const input = spliceInputs(node, inputIndex, 1)[0]
|
||||
if (!input.widget?.name) continue
|
||||
const widget = node.widgets?.find((w) => w.name === input.widget!.name)
|
||||
if (!widget) return
|
||||
widget.value = undefined
|
||||
node.removeWidget(widget)
|
||||
}
|
||||
trackedInputs.splice(groupIndex, 1)
|
||||
node.size[1] = node.computeSize([...node.size])[1]
|
||||
app.canvas?.setDirty(true, true)
|
||||
}
|
||||
|
||||
function inputConnected(index: number) {
|
||||
const input = node.inputs[index]
|
||||
const groupIndex = trackedInputs.findIndex((ig) =>
|
||||
ig.some((inputName) => inputName === input.name)
|
||||
)
|
||||
if (groupIndex == -1) throw new Error('Failed to find group')
|
||||
if (
|
||||
groupIndex + 1 === trackedInputs.length &&
|
||||
trackedInputs.length < (max ?? names?.length ?? 100)
|
||||
) {
|
||||
const lastInput = trackedInputs[groupIndex].at(-1)
|
||||
if (!lastInput) return
|
||||
const insertionIndex = nameToInputIndex(lastInput) + 1
|
||||
if (insertionIndex === 0) throw new Error('Failed to find Input')
|
||||
addInputGroup(insertionIndex)
|
||||
}
|
||||
}
|
||||
function inputDisconnected(index: number) {
|
||||
const input = node.inputs[index]
|
||||
if (trackedInputs.length === 1) return
|
||||
const groupIndex = trackedInputs.findIndex((ig) =>
|
||||
ig.some((inputName) => inputName === input.name)
|
||||
)
|
||||
if (groupIndex == -1) throw new Error('Failed to find group')
|
||||
if (
|
||||
trackedInputs[groupIndex].some(
|
||||
(inputName) => nameToInput(inputName).link != null
|
||||
)
|
||||
)
|
||||
return
|
||||
if (groupIndex + 1 < (min ?? 0)) return
|
||||
//For each group from here to last group, bubble swap links
|
||||
for (let column = 0; column < trackedInputs[0].length; column++) {
|
||||
let prevInput = nameToInputIndex(trackedInputs[groupIndex][column])
|
||||
for (let i = groupIndex + 1; i < trackedInputs.length; i++) {
|
||||
const curInput = nameToInputIndex(trackedInputs[i][column])
|
||||
const linkId = node.inputs[curInput].link
|
||||
node.inputs[prevInput].link = linkId
|
||||
const link = linkId && node.graph?.links?.[linkId]
|
||||
if (link) link.target_slot = prevInput
|
||||
prevInput = curInput
|
||||
}
|
||||
node.inputs[prevInput].link = null
|
||||
}
|
||||
if (
|
||||
trackedInputs.at(-2) &&
|
||||
!trackedInputs.at(-2)?.some((name) => !!nameToInput(name).link)
|
||||
)
|
||||
removeInputGroup(trackedInputs.at(-1)![0])
|
||||
}
|
||||
function withComfyAutogrow(node: LGraphNode): asserts node is AutogrowNode {
|
||||
if (node.comfyAutogrow) return
|
||||
node.comfyAutogrow = {}
|
||||
|
||||
let pendingConnection: number | undefined
|
||||
let swappingConnection = false
|
||||
|
||||
const originalOnConnectInput = node.onConnectInput
|
||||
node.onConnectInput = function (slot: number, ...args) {
|
||||
pendingConnection = slot
|
||||
requestAnimationFrame(() => (pendingConnection = undefined))
|
||||
return originalOnConnectInput?.apply(this, [slot, ...args]) ?? true
|
||||
}
|
||||
|
||||
node.onConnectionsChange = useChainCallback(
|
||||
node.onConnectionsChange,
|
||||
(
|
||||
type: ISlotType,
|
||||
index: number,
|
||||
function (
|
||||
this: AutogrowNode,
|
||||
contype: ISlotType,
|
||||
slot: number,
|
||||
iscon: boolean,
|
||||
linf: LLink | null | undefined
|
||||
) => {
|
||||
if (type !== NodeSlotType.INPUT) return
|
||||
const inputName = node.inputs[index].name
|
||||
if (!trackedInputs.flat().some((name) => name === inputName)) return
|
||||
) {
|
||||
const input = this.inputs[slot]
|
||||
if (contype !== LiteGraph.INPUT || !this.graph || !input) return
|
||||
//Return if input isn't known autogrow
|
||||
const key = input.name.slice(0, input.name.lastIndexOf('.'))
|
||||
const autogrowGroup = this.comfyAutogrow[key]
|
||||
if (!autogrowGroup) return
|
||||
if (iscon) {
|
||||
if (swappingConnection || !linf) return
|
||||
inputConnected(index)
|
||||
autogrowInputConnected(slot, this)
|
||||
} else {
|
||||
if (pendingConnection === index) {
|
||||
if (pendingConnection === slot) {
|
||||
swappingConnection = true
|
||||
requestAnimationFrame(() => (swappingConnection = false))
|
||||
return
|
||||
}
|
||||
requestAnimationFrame(() => inputDisconnected(index))
|
||||
requestAnimationFrame(() => autogrowInputDisconnected(slot, this))
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
function applyAutogrow(node: LGraphNode, inputSpecV2: InputSpecV2) {
|
||||
withComfyAutogrow(node)
|
||||
|
||||
const parseResult = zAutogrowOptions.safeParse(inputSpecV2)
|
||||
if (!parseResult.success) throw new Error('invalid Autogrow spec')
|
||||
const inputSpec = parseResult.data
|
||||
const { input, min = 1, names, prefix, max = 100 } = inputSpec.template
|
||||
|
||||
const inputTypes: (Record<string, InputSpec> | undefined)[] = [
|
||||
input.required,
|
||||
input.optional
|
||||
]
|
||||
const inputsV2 = inputTypes.flatMap((inputType, index) =>
|
||||
Object.entries(inputType ?? {}).map(([name, v]) =>
|
||||
transformInputSpecV1ToV2(v, { name, isOptional: index === 1 })
|
||||
)
|
||||
)
|
||||
node.comfyAutogrow[inputSpecV2.name] = {
|
||||
names,
|
||||
min,
|
||||
max,
|
||||
prefix,
|
||||
inputSpecs: inputsV2
|
||||
}
|
||||
for (let i = 0; i < min; i++) addAutogrowGroup(i, inputSpecV2.name, node)
|
||||
}
|
||||
|
||||
@@ -417,6 +417,7 @@ export class LGraphNode
|
||||
showAdvanced?: boolean
|
||||
|
||||
declare comfyMatchType?: Record<string, Record<string, string>>
|
||||
declare comfyAutogrow?: unknown
|
||||
declare comfyClass?: string
|
||||
declare isVirtualNode?: boolean
|
||||
applyToGraph?(extraLinks?: LLink[]): void
|
||||
|
||||
Reference in New Issue
Block a user