mirror of
https://github.com/Comfy-Org/ComfyUI_frontend.git
synced 2026-02-11 02:20:08 +00:00
Add support for growable inputs (#6830)
 Also fixes connections to widget inputs created by a dynamic combo breaking on reload. Performs some refactoring to group the prior dynamic inputs code. See also, the overarching frontend PR: comfyanonymous/ComfyUI#10832 ┆Issue is synchronized with this [Notion page](https://www.notion.so/PR-6830-Add-support-for-growable-inputs-2b36d73d365081c484ebc251a10aa6dd) by [Unito](https://www.unito.io)
This commit is contained in:
@@ -1,10 +1,50 @@
|
||||
import { without } from 'es-toolkit'
|
||||
|
||||
import { useChainCallback } from '@/composables/functional/useChainCallback'
|
||||
import { NodeSlotType } from '@/lib/litegraph/src/types/globalEnums'
|
||||
import type {
|
||||
ISlotType,
|
||||
INodeInputSlot,
|
||||
INodeOutputSlot
|
||||
} from '@/lib/litegraph/src/interfaces'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { LLink } from '@/lib/litegraph/src/LLink'
|
||||
import { transformInputSpecV1ToV2 } from '@/schemas/nodeDef/migration'
|
||||
import type { ComboInputSpec, InputSpec } from '@/schemas/nodeDefSchema'
|
||||
import { zDynamicComboInputSpec } from '@/schemas/nodeDefSchema'
|
||||
import type { InputSpec as InputSpecV2 } from '@/schemas/nodeDef/nodeDefSchemaV2'
|
||||
import {
|
||||
zAutogrowOptions,
|
||||
zDynamicComboInputSpec
|
||||
} from '@/schemas/nodeDefSchema'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
import { app } from '@/scripts/app'
|
||||
import type { ComfyApp } from '@/scripts/app'
|
||||
import { isStrings } from '@/utils/typeGuardUtil'
|
||||
|
||||
const INLINE_INPUTS = false
|
||||
|
||||
type MatchTypeNode = LGraphNode &
|
||||
Pick<Required<LGraphNode>, 'comfyMatchType' | 'onConnectionsChange'>
|
||||
|
||||
function ensureWidgetForInput(node: LGraphNode, input: INodeInputSlot) {
|
||||
if (input.widget?.name) return
|
||||
node.widgets ??= []
|
||||
node.widgets.push({
|
||||
name: input.name,
|
||||
y: 0,
|
||||
type: 'shim',
|
||||
options: {},
|
||||
draw(ctx, _n, _w, y) {
|
||||
ctx.save()
|
||||
ctx.fillStyle = LiteGraph.NODE_TEXT_COLOR
|
||||
ctx.fillText(input.label ?? input.name, 20, y + 15)
|
||||
ctx.restore()
|
||||
}
|
||||
})
|
||||
input.alwaysVisible = true
|
||||
input.widget = { name: input.name }
|
||||
}
|
||||
|
||||
function dynamicComboWidget(
|
||||
node: LGraphNode,
|
||||
@@ -32,11 +72,10 @@ function dynamicComboWidget(
|
||||
const updateWidgets = (value?: string) => {
|
||||
if (!node.widgets) throw new Error('Not Reachable')
|
||||
const newSpec = value ? options[value] : undefined
|
||||
//TODO: Calculate intersection for widgets that persist across options
|
||||
//This would potentially allow links to be retained
|
||||
const inputsToRemove: Record<string, INodeInputSlot> = {}
|
||||
for (const name of currentDynamicNames) {
|
||||
const inputIndex = node.inputs.findIndex((input) => input.name === name)
|
||||
if (inputIndex !== -1) node.removeInput(inputIndex)
|
||||
const input = node.inputs.find((input) => input.name === name)
|
||||
if (input) inputsToRemove[input.name] = input
|
||||
const widgetIndex = node.widgets.findIndex(
|
||||
(widget) => widget.name === name
|
||||
)
|
||||
@@ -45,13 +84,20 @@ function dynamicComboWidget(
|
||||
node.widgets.splice(widgetIndex, 1)
|
||||
}
|
||||
currentDynamicNames = []
|
||||
if (!newSpec) return
|
||||
if (!newSpec) {
|
||||
for (const input of Object.values(inputsToRemove)) {
|
||||
const inputIndex = node.inputs.findIndex((inp) => inp === input)
|
||||
if (inputIndex === -1) continue
|
||||
node.removeInput(inputIndex)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const insertionPoint = node.widgets.findIndex((w) => w === widget) + 1
|
||||
const startingLength = node.widgets.length
|
||||
const inputInsertionPoint =
|
||||
const initialInputIndex =
|
||||
node.inputs.findIndex((i) => i.name === widget.name) + 1
|
||||
const startingInputLength = node.inputs.length
|
||||
let startingInputLength = node.inputs.length
|
||||
if (insertionPoint === 0)
|
||||
throw new Error("Dynamic widget doesn't exist on node")
|
||||
const inputTypes: [Record<string, InputSpec> | undefined, boolean][] = [
|
||||
@@ -59,17 +105,37 @@ function dynamicComboWidget(
|
||||
[newSpec.optional, true]
|
||||
]
|
||||
for (const [inputType, isOptional] of inputTypes)
|
||||
for (const name in inputType ?? {}) {
|
||||
addNodeInput(
|
||||
node,
|
||||
transformInputSpecV1ToV2(inputType![name], {
|
||||
name,
|
||||
isOptional
|
||||
})
|
||||
)
|
||||
for (const key in inputType ?? {}) {
|
||||
const name = `${widget.name}.${key}`
|
||||
const specToAdd = transformInputSpecV1ToV2(inputType![key], {
|
||||
name,
|
||||
isOptional
|
||||
})
|
||||
specToAdd.display_name = key
|
||||
addNodeInput(node, specToAdd)
|
||||
currentDynamicNames.push(name)
|
||||
if (INLINE_INPUTS) ensureWidgetForInput(node, node.inputs.at(-1)!)
|
||||
if (
|
||||
!inputsToRemove[name] ||
|
||||
Array.isArray(inputType![key][0]) ||
|
||||
!LiteGraph.isValidConnection(
|
||||
inputsToRemove[name].type,
|
||||
inputType![key][0]
|
||||
)
|
||||
)
|
||||
continue
|
||||
node.inputs.at(-1)!.link = inputsToRemove[name].link
|
||||
inputsToRemove[name].link = null
|
||||
}
|
||||
|
||||
for (const input of Object.values(inputsToRemove)) {
|
||||
const inputIndex = node.inputs.findIndex((inp) => inp === input)
|
||||
if (inputIndex === -1) continue
|
||||
if (inputIndex < initialInputIndex) startingInputLength--
|
||||
node.removeInput(inputIndex)
|
||||
}
|
||||
const inputInsertionPoint =
|
||||
node.inputs.findIndex((i) => i.name === widget.name) + 1
|
||||
const addedWidgets = node.widgets.splice(startingLength)
|
||||
node.widgets.splice(insertionPoint, 0, ...addedWidgets)
|
||||
if (inputInsertionPoint === 0) {
|
||||
@@ -81,19 +147,23 @@ function dynamicComboWidget(
|
||||
throw new Error('Failed to find input socket for ' + widget.name)
|
||||
return
|
||||
}
|
||||
const addedInputs = node
|
||||
.spliceInputs(startingInputLength)
|
||||
.map((addedInput) => {
|
||||
const addedInputs = spliceInputs(node, startingInputLength).map(
|
||||
(addedInput) => {
|
||||
const existingInput = node.inputs.findIndex(
|
||||
(existingInput) => addedInput.name === existingInput.name
|
||||
)
|
||||
return existingInput === -1
|
||||
? addedInput
|
||||
: node.spliceInputs(existingInput, 1)[0]
|
||||
})
|
||||
: spliceInputs(node, existingInput, 1)[0]
|
||||
}
|
||||
)
|
||||
//assume existing inputs are in correct order
|
||||
node.spliceInputs(inputInsertionPoint, 0, ...addedInputs)
|
||||
spliceInputs(node, inputInsertionPoint, 0, ...addedInputs)
|
||||
node.size[1] = node.computeSize([...node.size])[1]
|
||||
if (!node.graph) return
|
||||
node._setConcreteSlots()
|
||||
node.arrange()
|
||||
app.canvas?.setDirty(true, true)
|
||||
}
|
||||
//A little hacky, but onConfigure won't work.
|
||||
//It fires too late and is overly disruptive
|
||||
@@ -112,3 +182,335 @@ function dynamicComboWidget(
|
||||
}
|
||||
|
||||
export const dynamicWidgets = { COMFY_DYNAMICCOMBO_V3: dynamicComboWidget }
|
||||
const dynamicInputs: Record<
|
||||
string,
|
||||
(node: LGraphNode, inputSpec: InputSpecV2) => void
|
||||
> = {
|
||||
COMFY_AUTOGROW_V3: applyAutogrow,
|
||||
COMFY_MATCHTYPE_V3: applyMatchType
|
||||
}
|
||||
|
||||
export function applyDynamicInputs(
|
||||
node: LGraphNode,
|
||||
inputSpec: InputSpecV2
|
||||
): boolean {
|
||||
if (!(inputSpec.type in dynamicInputs)) return false
|
||||
//TODO: move parsing/validation of inputSpec here?
|
||||
dynamicInputs[inputSpec.type](node, inputSpec)
|
||||
return true
|
||||
}
|
||||
function spliceInputs(
|
||||
node: LGraphNode,
|
||||
startIndex: number,
|
||||
deleteCount = -1,
|
||||
...toAdd: INodeInputSlot[]
|
||||
): INodeInputSlot[] {
|
||||
if (deleteCount < 0) return node.inputs.splice(startIndex)
|
||||
const ret = node.inputs.splice(startIndex, deleteCount, ...toAdd)
|
||||
node.inputs.slice(startIndex).forEach((input, index) => {
|
||||
const link = input.link && node.graph?.links?.get(input.link)
|
||||
if (link) link.target_slot = startIndex + index
|
||||
})
|
||||
return ret
|
||||
}
|
||||
|
||||
function changeOutputType(
|
||||
node: LGraphNode,
|
||||
output: INodeOutputSlot,
|
||||
combinedType: ISlotType
|
||||
) {
|
||||
if (output.type === combinedType) return
|
||||
output.type = combinedType
|
||||
|
||||
//check and potentially remove links
|
||||
if (!node.graph) return
|
||||
for (const link_id of output.links ?? []) {
|
||||
const link = node.graph.links[link_id]
|
||||
if (!link) continue
|
||||
const { input, inputNode, subgraphOutput } = link.resolve(node.graph)
|
||||
const inputType = (input ?? subgraphOutput)?.type
|
||||
if (!inputType) continue
|
||||
const keep = LiteGraph.isValidConnection(combinedType, inputType)
|
||||
if (!keep && subgraphOutput) subgraphOutput.disconnect()
|
||||
else if (!keep && inputNode) inputNode.disconnectInput(link.target_slot)
|
||||
if (input && inputNode?.onConnectionsChange)
|
||||
inputNode.onConnectionsChange(
|
||||
LiteGraph.INPUT,
|
||||
link.target_slot,
|
||||
keep,
|
||||
link,
|
||||
input
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
function combineTypes(...types: ISlotType[]): ISlotType | undefined {
|
||||
if (!isStrings(types)) return undefined
|
||||
|
||||
const withoutWildcards = without(types, '*')
|
||||
if (withoutWildcards.length === 0) return '*'
|
||||
|
||||
const typeLists: string[][] = withoutWildcards.map((type) => type.split(','))
|
||||
|
||||
const combinedTypes = intersection(...typeLists)
|
||||
if (combinedTypes.length === 0) return undefined
|
||||
|
||||
return combinedTypes.join(',')
|
||||
}
|
||||
|
||||
function intersection(...sets: string[][]): string[] {
|
||||
const itemCounts: Record<string, number> = {}
|
||||
for (const set of sets)
|
||||
for (const item of new Set(set))
|
||||
itemCounts[item] = (itemCounts[item] ?? 0) + 1
|
||||
return Object.entries(itemCounts)
|
||||
.filter(([, count]) => count == sets.length)
|
||||
.map(([key]) => key)
|
||||
}
|
||||
|
||||
function withComfyMatchType(node: LGraphNode): asserts node is MatchTypeNode {
|
||||
if (node.comfyMatchType) return
|
||||
node.comfyMatchType = {}
|
||||
|
||||
const outputGroups = node.constructor.nodeData?.output_matchtypes
|
||||
node.onConnectionsChange = useChainCallback(
|
||||
node.onConnectionsChange,
|
||||
function (
|
||||
this: MatchTypeNode,
|
||||
contype: ISlotType,
|
||||
slot: number,
|
||||
iscon: boolean,
|
||||
linf: LLink | null | undefined
|
||||
) {
|
||||
const input = this.inputs[slot]
|
||||
if (contype !== LiteGraph.INPUT || !this.graph || !input) return
|
||||
const [matchKey, matchGroup] = Object.entries(this.comfyMatchType).find(
|
||||
([, group]) => input.name in group
|
||||
) ?? ['', undefined]
|
||||
if (!matchGroup) return
|
||||
if (iscon && linf) {
|
||||
const { output, subgraphInput } = linf.resolve(this.graph)
|
||||
//TODO: fix this bug globally. A link type (and therefore color)
|
||||
//should be the combinedType of origin and target type
|
||||
const connectingType = (output ?? subgraphInput)?.type
|
||||
if (connectingType) linf.type = connectingType
|
||||
}
|
||||
//NOTE: inputs contains input
|
||||
const groupInputs: INodeInputSlot[] = node.inputs.filter(
|
||||
(inp) => inp.name in matchGroup
|
||||
)
|
||||
const connectedTypes = groupInputs.map((inp) => {
|
||||
if (!inp.link) return '*'
|
||||
const link = this.graph!.links[inp.link]
|
||||
if (!link) return '*'
|
||||
const { output, subgraphInput } = link.resolve(this.graph!)
|
||||
return (output ?? subgraphInput)?.type ?? '*'
|
||||
})
|
||||
//An input slot can accept a connection that is
|
||||
// - Compatible with original type
|
||||
// - Compatible with all other input types
|
||||
//An output slot can output
|
||||
// - Only what every input can output
|
||||
groupInputs.forEach((input, idx) => {
|
||||
const otherConnected = [
|
||||
...connectedTypes.slice(0, idx),
|
||||
...connectedTypes.slice(idx + 1)
|
||||
]
|
||||
const combinedType = combineTypes(
|
||||
...otherConnected,
|
||||
matchGroup[input.name]
|
||||
)
|
||||
if (!combinedType) throw new Error('invalid connection')
|
||||
input.type = combinedType
|
||||
})
|
||||
const outputType = combineTypes(...connectedTypes)
|
||||
if (!outputType) throw new Error('invalid connection')
|
||||
this.outputs.forEach((output, idx) => {
|
||||
if (!(outputGroups?.[idx] == matchKey)) return
|
||||
changeOutputType(this, output, outputType)
|
||||
})
|
||||
app.canvas?.setDirty(true, true)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
function applyMatchType(node: LGraphNode, inputSpec: InputSpecV2) {
|
||||
const { addNodeInput } = useLitegraphService()
|
||||
const name = inputSpec.name
|
||||
const { allowed_types, template_id } = (
|
||||
inputSpec as InputSpecV2 & {
|
||||
template: { allowed_types: string; template_id: string }
|
||||
}
|
||||
).template
|
||||
const typedSpec = { ...inputSpec, type: allowed_types }
|
||||
addNodeInput(node, typedSpec)
|
||||
withComfyMatchType(node)
|
||||
node.comfyMatchType[template_id] ??= {}
|
||||
node.comfyMatchType[template_id][name] = allowed_types
|
||||
|
||||
//TODO: instead apply on output add?
|
||||
//ensure outputs get updated
|
||||
const index = node.inputs.length - 1
|
||||
const input = node.inputs.at(-1)!
|
||||
requestAnimationFrame(() =>
|
||||
node.onConnectionsChange(LiteGraph.INPUT, index, false, undefined, input)
|
||||
)
|
||||
}
|
||||
|
||||
function applyAutogrow(node: LGraphNode, untypedInputSpec: InputSpecV2) {
|
||||
const { addNodeInput } = useLitegraphService()
|
||||
|
||||
const parseResult = zAutogrowOptions.safeParse(untypedInputSpec)
|
||||
if (!parseResult.success) throw new Error('invalid Autogrow spec')
|
||||
const inputSpec = parseResult.data
|
||||
|
||||
const { input, min, names, prefix, max } = inputSpec.template
|
||||
const inputTypes: [Record<string, InputSpec> | undefined, boolean][] = [
|
||||
[input.required, false],
|
||||
[input.optional, true]
|
||||
]
|
||||
const inputsV2 = inputTypes.flatMap(([inputType, isOptional]) =>
|
||||
Object.entries(inputType ?? {}).map(([name, v]) =>
|
||||
transformInputSpecV1ToV2(v, { name, isOptional })
|
||||
)
|
||||
)
|
||||
|
||||
function nameToInputIndex(name: string) {
|
||||
const index = node.inputs.findIndex((input) => input.name === name)
|
||||
if (index === -1) throw new Error('Failed to find input')
|
||||
return index
|
||||
}
|
||||
function nameToInput(name: string) {
|
||||
return node.inputs[nameToInputIndex(name)]
|
||||
}
|
||||
|
||||
//In the distance, someone shouting YAGNI
|
||||
const trackedInputs: string[][] = []
|
||||
function addInputGroup(insertionIndex: number) {
|
||||
const ordinal = trackedInputs.length
|
||||
const inputGroup = inputsV2.map((input) => ({
|
||||
...input,
|
||||
name: names
|
||||
? names[ordinal]
|
||||
: ((inputsV2.length == 1 ? prefix : input.name) ?? '') + ordinal,
|
||||
isOptional: ordinal >= (min ?? 0) || input.isOptional
|
||||
}))
|
||||
const newInputs = inputGroup
|
||||
.filter(
|
||||
(namedSpec) => !node.inputs.some((inp) => inp.name === namedSpec.name)
|
||||
)
|
||||
.map((namedSpec) => {
|
||||
addNodeInput(node, namedSpec)
|
||||
const input = spliceInputs(node, node.inputs.length - 1, 1)[0]
|
||||
if (inputsV2.length !== 1) ensureWidgetForInput(node, input)
|
||||
return input
|
||||
})
|
||||
spliceInputs(node, insertionIndex, 0, ...newInputs)
|
||||
trackedInputs.push(inputGroup.map((inp) => inp.name))
|
||||
app.canvas?.setDirty(true, true)
|
||||
}
|
||||
for (let i = 0; i < (min || 1); i++) addInputGroup(node.inputs.length)
|
||||
function removeInputGroup(inputName: string) {
|
||||
const groupIndex = trackedInputs.findIndex((ig) =>
|
||||
ig.some((inpName) => inpName === inputName)
|
||||
)
|
||||
if (groupIndex == -1) throw new Error('Failed to find group')
|
||||
const group = trackedInputs[groupIndex]
|
||||
for (const nameToRemove of group) {
|
||||
const inputIndex = nameToInputIndex(nameToRemove)
|
||||
const input = spliceInputs(node, inputIndex, 1)[0]
|
||||
if (!input.widget?.name) continue
|
||||
const widget = node.widgets?.find((w) => w.name === input.widget!.name)
|
||||
if (!widget) return
|
||||
widget.value = undefined
|
||||
node.removeWidget(widget)
|
||||
}
|
||||
trackedInputs.splice(groupIndex, 1)
|
||||
node.size[1] = node.computeSize([...node.size])[1]
|
||||
app.canvas?.setDirty(true, true)
|
||||
}
|
||||
|
||||
function inputConnected(index: number) {
|
||||
const input = node.inputs[index]
|
||||
const groupIndex = trackedInputs.findIndex((ig) =>
|
||||
ig.some((inputName) => inputName === input.name)
|
||||
)
|
||||
if (groupIndex == -1) throw new Error('Failed to find group')
|
||||
if (
|
||||
groupIndex + 1 === trackedInputs.length &&
|
||||
trackedInputs.length < (max ?? names?.length ?? 100)
|
||||
) {
|
||||
const lastInput = trackedInputs[groupIndex].at(-1)
|
||||
if (!lastInput) return
|
||||
const insertionIndex = nameToInputIndex(lastInput) + 1
|
||||
if (insertionIndex === 0) throw new Error('Failed to find Input')
|
||||
addInputGroup(insertionIndex)
|
||||
}
|
||||
}
|
||||
function inputDisconnected(index: number) {
|
||||
const input = node.inputs[index]
|
||||
if (trackedInputs.length === 1) return
|
||||
const groupIndex = trackedInputs.findIndex((ig) =>
|
||||
ig.some((inputName) => inputName === input.name)
|
||||
)
|
||||
if (groupIndex == -1) throw new Error('Failed to find group')
|
||||
if (
|
||||
trackedInputs[groupIndex].some(
|
||||
(inputName) => nameToInput(inputName).link != null
|
||||
)
|
||||
)
|
||||
return
|
||||
if (groupIndex + 1 < (min ?? 0)) return
|
||||
//For each group from here to last group, bubble swap links
|
||||
for (let column = 0; column < trackedInputs[0].length; column++) {
|
||||
let prevInput = nameToInputIndex(trackedInputs[groupIndex][column])
|
||||
for (let i = groupIndex + 1; i < trackedInputs.length; i++) {
|
||||
const curInput = nameToInputIndex(trackedInputs[i][column])
|
||||
const linkId = node.inputs[curInput].link
|
||||
node.inputs[prevInput].link = linkId
|
||||
const link = linkId && node.graph?.links?.[linkId]
|
||||
if (link) link.target_slot = prevInput
|
||||
prevInput = curInput
|
||||
}
|
||||
node.inputs[prevInput].link = null
|
||||
}
|
||||
if (
|
||||
trackedInputs.at(-2) &&
|
||||
!trackedInputs.at(-2)?.some((name) => !!nameToInput(name).link)
|
||||
)
|
||||
removeInputGroup(trackedInputs.at(-1)![0])
|
||||
}
|
||||
|
||||
let pendingConnection: number | undefined
|
||||
let swappingConnection = false
|
||||
const originalOnConnectInput = node.onConnectInput
|
||||
node.onConnectInput = function (slot: number, ...args) {
|
||||
pendingConnection = slot
|
||||
requestAnimationFrame(() => (pendingConnection = undefined))
|
||||
return originalOnConnectInput?.apply(this, [slot, ...args]) ?? true
|
||||
}
|
||||
node.onConnectionsChange = useChainCallback(
|
||||
node.onConnectionsChange,
|
||||
(
|
||||
type: ISlotType,
|
||||
index: number,
|
||||
iscon: boolean,
|
||||
linf: LLink | null | undefined
|
||||
) => {
|
||||
if (type !== NodeSlotType.INPUT) return
|
||||
const inputName = node.inputs[index].name
|
||||
if (!trackedInputs.flat().some((name) => name === inputName)) return
|
||||
if (iscon) {
|
||||
if (swappingConnection || !linf) return
|
||||
inputConnected(index)
|
||||
} else {
|
||||
if (pendingConnection === index) {
|
||||
swappingConnection = true
|
||||
requestAnimationFrame(() => (swappingConnection = false))
|
||||
return
|
||||
}
|
||||
requestAnimationFrame(() => inputDisconnected(index))
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import './groupNodeManage'
|
||||
import './groupOptions'
|
||||
import './load3d'
|
||||
import './maskeditor'
|
||||
import './matchType'
|
||||
import './nodeTemplates'
|
||||
import './noteNode'
|
||||
import './previewAny'
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
import { without } from 'es-toolkit'
|
||||
|
||||
import { useChainCallback } from '@/composables/functional/useChainCallback'
|
||||
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
|
||||
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
|
||||
import type { LLink } from '@/lib/litegraph/src/LLink'
|
||||
import type { ISlotType } from '@/lib/litegraph/src/interfaces'
|
||||
import { app } from '@/scripts/app'
|
||||
|
||||
const MATCH_TYPE = 'COMFY_MATCHTYPE_V3'
|
||||
|
||||
app.registerExtension({
|
||||
name: 'Comfy.MatchType',
|
||||
beforeRegisterNodeDef(nodeType, nodeData) {
|
||||
const inputs = {
|
||||
...nodeData.input?.required,
|
||||
...nodeData.input?.optional
|
||||
}
|
||||
if (!Object.values(inputs).some((w) => w[0] === MATCH_TYPE)) return
|
||||
nodeType.prototype.onNodeCreated = useChainCallback(
|
||||
nodeType.prototype.onNodeCreated,
|
||||
function (this: LGraphNode) {
|
||||
const inputGroups: Record<string, [string, ISlotType][]> = {}
|
||||
const outputGroups: Record<string, number[]> = {}
|
||||
for (const input of this.inputs) {
|
||||
if (input.type !== MATCH_TYPE) continue
|
||||
const template = inputs[input.name][1]?.template
|
||||
if (!template) continue
|
||||
input.type = template.allowed_types ?? '*'
|
||||
inputGroups[template.template_id] ??= []
|
||||
inputGroups[template.template_id].push([input.name, input.type])
|
||||
}
|
||||
this.outputs.forEach((output, i) => {
|
||||
if (output.type !== MATCH_TYPE) return
|
||||
const id = nodeData.output_matchtypes?.[i]
|
||||
if (id == undefined) return
|
||||
outputGroups[id] ??= []
|
||||
outputGroups[id].push(i)
|
||||
})
|
||||
for (const groupId in inputGroups) {
|
||||
addConnectionGroup(this, inputGroups[groupId], outputGroups[groupId])
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
})
|
||||
function addConnectionGroup(
|
||||
node: LGraphNode,
|
||||
inputPairs: [string, ISlotType][],
|
||||
outputs?: number[]
|
||||
) {
|
||||
const connectedTypes: ISlotType[] = new Array(inputPairs.length).fill('*')
|
||||
node.onConnectionsChange = useChainCallback(
|
||||
node.onConnectionsChange,
|
||||
function (
|
||||
this: LGraphNode,
|
||||
contype: ISlotType,
|
||||
slot: number,
|
||||
iscon: boolean,
|
||||
linf: LLink | null | undefined
|
||||
) {
|
||||
const input = this.inputs[slot]
|
||||
if (contype !== LiteGraph.INPUT || !this.graph || !input) return
|
||||
const pairIndex = inputPairs.findIndex(([name]) => name === input.name)
|
||||
if (pairIndex == -1) return
|
||||
connectedTypes[pairIndex] = inputPairs[pairIndex][1]
|
||||
if (iscon && linf) {
|
||||
const { output, subgraphInput } = linf.resolve(this.graph)
|
||||
const connectingType = (output ?? subgraphInput)?.type
|
||||
if (connectingType)
|
||||
linf.type = connectedTypes[pairIndex] = connectingType
|
||||
}
|
||||
//An input slot can accept a connection that is
|
||||
// - Compatible with original type
|
||||
// - Compatible with all other input types
|
||||
//An output slot can output
|
||||
// - Only what every input can output
|
||||
for (let i = 0; i < inputPairs.length; i++) {
|
||||
//NOTE: This isn't great. Originally, I kept direct references to each
|
||||
//input, but these were becoming orphaned
|
||||
const input = this.inputs.find((inp) => inp.name === inputPairs[i][0])
|
||||
if (!input) continue
|
||||
const otherConnected = [...connectedTypes]
|
||||
otherConnected.splice(i, 1)
|
||||
const validType = combineTypes(...otherConnected, inputPairs[i][1])
|
||||
if (!validType) throw new Error('invalid connection')
|
||||
input.type = validType
|
||||
}
|
||||
if (outputs) {
|
||||
const outputType = combineTypes(...connectedTypes)
|
||||
if (!outputType) throw new Error('invalid connection')
|
||||
changeOutputType(this, outputType, outputs)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
function changeOutputType(
|
||||
node: LGraphNode,
|
||||
combinedType: ISlotType,
|
||||
outputs: number[]
|
||||
) {
|
||||
if (!node.graph) return
|
||||
for (const index of outputs) {
|
||||
if (node.outputs[index].type === combinedType) continue
|
||||
node.outputs[index].type = combinedType
|
||||
|
||||
//check and potentially remove links
|
||||
for (let link_id of node.outputs[index].links ?? []) {
|
||||
let link = node.graph.links[link_id]
|
||||
if (!link) continue
|
||||
const { input, inputNode, subgraphOutput } = link.resolve(node.graph)
|
||||
const inputType = (input ?? subgraphOutput)?.type
|
||||
if (!inputType) continue
|
||||
const keep = LiteGraph.isValidConnection(combinedType, inputType)
|
||||
if (!keep && subgraphOutput) subgraphOutput.disconnect()
|
||||
else if (!keep && inputNode) inputNode.disconnectInput(link.target_slot)
|
||||
if (input && inputNode?.onConnectionsChange)
|
||||
inputNode.onConnectionsChange(
|
||||
LiteGraph.INPUT,
|
||||
link.target_slot,
|
||||
keep,
|
||||
link,
|
||||
input
|
||||
)
|
||||
}
|
||||
app.canvas.setDirty(true, true)
|
||||
}
|
||||
}
|
||||
function isStrings(types: ISlotType[]): types is string[] {
|
||||
return !types.some((t) => typeof t !== 'string')
|
||||
}
|
||||
|
||||
function combineTypes(...types: ISlotType[]): ISlotType | undefined {
|
||||
if (!isStrings(types)) return undefined
|
||||
|
||||
const withoutWildcards = without(types, '*')
|
||||
if (withoutWildcards.length === 0) return '*'
|
||||
|
||||
const typeLists: string[][] = withoutWildcards.map((type) => type.split(','))
|
||||
|
||||
const combinedTypes = intersection(...typeLists)
|
||||
if (combinedTypes.length === 0) return undefined
|
||||
|
||||
return combinedTypes.join(',')
|
||||
}
|
||||
function intersection(...sets: string[][]): string[] {
|
||||
const itemCounts: Record<string, number> = {}
|
||||
for (const set of sets)
|
||||
for (const item of new Set(set))
|
||||
itemCounts[item] = (itemCounts[item] ?? 0) + 1
|
||||
return Object.entries(itemCounts)
|
||||
.filter(([, count]) => count == sets.length)
|
||||
.map(([key]) => key)
|
||||
}
|
||||
@@ -415,6 +415,7 @@ export class LGraphNode
|
||||
selected?: boolean
|
||||
showAdvanced?: boolean
|
||||
|
||||
declare comfyMatchType?: Record<string, Record<string, string>>
|
||||
declare comfyClass?: string
|
||||
declare isVirtualNode?: boolean
|
||||
applyToGraph?(extraLinks?: LLink[]): void
|
||||
@@ -1651,19 +1652,6 @@ export class LGraphNode
|
||||
this.onInputRemoved?.(slot, slot_info[0])
|
||||
this.setDirtyCanvas(true, true)
|
||||
}
|
||||
spliceInputs(
|
||||
startIndex: number,
|
||||
deleteCount = -1,
|
||||
...toAdd: INodeInputSlot[]
|
||||
): INodeInputSlot[] {
|
||||
if (deleteCount < 0) return this.inputs.splice(startIndex)
|
||||
const ret = this.inputs.splice(startIndex, deleteCount, ...toAdd)
|
||||
this.inputs.slice(startIndex).forEach((input, index) => {
|
||||
const link = input.link && this.graph?.links?.get(input.link)
|
||||
if (link) link.target_slot = startIndex + index
|
||||
})
|
||||
return ret
|
||||
}
|
||||
|
||||
/**
|
||||
* computes the minimum size of a node according to its inputs and output slots
|
||||
@@ -4002,7 +3990,8 @@ export class LGraphNode
|
||||
isValidTarget ||
|
||||
!slot.isWidgetInputSlot ||
|
||||
this.#isMouseOverWidget(this.getWidgetFromSlot(slot)) ||
|
||||
slot.isConnected
|
||||
slot.isConnected ||
|
||||
slot.alwaysVisible
|
||||
) {
|
||||
ctx.globalAlpha = isValid ? editorAlpha : 0.4 * editorAlpha
|
||||
slot.draw(ctx, {
|
||||
|
||||
@@ -343,6 +343,7 @@ export interface IWidgetLocator {
|
||||
export interface INodeInputSlot extends INodeSlot {
|
||||
link: LinkId | null
|
||||
widget?: IWidgetLocator
|
||||
alwaysVisible?: boolean
|
||||
|
||||
/**
|
||||
* Internal use only; API is not finalised and may change at any time.
|
||||
|
||||
@@ -17,6 +17,7 @@ import type { IBaseWidget } from '@/lib/litegraph/src/types/widgets'
|
||||
|
||||
export class NodeInputSlot extends NodeSlot implements INodeInputSlot {
|
||||
link: LinkId | null
|
||||
alwaysVisible?: boolean
|
||||
|
||||
get isWidgetInputSlot(): boolean {
|
||||
return !!this.widget
|
||||
|
||||
@@ -14,10 +14,6 @@ const zRemoteWidgetConfig = z.object({
|
||||
timeout: z.number().gte(0).optional(),
|
||||
max_retries: z.number().gte(0).optional()
|
||||
})
|
||||
const zWidgetTemplate = z.object({
|
||||
template_id: z.string(),
|
||||
allowed_types: z.string().optional()
|
||||
})
|
||||
const zMultiSelectOption = z.object({
|
||||
placeholder: z.string().optional(),
|
||||
chip: z.boolean().optional()
|
||||
@@ -34,7 +30,6 @@ export const zBaseInputOptions = z
|
||||
hidden: z.boolean().optional(),
|
||||
advanced: z.boolean().optional(),
|
||||
widgetType: z.string().optional(),
|
||||
template: zWidgetTemplate.optional(),
|
||||
/** Backend-only properties. */
|
||||
rawLink: z.boolean().optional(),
|
||||
lazy: z.boolean().optional()
|
||||
@@ -232,9 +227,21 @@ export const zComfyNodeDef = z.object({
|
||||
input_order: z.record(z.array(z.string())).optional()
|
||||
})
|
||||
|
||||
export const zAutogrowOptions = z.object({
|
||||
...zBaseInputOptions.shape,
|
||||
template: z.object({
|
||||
input: zComfyInputsSpec,
|
||||
names: z.array(z.string()).optional(),
|
||||
max: z.number().optional(),
|
||||
//Backend defines as mandatory with min 1, Frontend is more forgiving
|
||||
min: z.number().optional(),
|
||||
prefix: z.string().optional()
|
||||
})
|
||||
})
|
||||
|
||||
export const zDynamicComboInputSpec = z.tuple([
|
||||
z.literal('COMFY_DYNAMICCOMBO_V3'),
|
||||
zComboInputOptions.extend({
|
||||
zBaseInputOptions.extend({
|
||||
options: z.array(
|
||||
z.object({
|
||||
inputs: zComfyInputsSpec,
|
||||
|
||||
@@ -7,6 +7,7 @@ import { useNodeCanvasImagePreview } from '@/composables/node/useNodeCanvasImage
|
||||
import { useNodeImage, useNodeVideo } from '@/composables/node/useNodeImage'
|
||||
import { addWidgetPromotionOptions } from '@/core/graph/subgraph/proxyWidgetUtils'
|
||||
import { showSubgraphNodeDialog } from '@/core/graph/subgraph/useSubgraphNodeDialog'
|
||||
import { applyDynamicInputs } from '@/core/graph/widgets/dynamicWidgets'
|
||||
import { st, t } from '@/i18n'
|
||||
import {
|
||||
LGraphCanvas,
|
||||
@@ -93,7 +94,11 @@ export const useLitegraphService = () => {
|
||||
const widgetConstructor = widgetStore.widgets.get(
|
||||
inputSpec.widgetType ?? inputSpec.type
|
||||
)
|
||||
if (widgetConstructor && !inputSpec.forceInput) return
|
||||
if (
|
||||
(widgetConstructor && !inputSpec.forceInput) ||
|
||||
applyDynamicInputs(node, inputSpec)
|
||||
)
|
||||
return
|
||||
|
||||
const input = node.addInput(inputName, inputSpec.type, {
|
||||
shape: inputSpec.isOptional ? RenderShape.HollowCircle : undefined,
|
||||
|
||||
@@ -60,3 +60,7 @@ export const isResultItemType = (
|
||||
): value is ResultItemType => {
|
||||
return value === 'input' || value === 'output' || value === 'temp'
|
||||
}
|
||||
|
||||
export function isStrings(types: unknown[]): types is string[] {
|
||||
return types.every((t) => typeof t === 'string')
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { setActivePinia } from 'pinia'
|
||||
import { createTestingPinia } from '@pinia/testing'
|
||||
import { describe, expect, test } from 'vitest'
|
||||
import { LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { LGraph, LGraphNode } from '@/lib/litegraph/src/litegraph'
|
||||
import { transformInputSpecV1ToV2 } from '@/schemas/nodeDef/migration'
|
||||
import type { InputSpec } from '@/schemas/nodeDefSchema'
|
||||
import { useLitegraphService } from '@/services/litegraphService'
|
||||
@@ -12,6 +12,10 @@ type DynamicInputs = ('INT' | 'STRING' | 'IMAGE' | DynamicInputs)[][]
|
||||
|
||||
const { addNodeInput } = useLitegraphService()
|
||||
|
||||
function nextTick() {
|
||||
return new Promise<void>((r) => requestAnimationFrame(() => r()))
|
||||
}
|
||||
|
||||
function addDynamicCombo(node: LGraphNode, inputs: DynamicInputs) {
|
||||
const namePrefix = `${node.widgets?.length ?? 0}`
|
||||
function getSpec(
|
||||
@@ -40,6 +44,21 @@ function addDynamicCombo(node: LGraphNode, inputs: DynamicInputs) {
|
||||
transformInputSpecV1ToV2(inputSpec, { name: namePrefix, isOptional: false })
|
||||
)
|
||||
}
|
||||
function addAutogrow(node: LGraphNode, template: unknown) {
|
||||
addNodeInput(
|
||||
node,
|
||||
transformInputSpecV1ToV2(['COMFY_AUTOGROW_V3', { template }], {
|
||||
name: `${node.inputs.length}`,
|
||||
isOptional: false
|
||||
})
|
||||
)
|
||||
}
|
||||
function connectInput(node: LGraphNode, inputIndex: number, graph: LGraph) {
|
||||
const node2 = testNode()
|
||||
node2.addOutput('out', '*')
|
||||
graph.add(node2)
|
||||
node2.connect(0, node, inputIndex)
|
||||
}
|
||||
function testNode() {
|
||||
const node: LGraphNode & Partial<HasInitialMinSize> = new LGraphNode('test')
|
||||
node.widgets = []
|
||||
@@ -84,7 +103,76 @@ describe('Dynamic Combos', () => {
|
||||
node.widgets[0].value = '1'
|
||||
expect(node.widgets.length).toBe(2)
|
||||
expect(node.inputs.length).toBe(4)
|
||||
expect(node.inputs[1].name).toBe('0.0.0')
|
||||
expect(node.inputs[3].name).toBe('2.0.0')
|
||||
expect(node.inputs[1].name).toBe('0.0.0.0')
|
||||
expect(node.inputs[3].name).toBe('2.2.0.0')
|
||||
})
|
||||
})
|
||||
describe('Autogrow', () => {
|
||||
const inputsSpec = { required: { image: ['IMAGE', {}] } }
|
||||
test('Can name by prefix', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { input: inputsSpec, prefix: 'test' })
|
||||
connectInput(node, 0, graph)
|
||||
connectInput(node, 1, graph)
|
||||
connectInput(node, 2, graph)
|
||||
expect(node.inputs.length).toBe(4)
|
||||
expect(node.inputs[0].name).toBe('test0')
|
||||
expect(node.inputs[2].name).toBe('test2')
|
||||
})
|
||||
test('Can name by list of names', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { input: inputsSpec, names: ['a', 'b', 'c'] })
|
||||
connectInput(node, 0, graph)
|
||||
connectInput(node, 1, graph)
|
||||
connectInput(node, 2, graph)
|
||||
expect(node.inputs.length).toBe(3)
|
||||
expect(node.inputs[0].name).toBe('a')
|
||||
expect(node.inputs[2].name).toBe('c')
|
||||
})
|
||||
test('Can add autogrow with min input count', () => {
|
||||
const node = testNode()
|
||||
addAutogrow(node, { min: 4, input: inputsSpec })
|
||||
expect(node.inputs.length).toBe(4)
|
||||
})
|
||||
test('Adding connections will cause growth up to max', () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 1, input: inputsSpec, prefix: 'test', max: 3 })
|
||||
expect(node.inputs.length).toBe(1)
|
||||
|
||||
connectInput(node, 0, graph)
|
||||
expect(node.inputs.length).toBe(2)
|
||||
connectInput(node, 1, graph)
|
||||
expect(node.inputs.length).toBe(3)
|
||||
connectInput(node, 2, graph)
|
||||
expect(node.inputs.length).toBe(3)
|
||||
})
|
||||
test('Removing connections decreases to min', async () => {
|
||||
const graph = new LGraph()
|
||||
const node = testNode()
|
||||
graph.add(node)
|
||||
addAutogrow(node, { min: 4, input: inputsSpec, prefix: 'test' })
|
||||
connectInput(node, 3, graph)
|
||||
connectInput(node, 4, graph)
|
||||
connectInput(node, 5, graph)
|
||||
expect(node.inputs.length).toBe(7)
|
||||
|
||||
node.disconnectInput(4)
|
||||
await nextTick()
|
||||
expect(node.inputs.length).toBe(6)
|
||||
node.disconnectInput(3)
|
||||
await nextTick()
|
||||
expect(node.inputs.length).toBe(5)
|
||||
|
||||
connectInput(node, 0, graph)
|
||||
expect(node.inputs.length).toBe(5)
|
||||
node.disconnectInput(0)
|
||||
await nextTick()
|
||||
expect(node.inputs.length).toBe(5)
|
||||
})
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user