feat: add VRAM requirement estimation for workflow templates

Add a frontend heuristic that estimates peak VRAM consumption by
detecting model-loading nodes in the workflow graph and summing
approximate memory costs per model category (checkpoints, LoRAs,
ControlNets, VAEs, etc.). The estimate uses only the largest base
model (checkpoint or diffusion_model) since ComfyUI offloads others,
plus all co-resident models and a flat runtime overhead.

Surfaces the estimate in three places:

1. Template publishing wizard (metadata step) — auto-detects VRAM on
   mount using the same graph traversal pattern as custom node
   detection, with a manual GB override input for fine-tuning.

2. Template marketplace cards — displays a VRAM badge in the top-left
   corner of template thumbnails using the existing SquareChip and
   CardTop slot infrastructure.

3. Workflow editor — floating indicator in the bottom-right of the
   graph canvas showing estimated VRAM for the current workflow.

Bumps version to 1.46.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
John Haugeland
2026-02-24 17:09:32 -08:00
parent 8361122586
commit bf63a5cc71
10 changed files with 489 additions and 3 deletions

View File

@@ -1,6 +1,6 @@
{
"name": "@comfyorg/comfyui-frontend",
"version": "1.45.0",
"version": "1.46.0",
"private": true,
"description": "Official front-end implementation of ComfyUI",
"homepage": "https://comfy.org",

View File

@@ -267,6 +267,16 @@
/>
</div>
</template>
<template v-if="template.vram" #top-left>
<SquareChip
:label="formatSize(template.vram)"
:title="t('templateWorkflows.vramEstimateTooltip')"
>
<template #icon>
<i class="icon-[lucide--cpu] h-3 w-3" />
</template>
</SquareChip>
</template>
<template #bottom-right>
<template v-if="template.tags && template.tags.length > 0">
<SquareChip
@@ -387,6 +397,7 @@
<script setup lang="ts">
import { useAsyncState } from '@vueuse/core'
import { formatSize } from '@/utils/formatUtil'
import ProgressSpinner from 'primevue/progressspinner'
import { computed, onBeforeUnmount, onMounted, provide, ref, watch } from 'vue'
import { useI18n } from 'vue-i18n'

View File

@@ -0,0 +1,31 @@
<!--
Floating indicator that displays the estimated VRAM requirement
for the currently loaded workflow graph.
-->
<template>
<div
v-if="vramEstimate > 0"
class="pointer-events-auto absolute bottom-3 right-3 z-10 inline-flex items-center gap-1.5 rounded-lg bg-zinc-500/40 px-2.5 py-1.5 text-xs font-medium text-white/90 backdrop-blur-sm"
:title="t('templateWorkflows.vramEstimateTooltip')"
>
<i class="icon-[lucide--cpu] h-3.5 w-3.5" />
{{ formatSize(vramEstimate) }}
</div>
</template>
<script setup lang="ts">
import { formatSize } from '@/utils/formatUtil'
import { ref, watchEffect } from 'vue'
import { useI18n } from 'vue-i18n'
import { estimateWorkflowVram } from '@/composables/useVramEstimation'
import { app } from '@/scripts/app'
const { t } = useI18n()
const vramEstimate = ref(0)
watchEffect(() => {
vramEstimate.value = estimateWorkflowVram(app.rootGraph)
})
</script>

View File

@@ -51,6 +51,10 @@ vi.mock('@/utils/graphTraversalUtil', () => ({
)
}))
vi.mock('@/composables/useVramEstimation', () => ({
estimateWorkflowVram: vi.fn(() => 5_000_000_000)
}))
vi.mock('@/stores/nodeDefStore', () => ({
useNodeDefStore: () => ({
nodeDefsByName: {
@@ -100,6 +104,9 @@ const i18n = createI18n({
requiredNodesDetected: 'Detected from workflow',
requiredNodesManualPlaceholder: 'Add custom node name…',
requiredNodesManualLabel: 'Additional custom nodes',
vramLabel: 'Estimated VRAM Requirement',
vramAutoDetected: 'Auto-detected from workflow:',
vramManualOverride: 'Manual override (GB):',
difficulty: {
beginner: 'Beginner',
intermediate: 'Intermediate',

View File

@@ -136,15 +136,46 @@
</div>
</div>
</div>
<div class="flex flex-col gap-2">
<span id="tpl-vram-label" class="text-sm text-muted">
{{ t('templatePublishing.steps.metadata.vramLabel') }}
</span>
<div class="flex items-center gap-3">
<i class="icon-[lucide--cpu] h-3.5 w-3.5 text-muted-foreground" />
<span class="text-xs text-muted-foreground">
{{ t('templatePublishing.steps.metadata.vramAutoDetected') }}
</span>
<span class="text-sm font-medium">
{{ formatSize(autoDetectedVram) }}
</span>
</div>
<div class="flex items-center gap-2">
<input
id="tpl-vram-override"
v-model.number="manualVramGb"
type="number"
min="0"
step="0.5"
class="h-8 w-24 rounded border border-border-default bg-secondary-background px-2 text-sm focus:outline-none"
aria-labelledby="tpl-vram-label"
/>
<span class="text-xs text-muted-foreground">
{{ t('templatePublishing.steps.metadata.vramManualOverride') }}
</span>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { computed, inject, onMounted, ref } from 'vue'
import { watchDebounced } from '@vueuse/core'
import { formatSize } from '@/utils/formatUtil'
import { useI18n } from 'vue-i18n'
import FormItem from '@/components/common/FormItem.vue'
import { estimateWorkflowVram } from '@/composables/useVramEstimation'
import type { FormItem as FormItemType } from '@/platform/settings/types'
import { app } from '@/scripts/app'
import { useNodeDefStore } from '@/stores/nodeDefStore'
@@ -260,6 +291,29 @@ function detectCustomNodePackages(): string[] {
}
const detectedCustomNodes = ref<string[]>([])
const autoDetectedVram = ref(0)
const GB = 1_073_741_824
/**
* Manual VRAM override in GB. When set to a positive number, this
* value (converted to bytes) takes precedence over the auto-detected
* estimate for `vramRequirement`.
*/
const manualVramGb = computed({
get: () => {
const stored = ctx.template.value.vramRequirement
if (!stored || stored === autoDetectedVram.value) return undefined
return Math.round((stored / GB) * 10) / 10
},
set: (gb: number | undefined) => {
if (gb && gb > 0) {
ctx.template.value.vramRequirement = Math.round(gb * GB)
} else {
ctx.template.value.vramRequirement = autoDetectedVram.value
}
}
})
onMounted(() => {
detectedCustomNodes.value = detectCustomNodes()
@@ -273,6 +327,11 @@ onMounted(() => {
if (existingPackages.length === 0) {
ctx.template.value.requiresCustomNodes = detectCustomNodePackages()
}
autoDetectedVram.value = estimateWorkflowVram(app.rootGraph)
if (!ctx.template.value.vramRequirement) {
ctx.template.value.vramRequirement = autoDetectedVram.value
}
})
const manualNodes = computed(() => {

View File

@@ -43,6 +43,10 @@
{{ t('templatePublishing.steps.preview.noneDetected') }}
</span>
</PreviewField>
<PreviewField
:label="t('templatePublishing.steps.preview.vramLabel')"
:value="vramLabel"
/>
</PreviewSection>
<!-- Description -->
@@ -220,6 +224,7 @@
</template>
<script setup lang="ts">
import { formatSize } from '@/utils/formatUtil'
import { computed, inject } from 'vue'
import { useI18n } from 'vue-i18n'
@@ -284,4 +289,10 @@ const difficultyLabel = computed(() => {
if (!difficulty) return t('templatePublishing.steps.preview.notProvided')
return t(`templatePublishing.steps.metadata.difficulty.${difficulty}`)
})
const vramLabel = computed(() => {
const vram = tpl.value.vramRequirement
if (!vram) return t('templatePublishing.steps.preview.notProvided')
return formatSize(vram)
})
</script>

View File

@@ -0,0 +1,220 @@
import { createPinia, setActivePinia } from 'pinia'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
detectModelNodes,
estimateWorkflowVram,
MODEL_VRAM_ESTIMATES,
RUNTIME_OVERHEAD
} from './useVramEstimation'
const mockGetCategoryForNodeType = vi.fn<(type: string) => string | undefined>()
const mockGetAllNodeProviders = vi.fn()
vi.mock('@/stores/modelToNodeStore', () => ({
useModelToNodeStore: () => ({
getCategoryForNodeType: mockGetCategoryForNodeType,
getAllNodeProviders: mockGetAllNodeProviders
})
}))
vi.mock('@/utils/graphTraversalUtil', () => ({
mapAllNodes: vi.fn(
(
graph: { nodes: Array<Record<string, unknown>> },
mapFn: (node: Record<string, unknown>) => unknown
) => graph.nodes.map(mapFn).filter((r) => r !== undefined)
)
}))
function makeNode(
type: string,
widgets: Array<{ name: string; value: unknown }> = []
) {
return {
type,
isSubgraphNode: () => false,
widgets
}
}
function makeGraph(nodes: ReturnType<typeof makeNode>[]) {
return { nodes } as never
}
describe('useVramEstimation', () => {
beforeEach(() => {
setActivePinia(createPinia())
mockGetCategoryForNodeType.mockReset()
mockGetAllNodeProviders.mockReset()
mockGetAllNodeProviders.mockReturnValue([])
})
describe('detectModelNodes', () => {
it('returns empty array for graph with no model nodes', () => {
mockGetCategoryForNodeType.mockReturnValue(undefined)
const graph = makeGraph([makeNode('KSampler'), makeNode('SaveImage')])
expect(detectModelNodes(graph)).toEqual([])
})
it('detects checkpoint loader nodes', () => {
mockGetCategoryForNodeType.mockImplementation((type: string) =>
type === 'CheckpointLoaderSimple' ? 'checkpoints' : undefined
)
const graph = makeGraph([
makeNode('CheckpointLoaderSimple'),
makeNode('KSampler')
])
const result = detectModelNodes(graph)
expect(result).toHaveLength(1)
expect(result[0].category).toBe('checkpoints')
})
it('deduplicates models with same category and filename', () => {
mockGetCategoryForNodeType.mockImplementation((type: string) =>
type === 'CheckpointLoaderSimple' ? 'checkpoints' : undefined
)
mockGetAllNodeProviders.mockReturnValue([
{
nodeDef: { name: 'CheckpointLoaderSimple' },
key: 'ckpt_name'
}
])
const graph = makeGraph([
makeNode('CheckpointLoaderSimple', [
{ name: 'ckpt_name', value: 'model.safetensors' }
]),
makeNode('CheckpointLoaderSimple', [
{ name: 'ckpt_name', value: 'model.safetensors' }
])
])
expect(detectModelNodes(graph)).toHaveLength(1)
})
it('keeps models with same category but different filenames', () => {
mockGetCategoryForNodeType.mockImplementation((type: string) =>
type === 'LoraLoader' ? 'loras' : undefined
)
mockGetAllNodeProviders.mockReturnValue([
{ nodeDef: { name: 'LoraLoader' }, key: 'lora_name' }
])
const graph = makeGraph([
makeNode('LoraLoader', [
{ name: 'lora_name', value: 'lora_a.safetensors' }
]),
makeNode('LoraLoader', [
{ name: 'lora_name', value: 'lora_b.safetensors' }
])
])
expect(detectModelNodes(graph)).toHaveLength(2)
})
})
describe('estimateWorkflowVram', () => {
it('returns 0 for null/undefined graph', () => {
expect(estimateWorkflowVram(null)).toBe(0)
expect(estimateWorkflowVram(undefined)).toBe(0)
})
it('returns 0 for graph with no model nodes', () => {
mockGetCategoryForNodeType.mockReturnValue(undefined)
expect(estimateWorkflowVram(makeGraph([makeNode('KSampler')]))).toBe(0)
})
it('estimates checkpoint-only workflow as base + overhead', () => {
mockGetCategoryForNodeType.mockImplementation((type: string) =>
type === 'CheckpointLoaderSimple' ? 'checkpoints' : undefined
)
const result = estimateWorkflowVram(
makeGraph([makeNode('CheckpointLoaderSimple'), makeNode('KSampler')])
)
expect(result).toBe(MODEL_VRAM_ESTIMATES.checkpoints + RUNTIME_OVERHEAD)
})
it('uses only the largest base model when multiple checkpoints exist', () => {
mockGetCategoryForNodeType.mockImplementation((type: string) => {
if (type === 'CheckpointLoaderSimple') return 'checkpoints'
if (type === 'UNETLoader') return 'diffusion_models'
return undefined
})
const result = estimateWorkflowVram(
makeGraph([makeNode('CheckpointLoaderSimple'), makeNode('UNETLoader')])
)
const largestBase = Math.max(
MODEL_VRAM_ESTIMATES.checkpoints,
MODEL_VRAM_ESTIMATES.diffusion_models
)
expect(result).toBe(largestBase + RUNTIME_OVERHEAD)
})
it('sums checkpoint + lora + controlnet correctly', () => {
mockGetCategoryForNodeType.mockImplementation((type: string) => {
const map: Record<string, string> = {
CheckpointLoaderSimple: 'checkpoints',
LoraLoader: 'loras',
ControlNetLoader: 'controlnet'
}
return map[type]
})
const result = estimateWorkflowVram(
makeGraph([
makeNode('CheckpointLoaderSimple'),
makeNode('LoraLoader'),
makeNode('ControlNetLoader')
])
)
expect(result).toBe(
MODEL_VRAM_ESTIMATES.checkpoints +
MODEL_VRAM_ESTIMATES.loras +
MODEL_VRAM_ESTIMATES.controlnet +
RUNTIME_OVERHEAD
)
})
it('handles unknown model categories with default estimate', () => {
mockGetCategoryForNodeType.mockReturnValue('some_unknown_category')
const result = estimateWorkflowVram(
makeGraph([makeNode('UnknownModelLoader')])
)
// Unknown category uses 500 MB default + runtime overhead
expect(result).toBe(500_000_000 + RUNTIME_OVERHEAD)
})
it('counts multiple unique loras separately', () => {
mockGetCategoryForNodeType.mockImplementation((type: string) =>
type === 'LoraLoader' ? 'loras' : undefined
)
mockGetAllNodeProviders.mockReturnValue([
{ nodeDef: { name: 'LoraLoader' }, key: 'lora_name' }
])
const result = estimateWorkflowVram(
makeGraph([
makeNode('LoraLoader', [
{ name: 'lora_name', value: 'lora_a.safetensors' }
]),
makeNode('LoraLoader', [
{ name: 'lora_name', value: 'lora_b.safetensors' }
])
])
)
expect(result).toBe(MODEL_VRAM_ESTIMATES.loras * 2 + RUNTIME_OVERHEAD)
})
})
})

View File

@@ -0,0 +1,139 @@
import type {
LGraph,
LGraphNode,
Subgraph
} from '@/lib/litegraph/src/litegraph'
import { useModelToNodeStore } from '@/stores/modelToNodeStore'
import { mapAllNodes } from '@/utils/graphTraversalUtil'
/**
* A model detected in a workflow graph, identified by the directory
* category it belongs to and the filename selected in its widget.
*/
export interface DetectedModel {
/** Model directory category (e.g. 'checkpoints', 'loras'). */
category: string
/** Selected model filename from the node's widget, if available. */
filename: string | undefined
}
/**
* Approximate VRAM consumption in bytes per model directory category.
* Values represent typical fp16 model sizes loaded into GPU memory.
*/
export const MODEL_VRAM_ESTIMATES: Record<string, number> = {
checkpoints: 4_500_000_000,
diffusion_models: 4_500_000_000,
loras: 200_000_000,
controlnet: 1_500_000_000,
vae: 350_000_000,
clip_vision: 600_000_000,
text_encoders: 1_200_000_000,
upscale_models: 200_000_000,
style_models: 500_000_000,
gligen: 500_000_000
}
/** Default VRAM estimate for unrecognised model categories. */
const DEFAULT_MODEL_VRAM = 500_000_000
/** Flat overhead for intermediate tensors and activations. */
export const RUNTIME_OVERHEAD = 500_000_000
/**
* Categories whose models act as the "base" diffusion backbone.
* Only the single largest base model is counted because ComfyUI
* does not keep multiple base models resident simultaneously.
*/
const BASE_MODEL_CATEGORIES = new Set(['checkpoints', 'diffusion_models'])
/**
* Extracts the widget value for the model input key from a graph node.
*
* @param node - The graph node to inspect
* @param category - The model category, used to look up the expected input key
* @returns The string widget value, or undefined if not found
*/
function getModelWidgetValue(
node: LGraphNode,
category: string
): string | undefined {
const store = useModelToNodeStore()
const providers = store.getAllNodeProviders(category)
for (const provider of providers) {
if (provider.nodeDef?.name !== node.type) continue
if (!provider.key) return undefined
const widget = node.widgets?.find((w) => w.name === provider.key)
if (widget?.value && typeof widget.value === 'string') {
return widget.value
}
}
return undefined
}
/**
* Detects all model-loading nodes in a graph hierarchy and returns
* a deduplicated list of models with their category and filename.
*
* @param graph - The root graph (or subgraph) to traverse
* @returns Array of unique detected models
*/
export function detectModelNodes(graph: LGraph | Subgraph): DetectedModel[] {
const store = useModelToNodeStore()
const raw = mapAllNodes(graph, (node) => {
if (!node.type) return undefined
const category = store.getCategoryForNodeType(node.type)
if (!category) return undefined
const filename = getModelWidgetValue(node, category)
return { category, filename } satisfies DetectedModel
})
const seen = new Set<string>()
return raw.filter((model) => {
const key = `${model.category}::${model.filename ?? ''}`
if (seen.has(key)) return false
seen.add(key)
return true
})
}
/**
* Estimates peak VRAM consumption (in bytes) for a workflow graph.
*
* The heuristic:
* 1. Detect all model-loading nodes in the graph.
* 2. For base model categories (checkpoints, diffusion_models), take only
* the largest single model — ComfyUI offloads others.
* 3. Sum all other model categories (LoRAs, ControlNets, VAEs, etc.)
* as they can be co-resident.
* 4. Add a flat runtime overhead for activations and intermediates.
*
* @param graph - The root graph to analyse
* @returns Estimated VRAM in bytes, or 0 if no models detected
*/
export function estimateWorkflowVram(
graph: LGraph | Subgraph | null | undefined
): number {
if (!graph) return 0
const models = detectModelNodes(graph)
if (models.length === 0) return 0
let baseCost = 0
let additionalCost = 0
for (const model of models) {
const estimate = MODEL_VRAM_ESTIMATES[model.category] ?? DEFAULT_MODEL_VRAM
if (BASE_MODEL_CATEGORIES.has(model.category)) {
baseCost = Math.max(baseCost, estimate)
} else {
additionalCost += estimate
}
}
return baseCost + additionalCost + RUNTIME_OVERHEAD
}

View File

@@ -1007,6 +1007,7 @@
"default": "Default",
"similarToCurrent": "Similar to Current"
},
"vramEstimateTooltip": "Estimated GPU memory required to run this workflow",
"error": {
"templateNotFound": "Template \"{templateName}\" not found"
}
@@ -1082,6 +1083,9 @@
"requiredNodesDetected": "Detected from workflow",
"requiredNodesManualPlaceholder": "Add custom node name…",
"requiredNodesManualLabel": "Additional custom nodes",
"vramLabel": "Estimated VRAM Requirement",
"vramAutoDetected": "Auto-detected from workflow:",
"vramManualOverride": "Manual override (GB):",
"difficulty": {
"beginner": "Beginner",
"intermediate": "Intermediate",
@@ -1130,7 +1134,8 @@
"galleryLabel": "Example Gallery",
"galleryHint": "Up to {max} example output images",
"uploadPrompt": "Click to upload",
"removeFile": "Remove"
"removeFile": "Remove",
"uploadingProgress": "Uploading… {percent}%"
},
"categoryAndTagging": {
"title": "Categories & Tags",
@@ -1148,6 +1153,7 @@
"workflowPreviewLabel": "Workflow Graph",
"videoPreviewLabel": "Video Preview",
"galleryLabel": "Gallery",
"vramLabel": "VRAM Requirement",
"notProvided": "Not provided",
"noneDetected": "None detected",
"correct": "Correct",

View File

@@ -8,9 +8,10 @@
v-show="!linearMode"
id="graph-canvas-container"
ref="graphCanvasContainerRef"
class="graph-canvas-container"
class="graph-canvas-container relative"
>
<GraphCanvas @ready="onGraphReady" />
<VramEstimateIndicator />
</div>
<LinearView v-if="linearMode" />
<BuilderToolbar v-if="appModeStore.isBuilderMode" />
@@ -46,6 +47,7 @@ import { runWhenGlobalIdle } from '@/base/common/async'
import MenuHamburger from '@/components/MenuHamburger.vue'
import UnloadWindowConfirmDialog from '@/components/dialog/UnloadWindowConfirmDialog.vue'
import GraphCanvas from '@/components/graph/GraphCanvas.vue'
import VramEstimateIndicator from '@/components/graph/VramEstimateIndicator.vue'
import GlobalToast from '@/components/toast/GlobalToast.vue'
import InviteAcceptedToast from '@/platform/workspace/components/toasts/InviteAcceptedToast.vue'
import RerouteMigrationToast from '@/components/toast/RerouteMigrationToast.vue'