Color links as common type (#7211)

Previously the color of a link would simply use the type of the target
slot and fallback to the type of the origin slot. When a connection is
made to a node that accepts the any type ('*'), the link has the green
color of an unknown type.

Instead, when a connection is made, the type of a link is now calculated
as the greatest common type of the source and destination. This means
that connections to reroutes are correctly colored.

| Before | After |
| ------ | ----- |
| <img width="360" alt="before"
src="https://github.com/user-attachments/assets/a5544730-e69a-4c85-af33-b303bb30ae71"
/>| <img width="360" alt="after"
src="https://github.com/user-attachments/assets/7d7b59fd-1b79-440b-a97d-a1657313c484"
/>|

The code for calculating common types already exists, it has simply been
moved into litegraph and given a more descriptive name.

Resolves #7196

┆Issue is synchronized with this [Notion
page](https://www.notion.so/PR-7211-Color-links-as-common-type-2c16d73d365081188460f6b5973db962)
by [Unito](https://www.unito.io)

---------

Co-authored-by: Alexander Brown <drjkl@comfy.org>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
This commit is contained in:
AustinMroz
2025-12-06 12:00:07 -08:00
committed by GitHub
parent 1d18583e42
commit a8f6bea371
4 changed files with 39 additions and 38 deletions

View File

@@ -1,5 +1,3 @@
import { without } from 'es-toolkit'
import { useChainCallback } from '@/composables/functional/useChainCallback'
import { NodeSlotType } from '@/lib/litegraph/src/types/globalEnums'
import type {
@@ -10,6 +8,7 @@ import type {
import type { LGraphNode } from '@/lib/litegraph/src/LGraphNode'
import { LiteGraph } from '@/lib/litegraph/src/litegraph'
import type { LLink } from '@/lib/litegraph/src/LLink'
import { commonType } from '@/lib/litegraph/src/utils/type'
import { transformInputSpecV1ToV2 } from '@/schemas/nodeDef/migration'
import type { ComboInputSpec, InputSpec } from '@/schemas/nodeDefSchema'
import type { InputSpec as InputSpecV2 } from '@/schemas/nodeDef/nodeDefSchemaV2'
@@ -20,7 +19,6 @@ import {
import { useLitegraphService } from '@/services/litegraphService'
import { app } from '@/scripts/app'
import type { ComfyApp } from '@/scripts/app'
import { isStrings } from '@/utils/typeGuardUtil'
const INLINE_INPUTS = false
@@ -244,30 +242,6 @@ function changeOutputType(
}
}
function combineTypes(...types: ISlotType[]): ISlotType | undefined {
if (!isStrings(types)) return undefined
const withoutWildcards = without(types, '*')
if (withoutWildcards.length === 0) return '*'
const typeLists: string[][] = withoutWildcards.map((type) => type.split(','))
const combinedTypes = intersection(...typeLists)
if (combinedTypes.length === 0) return undefined
return combinedTypes.join(',')
}
function intersection(...sets: string[][]): string[] {
const itemCounts: Record<string, number> = {}
for (const set of sets)
for (const item of new Set(set))
itemCounts[item] = (itemCounts[item] ?? 0) + 1
return Object.entries(itemCounts)
.filter(([, count]) => count == sets.length)
.map(([key]) => key)
}
function withComfyMatchType(node: LGraphNode): asserts node is MatchTypeNode {
if (node.comfyMatchType) return
node.comfyMatchType = {}
@@ -290,8 +264,6 @@ function withComfyMatchType(node: LGraphNode): asserts node is MatchTypeNode {
if (!matchGroup) return
if (iscon && linf) {
const { output, subgraphInput } = linf.resolve(this.graph)
//TODO: fix this bug globally. A link type (and therefore color)
//should be the combinedType of origin and target type
const connectingType = (output ?? subgraphInput)?.type
if (connectingType) linf.type = connectingType
}
@@ -316,14 +288,14 @@ function withComfyMatchType(node: LGraphNode): asserts node is MatchTypeNode {
...connectedTypes.slice(0, idx),
...connectedTypes.slice(idx + 1)
]
const combinedType = combineTypes(
const combinedType = commonType(
...otherConnected,
matchGroup[input.name]
)
if (!combinedType) throw new Error('invalid connection')
input.type = combinedType
})
const outputType = combineTypes(...connectedTypes)
const outputType = commonType(...connectedTypes)
if (!outputType) throw new Error('invalid connection')
this.outputs.forEach((output, idx) => {
if (!(outputGroups?.[idx] == matchKey)) return

View File

@@ -10,6 +10,7 @@ import { useLayoutMutations } from '@/renderer/core/layout/operations/layoutMuta
import { LayoutSource } from '@/renderer/core/layout/types'
import { adjustColor } from '@/utils/colorUtil'
import type { ColorAdjustOptions } from '@/utils/colorUtil'
import { commonType, toClass } from '@/lib/litegraph/src/utils/type'
import { SUBGRAPH_OUTPUT_ID } from '@/lib/litegraph/src/constants'
import type { DragAndScale } from './DragAndScale'
@@ -84,7 +85,6 @@ import { findFreeSlotOfType } from './utils/collections'
import { warnDeprecated } from './utils/feedback'
import { distributeSpace } from './utils/spaceDistribution'
import { truncateText } from './utils/textUtils'
import { toClass } from './utils/type'
import { BaseWidget } from './widgets/BaseWidget'
import { toConcreteWidget } from './widgets/widgetMap'
import type { WidgetTypeMap } from './widgets/widgetMap'
@@ -2832,9 +2832,12 @@ export class LGraphNode
inputNode.disconnectInput(inputIndex, true)
}
const maybeCommonType =
input.type && output.type && commonType(input.type, output.type)
const link = new LLink(
++graph.state.lastLinkId,
input.type || output.type,
maybeCommonType || input.type || output.type,
this.id,
outputIndex,
inputNode.id,

View File

@@ -1,4 +1,6 @@
import type { IColorable } from '@/lib/litegraph/src/interfaces'
import { without } from 'es-toolkit'
import type { IColorable, ISlotType } from '@/lib/litegraph/src/interfaces'
/**
* Converts a plain object to a class instance if it is not already an instance of the class.
@@ -26,3 +28,31 @@ export function isColorable(obj: unknown): obj is IColorable {
'getColorOption' in obj
)
}
export function commonType(...types: ISlotType[]): ISlotType | undefined {
if (!isStrings(types)) return undefined
const withoutWildcards = without(types, '*')
if (withoutWildcards.length === 0) return '*'
const typeLists: string[][] = withoutWildcards.map((type) => type.split(','))
const combinedTypes = intersection(...typeLists)
if (combinedTypes.length === 0) return undefined
return combinedTypes.join(',')
}
function intersection(...sets: string[][]): string[] {
const itemCounts: Record<string, number> = {}
for (const set of sets)
for (const item of new Set(set))
itemCounts[item] = (itemCounts[item] ?? 0) + 1
return Object.entries(itemCounts)
.filter(([, count]) => count === sets.length)
.map(([key]) => key)
}
function isStrings(types: unknown[]): types is string[] {
return types.every((t) => typeof t === 'string')
}

View File

@@ -60,7 +60,3 @@ export const isResultItemType = (
): value is ResultItemType => {
return value === 'input' || value === 'output' || value === 'temp'
}
export function isStrings(types: unknown[]): types is string[] {
return types.every((t) => typeof t === 'string')
}