Add front end support for type matching (#6582)

This PR implements front end logic to handle MatchType inputs and
outputs.
See  comfyanonymous/ComfyUI#10644

This allows for the implementation of nodes such as a "switch node"
where input types change based on the connections made.

![switch-node](https://github.com/user-attachments/assets/090515ba-484c-4295-b7b3-204b0c72fc4a)

As part of this implementation, significant cleanup is being performed
in the reroute code. Extra testing will be required to make sure these
changes don't introduce regressions.

┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-6582-Add-front-end-support-for-type-matching-2a16d73d36508189b042cd23f82a332e)
by [Unito](https://www.unito.io)
This commit is contained in:
AustinMroz
2025-11-12 12:30:58 -08:00
committed by GitHub
parent cfbd5361d3
commit 23b0d2eb7f
9 changed files with 393 additions and 245 deletions

View File

@@ -10,6 +10,7 @@ import './groupNodeManage'
import './groupOptions'
import './load3d'
import './maskeditor'
import './matchType'
import './nodeTemplates'
import './noteNode'
import './previewAny'

View File

@@ -0,0 +1,155 @@
import { without } from 'es-toolkit'
import { useChainCallback } from '@/composables/functional/useChainCallback'
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
import type { LLink } from '@/lib/litegraph/src/LLink'
import type { ISlotType } from '@/lib/litegraph/src/interfaces'
import { app } from '@/scripts/app'
const MATCH_TYPE = 'COMFY_MATCHTYPE_V3'
app.registerExtension({
name: 'Comfy.MatchType',
beforeRegisterNodeDef(nodeType, nodeData) {
const inputs = {
...nodeData.input?.required,
...nodeData.input?.optional
}
if (!Object.values(inputs).some((w) => w[0] === MATCH_TYPE)) return
nodeType.prototype.onNodeCreated = useChainCallback(
nodeType.prototype.onNodeCreated,
function (this: LGraphNode) {
const inputGroups: Record<string, [string, ISlotType][]> = {}
const outputGroups: Record<string, number[]> = {}
for (const input of this.inputs) {
if (input.type !== MATCH_TYPE) continue
const template = inputs[input.name][1]?.template
if (!template) continue
input.type = template.allowed_types ?? '*'
inputGroups[template.template_id] ??= []
inputGroups[template.template_id].push([input.name, input.type])
}
this.outputs.forEach((output, i) => {
if (output.type !== MATCH_TYPE) return
const id = nodeData.output_matchtypes?.[i]
if (id == undefined) return
outputGroups[id] ??= []
outputGroups[id].push(i)
})
for (const groupId in inputGroups) {
addConnectionGroup(this, inputGroups[groupId], outputGroups[groupId])
}
}
)
}
})
function addConnectionGroup(
node: LGraphNode,
inputPairs: [string, ISlotType][],
outputs?: number[]
) {
const connectedTypes: ISlotType[] = new Array(inputPairs.length).fill('*')
node.onConnectionsChange = useChainCallback(
node.onConnectionsChange,
function (
this: LGraphNode,
contype: ISlotType,
slot: number,
iscon: boolean,
linf: LLink | null | undefined
) {
const input = this.inputs[slot]
if (contype !== LiteGraph.INPUT || !this.graph || !input) return
const pairIndex = inputPairs.findIndex(([name]) => name === input.name)
if (pairIndex == -1) return
connectedTypes[pairIndex] = inputPairs[pairIndex][1]
if (iscon && linf) {
const { output, subgraphInput } = linf.resolve(this.graph)
const connectingType = (output ?? subgraphInput)?.type
if (connectingType)
linf.type = connectedTypes[pairIndex] = connectingType
}
//An input slot can accept a connection that is
// - Compatible with original type
// - Compatible with all other input types
//An output slot can output
// - Only what every input can output
for (let i = 0; i < inputPairs.length; i++) {
//NOTE: This isn't great. Originally, I kept direct references to each
//input, but these were becoming orphaned
const input = this.inputs.find((inp) => inp.name === inputPairs[i][0])
if (!input) continue
const otherConnected = [...connectedTypes]
otherConnected.splice(i, 1)
const validType = combineTypes(...otherConnected, inputPairs[i][1])
if (!validType) throw new Error('invalid connection')
input.type = validType
}
if (outputs) {
const outputType = combineTypes(...connectedTypes)
if (!outputType) throw new Error('invalid connection')
changeOutputType(this, outputType, outputs)
}
}
)
}
function changeOutputType(
node: LGraphNode,
combinedType: ISlotType,
outputs: number[]
) {
if (!node.graph) return
for (const index of outputs) {
if (node.outputs[index].type === combinedType) continue
node.outputs[index].type = combinedType
//check and potentially remove links
for (let link_id of node.outputs[index].links ?? []) {
let link = node.graph.links[link_id]
if (!link) continue
const { input, inputNode, subgraphOutput } = link.resolve(node.graph)
const inputType = (input ?? subgraphOutput)?.type
if (!inputType) continue
const keep = LiteGraph.isValidConnection(combinedType, inputType)
if (!keep && subgraphOutput) subgraphOutput.disconnect()
else if (!keep && inputNode) inputNode.disconnectInput(link.target_slot)
if (input && inputNode?.onConnectionsChange)
inputNode.onConnectionsChange(
LiteGraph.INPUT,
link.target_slot,
keep,
link,
input
)
}
app.canvas.setDirty(true, true)
}
}
function isStrings(types: ISlotType[]): types is string[] {
return !types.some((t) => typeof t !== 'string')
}
function combineTypes(...types: ISlotType[]): ISlotType | undefined {
if (!isStrings(types)) return undefined
const withoutWildcards = without(types, '*')
if (withoutWildcards.length === 0) return '*'
const typeLists: string[][] = withoutWildcards.map((type) => type.split(','))
const combinedTypes = intersection(...typeLists)
if (combinedTypes.length === 0) return undefined
return combinedTypes.join(',')
}
function intersection(...sets: string[][]): string[] {
const itemCounts: Record<string, number> = {}
for (const set of sets)
for (const item of new Set(set))
itemCounts[item] = (itemCounts[item] ?? 0) + 1
return Object.entries(itemCounts)
.filter(([, count]) => count == sets.length)
.map(([key]) => key)
}

View File

@@ -4,6 +4,7 @@ import {
LGraphNode,
LiteGraph
} from '@/lib/litegraph/src/litegraph'
import type { ISlotType } from '@/lib/litegraph/src/interfaces'
import { app } from '../../scripts/app'
import { getWidgetConfig, mergeIfValid, setWidgetConfig } from './widgetInputs'
@@ -14,7 +15,7 @@ app.registerExtension({
name: 'Comfy.RerouteNode',
registerCustomNodes(app) {
interface RerouteNode extends LGraphNode {
__outputType?: string
__outputType?: string | number
}
class RerouteNode extends LGraphNode {
@@ -22,8 +23,7 @@ app.registerExtension({
static defaultVisibility = false
constructor(title?: string) {
// @ts-expect-error fixme ts strict error
super(title)
super(title ?? '')
if (!this.properties) {
this.properties = {}
}
@@ -33,225 +33,198 @@ app.registerExtension({
this.addInput('', '*')
this.addOutput(this.properties.showOutputText ? '*' : '', '*')
this.onAfterGraphConfigured = function () {
requestAnimationFrame(() => {
// @ts-expect-error fixme ts strict error
this.onConnectionsChange(LiteGraph.INPUT, null, true, null)
})
}
this.onConnectionsChange = (type, _index, connected) => {
if (app.configuringGraph) return
// Prevent multiple connections to different types when we have no input
if (connected && type === LiteGraph.OUTPUT) {
// Ignore wildcard nodes as these will be updated to real types
const types = new Set(
// @ts-expect-error fixme ts strict error
this.outputs[0].links
.map((l) => app.graph.links[l].type)
.filter((t) => t !== '*')
)
if (types.size > 1) {
const linksToDisconnect = []
// @ts-expect-error fixme ts strict error
for (let i = 0; i < this.outputs[0].links.length - 1; i++) {
// @ts-expect-error fixme ts strict error
const linkId = this.outputs[0].links[i]
const link = app.graph.links[linkId]
linksToDisconnect.push(link)
}
for (const link of linksToDisconnect) {
const node = app.graph.getNodeById(link.target_id)
// @ts-expect-error fixme ts strict error
node.disconnectInput(link.target_slot)
}
}
}
// Find root input
let currentNode: LGraphNode | null = this
let updateNodes = []
let inputType = null
let inputNode = null
while (currentNode) {
updateNodes.unshift(currentNode)
const linkId = currentNode.inputs[0].link
if (linkId !== null) {
const link = app.graph.links[linkId]
if (!link) return
const node = app.graph.getNodeById(link.origin_id)
// @ts-expect-error fixme ts strict error
const type = node.constructor.type
if (type === 'Reroute') {
if (node === this) {
// We've found a circle
currentNode.disconnectInput(link.target_slot)
currentNode = null
} else {
// Move the previous node
currentNode = node
}
} else {
// We've found the end
inputNode = currentNode
// @ts-expect-error fixme ts strict error
inputType = node.outputs[link.origin_slot]?.type ?? null
break
}
} else {
// This path has no input node
currentNode = null
break
}
}
// Find all outputs
const nodes: LGraphNode[] = [this]
let outputType = null
while (nodes.length) {
// @ts-expect-error fixme ts strict error
currentNode = nodes.pop()
const outputs =
// @ts-expect-error fixme ts strict error
(currentNode.outputs ? currentNode.outputs[0].links : []) || []
if (outputs.length) {
for (const linkId of outputs) {
const link = app.graph.links[linkId]
// When disconnecting sometimes the link is still registered
if (!link) continue
const node = app.graph.getNodeById(link.target_id)
// @ts-expect-error fixme ts strict error
const type = node.constructor.type
if (type === 'Reroute') {
// Follow reroute nodes
// @ts-expect-error fixme ts strict error
nodes.push(node)
updateNodes.push(node)
} else {
// We've found an output
const nodeOutType =
// @ts-expect-error fixme ts strict error
node.inputs &&
// @ts-expect-error fixme ts strict error
node.inputs[link?.target_slot] &&
// @ts-expect-error fixme ts strict error
node.inputs[link.target_slot].type
? // @ts-expect-error fixme ts strict error
node.inputs[link.target_slot].type
: null
if (
inputType &&
// @ts-expect-error fixme ts strict error
!LiteGraph.isValidConnection(inputType, nodeOutType)
) {
// The output doesnt match our input so disconnect it
// @ts-expect-error fixme ts strict error
node.disconnectInput(link.target_slot)
} else {
outputType = nodeOutType
}
}
}
} else {
// No more outputs for this path
}
}
const displayType = inputType || outputType || '*'
const color = LGraphCanvas.link_type_colors[displayType]
let widgetConfig
let widgetType
// Update the types of each node
for (const node of updateNodes) {
// If we dont have an input type we are always wildcard but we'll show the output type
// This lets you change the output link to a different type and all nodes will update
// @ts-expect-error fixme ts strict error
node.outputs[0].type = inputType || '*'
// @ts-expect-error fixme ts strict error
node.__outputType = displayType
// @ts-expect-error fixme ts strict error
node.outputs[0].name = node.properties.showOutputText
? displayType
: ''
// @ts-expect-error fixme ts strict error
node.setSize(node.computeSize())
// @ts-expect-error fixme ts strict error
for (const l of node.outputs[0].links || []) {
const link = app.graph.links[l]
if (link) {
link.color = color
if (app.configuringGraph) continue
const targetNode = app.graph.getNodeById(link.target_id)
// @ts-expect-error fixme ts strict error
const targetInput = targetNode.inputs?.[link.target_slot]
if (targetInput?.widget) {
const config = getWidgetConfig(targetInput)
if (!widgetConfig) {
widgetConfig = config[1] ?? {}
widgetType = config[0]
}
const merged = mergeIfValid(targetInput, [
config[0],
widgetConfig
])
if (merged.customConfig) {
widgetConfig = merged.customConfig
}
}
}
}
}
for (const node of updateNodes) {
if (widgetConfig && outputType) {
// @ts-expect-error fixme ts strict error
node.inputs[0].widget = { name: 'value' }
// @ts-expect-error fixme ts strict error
setWidgetConfig(node.inputs[0], [
// @ts-expect-error fixme ts strict error
widgetType ?? displayType,
widgetConfig
])
} else {
// @ts-expect-error fixme ts strict error
setWidgetConfig(node.inputs[0], null)
}
}
if (inputNode) {
// @ts-expect-error fixme ts strict error
const link = app.graph.links[inputNode.inputs[0].link]
if (link) {
link.color = color
}
}
}
this.clone = function () {
const cloned = RerouteNode.prototype.clone.apply(this)
// @ts-expect-error fixme ts strict error
cloned.removeOutput(0)
// @ts-expect-error fixme ts strict error
cloned.addOutput(this.properties.showOutputText ? '*' : '', '*')
// @ts-expect-error fixme ts strict error
cloned.setSize(cloned.computeSize())
return cloned
}
// This node is purely frontend and does not impact the resulting prompt so should not be serialized
this.isVirtualNode = true
}
override onAfterGraphConfigured() {
requestAnimationFrame(() => {
this.onConnectionsChange(LiteGraph.INPUT, undefined, true)
})
}
override clone(): LGraphNode | null {
const cloned = super.clone()
if (!cloned) return cloned
cloned.removeOutput(0)
cloned.addOutput(this.properties.showOutputText ? '*' : '', '*')
cloned.setSize(cloned.computeSize())
return cloned
}
override onConnectionsChange(
type: ISlotType,
_index: number | undefined,
connected: boolean
) {
const { graph } = this
if (!graph) return
if (app.configuringGraph) return
// @ts-expect-error fixme ts strict error
getExtraMenuOptions(_, options): IContextMenuValue[] {
// Prevent multiple connections to different types when we have no input
if (connected && type === LiteGraph.OUTPUT) {
// Ignore wildcard nodes as these will be updated to real types
const types = new Set(
this.outputs[0].links
?.map((l) => graph.links[l]?.type)
?.filter((t) => t && t !== '*') ?? []
)
if (types.size > 1) {
const linksToDisconnect = []
for (const linkId of this.outputs[0].links ?? []) {
const link = graph.links[linkId]
linksToDisconnect.push(link)
}
linksToDisconnect.pop()
for (const link of linksToDisconnect) {
const node = graph.getNodeById(link.target_id)
node?.disconnectInput(link.target_slot)
}
}
}
// Find root input
let currentNode: RerouteNode | null = this
let updateNodes: RerouteNode[] = []
let inputType = null
let inputNode = null
while (currentNode) {
updateNodes.unshift(currentNode)
const linkId = currentNode.inputs[0].link
if (linkId !== null) {
const link = graph.links[linkId]
if (!link) return
const node = graph.getNodeById(link.origin_id)
if (!node) return
if (node instanceof RerouteNode) {
if (node === this) {
// We've found a circle
currentNode.disconnectInput(link.target_slot)
currentNode = null
} else {
// Move the previous node
currentNode = node
}
} else {
// We've found the end
inputNode = currentNode
inputType = node.outputs[link.origin_slot]?.type ?? null
break
}
} else {
// This path has no input node
currentNode = null
break
}
}
// Find all outputs
const nodes: RerouteNode[] = [this]
let outputType = null
while (nodes.length) {
currentNode = nodes.pop()!
const outputs = currentNode.outputs?.[0]?.links ?? []
for (const linkId of outputs) {
const link = graph.links[linkId]
// When disconnecting sometimes the link is still registered
if (!link) continue
const node = graph.getNodeById(link.target_id)
if (!node) continue
if (node instanceof RerouteNode) {
// Follow reroute nodes
nodes.push(node)
updateNodes.push(node)
} else {
// We've found an output
const nodeInput = node.inputs[link.target_slot]
const nodeOutType = nodeInput.type
const keep =
!inputType ||
!nodeOutType ||
LiteGraph.isValidConnection(inputType, nodeOutType)
if (!keep) {
// The output doesnt match our input so disconnect it
node.disconnectInput(link.target_slot)
continue
}
node.onConnectionsChange?.(
LiteGraph.INPUT,
link.target_slot,
keep,
link,
nodeInput
)
outputType = node.inputs[link.target_slot].type
}
}
}
const displayType = inputType || outputType || '*'
const color = LGraphCanvas.link_type_colors[displayType]
let widgetConfig
let widgetType
// Update the types of each node
for (const node of updateNodes) {
// If we dont have an input type we are always wildcard but we'll show the output type
// This lets you change the output link to a different type and all nodes will update
node.outputs[0].type = inputType || '*'
node.__outputType = displayType
node.outputs[0].name = node.properties.showOutputText
? `${displayType}`
: ''
node.setSize(node.computeSize())
for (const l of node.outputs[0].links || []) {
const link = graph.links[l]
if (!link) continue
link.color = color
if (app.configuringGraph) continue
const targetNode = graph.getNodeById(link.target_id)
if (!targetNode) continue
const targetInput = targetNode.inputs?.[link.target_slot]
if (targetInput?.widget) {
const config = getWidgetConfig(targetInput)
if (!widgetConfig) {
widgetConfig = config[1] ?? {}
widgetType = config[0]
}
const merged = mergeIfValid(targetInput, [
config[0],
widgetConfig
])
if (merged.customConfig) {
widgetConfig = merged.customConfig
}
}
}
}
for (const node of updateNodes) {
if (widgetConfig && outputType) {
node.inputs[0].widget = { name: 'value' }
setWidgetConfig(node.inputs[0], [
widgetType ?? `${displayType}`,
widgetConfig
])
} else {
setWidgetConfig(node.inputs[0], undefined)
}
}
if (inputNode?.inputs?.[0]?.link) {
const link = graph.links[inputNode.inputs[0].link]
if (link) {
link.color = color
}
}
}
override getExtraMenuOptions(
_: unknown,
options: (IContextMenuValue | null)[]
): IContextMenuValue[] {
options.unshift(
{
content:
@@ -259,13 +232,12 @@ app.registerExtension({
callback: () => {
this.properties.showOutputText = !this.properties.showOutputText
if (this.properties.showOutputText) {
this.outputs[0].name =
this.__outputType || (this.outputs[0].type as string)
this.outputs[0].name = `${this.__outputType || this.outputs[0].type}`
} else {
this.outputs[0].name = ''
}
this.setSize(this.computeSize())
app.graph.setDirtyCanvas(true, true)
app.canvas.setDirty(true, true)
}
},
{
@@ -294,8 +266,7 @@ app.registerExtension({
]
}
// @ts-expect-error fixme ts strict error
static setDefaultTextVisibility(visible) {
static setDefaultTextVisibility(visible: boolean) {
RerouteNode.defaultVisibility = visible
if (visible) {
localStorage['Comfy.RerouteNode.DefaultVisibility'] = 'true'

View File

@@ -443,7 +443,7 @@ function getWidgetType(config: InputSpec) {
export function setWidgetConfig(
slot: INodeInputSlot | INodeOutputSlot,
config: InputSpec
config?: InputSpec
) {
if (!slot.widget) return
if (config) {

View File

@@ -2,6 +2,8 @@ import {
SUBGRAPH_INPUT_ID,
SUBGRAPH_OUTPUT_ID
} from '@/lib/litegraph/src/constants'
import type { SubgraphInput } from '@/lib/litegraph/src/subgraph/SubgraphInput'
import type { SubgraphOutput } from '@/lib/litegraph/src/subgraph/SubgraphOutput'
import { useLayoutMutations } from '@/renderer/core/layout/operations/layoutMutations'
import { LayoutSource } from '@/renderer/core/layout/types'
@@ -17,11 +19,7 @@ import type {
Point,
ReadonlyLinkNetwork
} from './interfaces'
import type {
Serialisable,
SerialisableLLink,
SubgraphIO
} from './types/serialisation'
import type { Serialisable, SerialisableLLink } from './types/serialisation'
const layoutMutations = useLayoutMutations()
@@ -55,9 +53,9 @@ interface BaseResolvedConnection {
/** The output the link is connected to (mutually exclusive with {@link subgraphInput}) */
output?: INodeOutputSlot
/** The subgraph output the link is connected to (mutually exclusive with {@link input}) */
subgraphOutput?: SubgraphIO
subgraphOutput?: SubgraphOutput
/** The subgraph input the link is connected to (mutually exclusive with {@link output}) */
subgraphInput?: SubgraphIO
subgraphInput?: SubgraphInput
}
interface ResolvedNormalInput {
@@ -76,13 +74,13 @@ interface ResolvedSubgraphInput {
inputNode?: undefined
/** The actual input slot the link is connected to (mutually exclusive with {@link subgraphOutput}) */
input?: undefined
subgraphOutput: SubgraphIO
subgraphOutput: SubgraphOutput
}
interface ResolvedSubgraphOutput {
outputNode?: undefined
output?: undefined
subgraphInput: SubgraphIO
subgraphInput: SubgraphInput
}
type BasicReadonlyNetwork = Pick<

View File

@@ -287,11 +287,7 @@ export class ExecutableNodeDTO implements ExecutableLGraphNode {
if (node.isSubgraphNode())
return this.#resolveSubgraphOutput(slot, type, visited)
// Upstreamed: Other virtual nodes are bypassed using the same input/output index (slots must match)
if (node.isVirtualNode) {
if (this.inputs.at(slot)) return this.resolveInput(slot, visited, type)
// Fallback check for nodes performing link redirection
const virtualLink = this.node.getInputLink(slot)
if (virtualLink) {
const { inputNode } = virtualLink.resolve(this.graph)

View File

@@ -206,14 +206,16 @@ export class SubgraphInputNode
link.id
)
}
node.onConnectionsChange?.(
NodeSlotType.OUTPUT,
index,
false,
link,
subgraphInput
)
const slotIndex = node.inputs.findIndex((inp) => inp === input)
if (slotIndex !== -1) {
node.onConnectionsChange?.(
NodeSlotType.INPUT,
slotIndex,
false,
link,
subgraphInput
)
}
}
override drawProtected(

View File

@@ -153,4 +153,23 @@ export class SubgraphOutput extends SubgraphSlot {
return false
}
override disconnect() {
const { subgraph } = this.parent
//should never have more than one connection
for (const linkId of this.linkIds) {
const link = subgraph.links[linkId]
subgraph.removeLink(linkId)
const { output, outputNode } = link.resolve(subgraph)
if (output)
output.links = output.links?.filter((id) => id !== linkId) ?? null
outputNode?.onConnectionsChange?.(
NodeSlotType.OUTPUT,
link.origin_slot,
false,
link,
this
)
}
this.linkIds.length = 0
}
}

View File

@@ -14,6 +14,10 @@ const zRemoteWidgetConfig = z.object({
timeout: z.number().gte(0).optional(),
max_retries: z.number().gte(0).optional()
})
const zWidgetTemplate = z.object({
template_id: z.string(),
allowed_types: z.string().optional()
})
const zMultiSelectOption = z.object({
placeholder: z.string().optional(),
chip: z.boolean().optional()
@@ -28,6 +32,7 @@ export const zBaseInputOptions = z
hidden: z.boolean().optional(),
advanced: z.boolean().optional(),
widgetType: z.string().optional(),
template: zWidgetTemplate.optional(),
/** Backend-only properties. */
rawLink: z.boolean().optional(),
lazy: z.boolean().optional()
@@ -201,6 +206,7 @@ export const zComfyNodeDef = z.object({
output_is_list: z.array(z.boolean()).optional(),
output_name: z.array(z.string()).optional(),
output_tooltips: z.array(z.string()).optional(),
output_matchtypes: z.array(z.string().optional()).optional(),
name: z.string(),
display_name: z.string(),
description: z.string(),