From 06f85e2c792c626f2cab3cb4f94cd30d43e9347b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Mon, 9 Mar 2026 22:08:51 +0200 Subject: [PATCH 01/27] Fix text encoder lora loading for wrapped models (#12852) --- comfy/lora.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index f36ddb046..63ee85323 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models + if tp > 0 and not k.startswith("clip_"): + key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False From 814dab9f4636df22a36cbbad21e35ac7609a0ef2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 10 Mar 2026 10:03:22 +0800 Subject: [PATCH 02/27] Update workflow templates to v0.9.18 (#12857) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b1db1cf24..bb58f8d01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.11 +comfyui-workflow-templates==0.9.18 comfyui-embedded-docs==0.4.3 torch torchsde From 740d998c9cc821ca0a72b5b5d4b17aba1aec6b44 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:49:31 +0900 Subject: [PATCH 03/27] fix(manager): improve install guidance when comfyui-manager is not installed (#12810) --- main.py | 13 ++++++++++--- manager_requirements.txt | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 1977f9362..83a7244db 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ comfy.options.enable_args_parsing() import os import importlib.util +import shutil import importlib.metadata import folder_paths import time @@ -64,8 +65,15 @@ if __name__ == "__main__": def handle_comfyui_manager_unavailable(): - if not args.windows_standalone_build: - logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt") + uv_available = shutil.which("uv") is not None + + pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}" + msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}" + if uv_available: + msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}" + msg += "\n" + logging.warning(msg) args.enable_manager = False @@ -173,7 +181,6 @@ execute_prestartup_script() # Main code import asyncio -import shutil import threading import gc diff --git a/manager_requirements.txt b/manager_requirements.txt index c420cc48e..6bcc3fb50 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b1 +comfyui_manager==4.1b2 \ No newline at end of file From c4fb0271cd7fbddb2381372b1f7c1206d1dd58fc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:37:58 -0700 Subject: [PATCH 04/27] Add a way for nodes to add pre attn patches to flux model. (#12861) --- comfy/ldm/flux/layers.py | 15 ++++++++++++++- comfy/ldm/flux/math.py | 2 ++ comfy/ldm/flux/model.py | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8b3f500d7..e20d498f8 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module): del txt_k, img_k v = torch.cat((txt_v, img_v), dim=2) del txt_v, img_v + + extra_options["img_slice"] = [txt.shape[1], q.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # run actual attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v if "attn1_output_patch" in transformer_patches: - extra_options["img_slice"] = [txt.shape[1], attn.shape[1]] patch = transformer_patches["attn1_output_patch"] for p in patch: attn = p(attn, extra_options) @@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module): del qkv q, k = self.norm(q, k, v) + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 5e764bb46..824daf5e6 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def _apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]: + freqs_cis = freqs_cis[:, :, :x_.shape[2]] x_out = freqs_cis[..., 0] * x_[..., 0] x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4dcf7c5..00f12c031 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -170,7 +170,7 @@ class Flux(nn.Module): if "post_input" in patches: for p in patches["post_input"]: - out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) img = out["img"] txt = out["txt"] img_ids = out["img_ids"] From a912809c252f5a2d69c8ab4035fc262a578fdcee Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:50:10 -0700 Subject: [PATCH 05/27] model_detection: deep clone pre edited edited weights (#12862) Deep clone these weights as needed to avoid segfaulting when it tries to touch the original mmap. --- comfy/model_detection.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6eace4628..35a6822e3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,4 +1,5 @@ import json +import comfy.memory_management import comfy.supported_models import comfy.supported_models_base import comfy.utils @@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): new[:old_weight.shape[0]] = old_weight old_weight = new + if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled: + old_weight = old_weight.clone() + w = old_weight.narrow(offset[0], offset[1], offset[2]) else: + if comfy.memory_management.aimdo_enabled: + weight = weight.clone() old_weight = weight w = weight w[:] = fun(weight) From 535c16ce6e3d2634d6eb2fd17ecccb8d497e26a0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 9 Mar 2026 21:41:02 -0700 Subject: [PATCH 06/27] Widen OOM_EXCEPTION to AcceleratorError form (#12835) Pytorch only filters for OOMs in its own allocators however there are paths that can OOM on allocators made outside the pytorch allocators. These manifest as an AllocatorError as pytorch does not have universal error translation to its OOM type on exception. Handle it. A log I have for this also shows a double report of the error async, so call the async discarder to cleanup and make these OOMs look like OOMs. --- comfy/ldm/modules/attention.py | 3 ++- comfy/ldm/modules/diffusionmodules/model.py | 6 ++++-- comfy/ldm/modules/sub_quadratic_attention.py | 3 ++- comfy/model_management.py | 12 ++++++++++++ comfy/sd.py | 6 ++++-- comfy_extras/nodes_upscale_model.py | 3 ++- execution.py | 2 +- 7 files changed, 27 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 10d051325..b193fe5e8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) if first_op_done == False: model_management.soft_empty_cache(True) if cleared_cache == False: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 805592aa5..fcbaa074f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -258,7 +258,8 @@ def slice_attention(q, k, v): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) model_management.soft_empty_cache(True) steps *= 2 if steps > 128: @@ -314,7 +315,8 @@ def pytorch_attention(q, k, v): try: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") oom_fallback = True if oom_fallback: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fab145f1c..f982afc2b 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined torch.exp(attn_scores, out=attn_scores) diff --git a/comfy/model_management.py b/comfy/model_management.py index 07bc8ad67..81550c790 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,6 +270,18 @@ try: except: OOM_EXCEPTION = Exception +def is_oom(e): + if isinstance(e, OOM_EXCEPTION): + return True + if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: + discard_cuda_async_error() + return True + return False + +def raise_non_oom(e): + if not is_oom(e): + raise e + XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True if args.disable_xformers: diff --git a/comfy/sd.py b/comfy/sd.py index 888ef1e77..adcd67767 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -954,7 +954,8 @@ class VAE: if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. @@ -1029,7 +1030,8 @@ class VAE: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples[x:x + batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 97b9e948d..db4f9d231 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode): pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) oom = False - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) tile //= 2 if tile < 128: raise e diff --git a/execution.py b/execution.py index 7ccdbf93e..a7791efed 100644 --- a/execution.py +++ b/execution.py @@ -612,7 +612,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, logging.error(traceback.format_exc()) tips = "" - if isinstance(ex, comfy.model_management.OOM_EXCEPTION): + if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") From 8086468d2a1a5a6ed70fea3391e7fb9248ebc7da Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 10 Mar 2026 09:05:31 -0700 Subject: [PATCH 07/27] main: switch on faulthandler (#12868) When we get segfault bug reports we dont get much. Switch on pythons inbuilt tracer for segfault. --- main.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main.py b/main.py index 83a7244db..8905fd09a 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,7 @@ from app.logger import setup_logger import itertools import utils.extra_config from utils.mime_types import init_mime_types +import faulthandler import logging import sys from comfy_execution.progress import get_progress_state @@ -26,6 +27,8 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +faulthandler.enable(file=sys.stderr, all_threads=False) + import comfy_aimdo.control if enables_dynamic_vram(): From 3ad36d6be66b2af2a7c3dc9ab6936eebc6b98075 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 10 Mar 2026 17:09:12 -0700 Subject: [PATCH 08/27] Allow model patches to have a cleanup function. (#12878) The function gets called after sampling is finished. --- comfy/model_patcher.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 745384271..bc3a8f446 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -599,6 +599,27 @@ class ModelPatcher: return models + def model_patches_call_function(self, function_name="cleanup", arguments={}): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], function_name): + getattr(patch_list[i], function_name)(**arguments) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], function_name): + getattr(patch_list[k], function_name)(**arguments) + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, function_name): + getattr(wrap_func, function_name)(**arguments) + def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() @@ -1062,6 +1083,7 @@ class ModelPatcher: return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) def cleanup(self): + self.model_patches_call_function(function_name="cleanup") self.clean_hooks() if hasattr(self.model, "current_patcher"): self.model.current_patcher = None From 9642e4407b60b291744cc1d34501783cff6702e5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 10 Mar 2026 21:09:35 -0700 Subject: [PATCH 09/27] Add pre attention and post input patches to qwen image model. (#12879) --- comfy/ldm/qwen_image/model.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 6eb744286..0862f72f7 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -149,6 +149,9 @@ class Attention(nn.Module): seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + # Project and reshape to BHND format (batch, heads, seq, dim) img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() @@ -167,15 +170,22 @@ class Attention(nn.Module): joint_key = torch.cat([txt_key, img_key], dim=2) joint_value = torch.cat([txt_value, img_value], dim=2) - joint_query = apply_rope1(joint_query, image_rotary_emb) - joint_key = apply_rope1(joint_key, image_rotary_emb) - if encoder_hidden_states_mask is not None: attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask else: attn_mask = None + extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options) + joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask) + + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attn_mask, transformer_options=transformer_options, skip_reshape=True) @@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module): timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 @@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module): kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) if timestep_zero: if index > 0: timestep = torch.cat([timestep, timestep * 0], dim=0) timestep_zero_index = num_embeds + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() - del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + del ids, txt_ids, img_ids + transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): From 980621da83267beffcb84839a27101b7092256e7 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 11 Mar 2026 08:49:38 -0700 Subject: [PATCH 10/27] comfy-aimdo 0.2.10 (#12890) Comfy Aimdo 0.2.10 fixes the aimdo allocator hook for legacy cudaMalloc consumers. Some consumers of cudaMalloc assume implicit synchronization built in closed source logic inside cuda. This is preserved by passing through to cuda as-is and accouting after the fact as opposed to integrating these hooks with Aimdos VMA based allocator. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bb58f8d01..89cd994e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy filelock av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.9 +comfy-aimdo>=0.2.10 requests simpleeval>=1.0.0 blake3 From 3365008dfe5a7a46cbe76d8ad0d7efb054617733 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Wed, 11 Mar 2026 18:53:55 +0200 Subject: [PATCH 11/27] feat(api-nodes): add Reve Image nodes (#12848) --- comfy_api_nodes/apis/reve.py | 68 ++++++ comfy_api_nodes/nodes_reve.py | 395 +++++++++++++++++++++++++++++++++ comfy_api_nodes/util/client.py | 12 +- 3 files changed, 474 insertions(+), 1 deletion(-) create mode 100644 comfy_api_nodes/apis/reve.py create mode 100644 comfy_api_nodes/nodes_reve.py diff --git a/comfy_api_nodes/apis/reve.py b/comfy_api_nodes/apis/reve.py new file mode 100644 index 000000000..c6b5a69d8 --- /dev/null +++ b/comfy_api_nodes/apis/reve.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel, Field + + +class RevePostprocessingOperation(BaseModel): + process: str = Field(..., description="The postprocessing operation: upscale or remove_background.") + upscale_factor: int | None = Field( + None, + description="Upscale factor (2, 3, or 4). Only used when process is upscale.", + ge=2, + le=4, + ) + + +class ReveImageCreateRequest(BaseModel): + prompt: str = Field(...) + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageEditRequest(BaseModel): + edit_instruction: str = Field(...) + reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageRemixRequest(BaseModel): + prompt: str = Field(...) + reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageResponse(BaseModel): + image: str | None = Field(None, description="The base64 encoded image data.") + request_id: str | None = Field(None, description="A unique id for the request.") + credits_used: float | None = Field(None, description="The number of credits used for this request.") + version: str | None = Field(None, description="The specific model version used.") + content_violation: bool | None = Field( + None, description="Indicates whether the generated image violates the content policy." + ) diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py new file mode 100644 index 000000000..608d9f058 --- /dev/null +++ b/comfy_api_nodes/nodes_reve.py @@ -0,0 +1,395 @@ +from io import BytesIO + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.reve import ( + ReveImageCreateRequest, + ReveImageEditRequest, + ReveImageRemixRequest, + RevePostprocessingOperation, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + bytesio_to_image_tensor, + sync_op_raw, + tensor_to_base64_string, + validate_string, +) + + +def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None: + ops = [] + if upscale["upscale"] == "enabled": + ops.append( + RevePostprocessingOperation( + process="upscale", + upscale_factor=upscale["upscale_factor"], + ) + ) + if remove_background: + ops.append(RevePostprocessingOperation(process="remove_background")) + return ops or None + + +def _postprocessing_inputs(): + return [ + IO.DynamicCombo.Input( + "upscale", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Int.Input( + "upscale_factor", + default=2, + min=2, + max=4, + step=1, + tooltip="Upscale factor (2x, 3x, or 4x).", + ), + ], + ), + ], + tooltip="Upscale the generated image. May add additional cost.", + ), + IO.Boolean.Input( + "remove_background", + default=False, + tooltip="Remove the background from the generated image. May add additional cost.", + ), + ] + + +def _reve_price_extractor(headers: dict) -> float | None: + credits_used = headers.get("x-reve-credits-used") + if credits_used is not None: + return float(credits_used) / 524.48 + return None + + +def _reve_response_header_validator(headers: dict) -> None: + error_code = headers.get("x-reve-error-code") + if error_code: + raise ValueError(f"Reve API error: {error_code}") + if headers.get("x-reve-content-violation", "").lower() == "true": + raise ValueError("The generated image was flagged for content policy violation.") + + +def _model_inputs(versions: list[str], aspect_ratios: list[str]): + return [ + IO.DynamicCombo.Option( + version, + [ + IO.Combo.Input( + "aspect_ratio", + options=aspect_ratios, + tooltip="Aspect ratio of the output image.", + ), + IO.Int.Input( + "test_time_scaling", + default=1, + min=1, + max=5, + step=1, + tooltip="Higher values produce better images but cost more credits.", + advanced=True, + ), + ], + ) + for version in versions + ] + + +class ReveImageCreateNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageCreateNode", + display_name="Reve Image Create", + category="api node/image/Reve", + description="Generate images from text descriptions using Reve.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-create@20250915"], + aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for generation.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""", + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/create", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageCreateRequest( + prompt=prompt, + aspect_ratio=model["aspect_ratio"], + version=model["model"], + test_time_scaling=model["test_time_scaling"], + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageEditNode", + display_name="Reve Image Edit", + category="api node/image/Reve", + description="Edit images using natural language instructions with Reve.", + inputs=[ + IO.Image.Input("image", tooltip="The image to edit."), + IO.String.Input( + "edit_instruction", + multiline=True, + default="", + tooltip="Text description of how to edit the image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-edit@20250915", "reve-edit-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for editing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model"], + ), + expr=""" + ( + $isFast := $contains(widgets.model, "fast"); + $base := $isFast ? 0.01001 : 0.0572; + {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + edit_instruction: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(edit_instruction, min_length=1, max_length=2560) + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/edit", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageEditRequest( + edit_instruction=edit_instruction, + reference_image=tensor_to_base64_string(image), + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageRemixNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageRemixNode", + display_name="Reve Image Remix", + category="api node/image/Reve", + description="Combine reference images with text prompts to create new images using Reve.", + inputs=[ + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="image_", + min=1, + max=6, + ), + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. " + "May include XML img tags to reference specific images by index, " + "e.g. 0, 1, etc.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-remix@20250915", "reve-remix-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for remixing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model"], + ), + expr=""" + ( + $isFast := $contains(widgets.model, "fast"); + $base := $isFast ? 0.01001 : 0.0572; + {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + reference_images: IO.Autogrow.Type, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + if not reference_images: + raise ValueError("At least one reference image is required.") + ref_base64_list = [] + for key in reference_images: + ref_base64_list.append(tensor_to_base64_string(reference_images[key])) + if len(ref_base64_list) > 6: + raise ValueError("Maximum 6 reference images are allowed.") + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/remix", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageRemixRequest( + prompt=prompt, + reference_images=ref_base64_list, + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ReveImageCreateNode, + ReveImageEditNode, + ReveImageRemixNode, + ] + + +async def comfy_entrypoint() -> ReveExtension: + return ReveExtension() diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 79ffb77c1..9d730b81a 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -67,6 +67,7 @@ class _RequestConfig: progress_origin_ts: float | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None is_rate_limited: Callable[[int, Any], bool] | None = None + response_header_validator: Callable[[dict[str, str]], None] | None = None @dataclass @@ -202,11 +203,13 @@ async def sync_op_raw( monitor_progress: bool = True, max_retries_on_rate_limit: int = 16, is_rate_limited: Callable[[int, Any], bool] | None = None, + response_header_validator: Callable[[dict[str, str]], None] | None = None, ) -> dict[str, Any] | bytes: """ Make a single network request. - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). - If as_binary=True: returns bytes. + - response_header_validator: optional callback receiving response headers dict """ if isinstance(data, BaseModel): data = data.model_dump(exclude_none=True) @@ -232,6 +235,7 @@ async def sync_op_raw( price_extractor=price_extractor, max_retries_on_rate_limit=max_retries_on_rate_limit, is_rate_limited=is_rate_limited, + response_header_validator=response_header_validator, ) return await _request_base(cfg, expect_binary=as_binary) @@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total ) bytes_payload = bytes(buff) + resp_headers = {k.lower(): v for k, v in resp.headers.items()} + if cfg.price_extractor: + with contextlib.suppress(Exception): + extracted_price = cfg.price_extractor(resp_headers) + if cfg.response_header_validator: + cfg.response_header_validator(resp_headers) operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) request_logger.log_request_response( @@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): request_method=method, request_url=url, response_status_code=resp.status, - response_headers=dict(resp.headers), + response_headers=resp_headers, response_content=bytes_payload, ) return bytes_payload From 4f4f8659c205069f74da8ac47378a5b1c0e142ca Mon Sep 17 00:00:00 2001 From: Adi Borochov <58855640+adiborochov@users.noreply.github.com> Date: Wed, 11 Mar 2026 19:04:13 +0200 Subject: [PATCH 12/27] fix: guard torch.AcceleratorError for compatibility with torch < 2.8.0 (#12874) * fix: guard torch.AcceleratorError for compatibility with torch < 2.8.0 torch.AcceleratorError was introduced in PyTorch 2.8.0. Accessing it directly raises AttributeError on older versions. Use a try/except fallback at module load time, consistent with the existing pattern used for OOM_EXCEPTION. * fix: address review feedback for AcceleratorError compat - Fall back to RuntimeError instead of type(None) for ACCELERATOR_ERROR, consistent with OOM_EXCEPTION fallback pattern and valid for except clauses - Add "out of memory" message introspection for RuntimeError fallback case - Use RuntimeError directly in discard_cuda_async_error except clause --------- --- comfy/model_management.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 81550c790..81c89b180 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,10 +270,15 @@ try: except: OOM_EXCEPTION = Exception +try: + ACCELERATOR_ERROR = torch.AcceleratorError +except AttributeError: + ACCELERATOR_ERROR = RuntimeError + def is_oom(e): if isinstance(e, OOM_EXCEPTION): return True - if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: + if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()): discard_cuda_async_error() return True return False @@ -1275,7 +1280,7 @@ def discard_cuda_async_error(): b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) _ = a + b synchronize() - except torch.AcceleratorError: + except RuntimeError: #Dump it! We already know about it from the synchronous return pass From f6274c06b4e7bce8adbc1c60ae5a4c168825a614 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 11 Mar 2026 13:37:31 -0700 Subject: [PATCH 13/27] Fix issue with batch_size > 1 on some models. (#12892) --- comfy/ldm/flux/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index e20d498f8..e28d704b4 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor * m_mult else: for d in modulation_dims: - tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] + tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1] if m_add is not None: - tensor[:, d[0]:d[1]] += m_add[:, d[2]] + tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1] return tensor From abc87d36693b007bdbdab5ee753ccea6326acb34 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Thu, 12 Mar 2026 06:04:51 +0900 Subject: [PATCH 14/27] Bump comfyui-frontend-package to 1.41.15 (#12891) --------- Co-authored-by: Alexander Brown --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 89cd994e9..ffa5fa376 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.39.19 +comfyui-frontend-package==1.41.15 comfyui-workflow-templates==0.9.18 comfyui-embedded-docs==0.4.3 torch From 9ce4c3dd87c9c77dfe0371045fa920ce55e08973 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Thu, 12 Mar 2026 10:16:30 +0900 Subject: [PATCH 15/27] Bump comfyui-frontend-package to 1.41.16 (#12894) Co-authored-by: github-actions[bot] --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ffa5fa376..2272d121a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.41.15 +comfyui-frontend-package==1.41.16 comfyui-workflow-templates==0.9.18 comfyui-embedded-docs==0.4.3 torch From 8f9ea495713d4565dfe564e0c06f362bd627f902 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 11 Mar 2026 21:17:31 -0700 Subject: [PATCH 16/27] Bump comfy-kitchen version to 0.2.8 (#12895) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2272d121a..96cd0254f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy filelock av>=14.2.0 -comfy-kitchen>=0.2.7 +comfy-kitchen>=0.2.8 comfy-aimdo>=0.2.10 requests simpleeval>=1.0.0 From 44f1246c899ed188759f799dbd00c31def289114 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 12 Mar 2026 08:30:50 -0700 Subject: [PATCH 17/27] Support flux 2 klein kv cache model: Use the FluxKVCache node. (#12905) --- comfy/ldm/flux/model.py | 76 ++++++++++++++++++++++++++++++++------ comfy_extras/nodes_flux.py | 64 ++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 00f12c031..8e7912e6d 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -44,6 +44,22 @@ class FluxParams: txt_norm: bool = False +def invert_slices(slices, length): + sorted_slices = sorted(slices) + result = [] + current = 0 + + for start, end in sorted_slices: + if current < start: + result.append((current, start)) + current = max(current, end) + + if current < length: + result.append((current, length)) + + return result + + class Flux(nn.Module): """ Transformer model for flow matching on sequences. @@ -138,6 +154,7 @@ class Flux(nn.Module): y: Tensor, guidance: Tensor = None, control = None, + timestep_zero_index=None, transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: @@ -164,10 +181,6 @@ class Flux(nn.Module): txt = self.txt_norm(txt) txt = self.txt_in(txt) - vec_orig = vec - if self.params.global_modulation: - vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) - if "post_input" in patches: for p in patches["post_input"]: out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) @@ -182,6 +195,24 @@ class Flux(nn.Module): else: pe = None + vec_orig = vec + txt_vec = vec + extra_kwargs = {} + if timestep_zero_index is not None: + modulation_dims = [] + batch = vec.shape[0] // 2 + vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1) + invert = invert_slices(timestep_zero_index, img.shape[1]) + for s in invert: + modulation_dims.append((s[0], s[1], 0)) + for s in timestep_zero_index: + modulation_dims.append((s[0], s[1], 1)) + extra_kwargs["modulation_dims_img"] = modulation_dims + txt_vec = vec[:batch] + + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec)) + blocks_replace = patches_replace.get("dit", {}) transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" @@ -195,7 +226,8 @@ class Flux(nn.Module): vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) + transformer_options=args.get("transformer_options"), + **extra_kwargs) return out out = blocks_replace[("double_block", i)]({"img": img, @@ -213,7 +245,8 @@ class Flux(nn.Module): vec=vec, pe=pe, attn_mask=attn_mask, - transformer_options=transformer_options) + transformer_options=transformer_options, + **extra_kwargs) if control is not None: # Controlnet control_i = control.get("input") @@ -230,6 +263,12 @@ class Flux(nn.Module): if self.params.global_modulation: vec, _ = self.single_stream_modulation(vec_orig) + extra_kwargs = {} + if timestep_zero_index is not None: + lambda a: 0 if a == 0 else a + txt.shape[1] + modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims)) + extra_kwargs["modulation_dims"] = modulation_dims_combined + transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] @@ -242,7 +281,8 @@ class Flux(nn.Module): vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) + transformer_options=args.get("transformer_options"), + **extra_kwargs) return out out = blocks_replace[("single_block", i)]({"img": img, @@ -253,7 +293,7 @@ class Flux(nn.Module): {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs) if control is not None: # Controlnet control_o = control.get("output") @@ -264,7 +304,11 @@ class Flux(nn.Module): img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) + extra_kwargs = {} + if timestep_zero_index is not None: + extra_kwargs["modulation_dims"] = modulation_dims + + img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels) return img def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): @@ -312,13 +356,16 @@ class Flux(nn.Module): w_len = ((w_orig + (patch_size // 2)) // patch_size) img, img_ids = self.process_img(x, transformer_options=transformer_options) img_tokens = img.shape[1] + timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) + timestep_zero = ref_latents_method == "index_timestep_zero" for ref in ref_latents: - if ref_latents_method == "index": + if ref_latents_method in ("index", "index_timestep_zero"): index += self.params.ref_index_scale h_offset = 0 w_offset = 0 @@ -342,6 +389,13 @@ class Flux(nn.Module): kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) + if timestep_zero: + if index > 0: + timestep = torch.cat([timestep, timestep * 0], dim=0) + timestep_zero_index = [[img_tokens, img_ids.shape[1]]] + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) @@ -349,6 +403,6 @@ class Flux(nn.Module): for i in self.params.txt_ids_dims: txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) - out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index fe9552022..c366d0d5b 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -6,6 +6,7 @@ import comfy.model_management import torch import math import nodes +import comfy.ldm.flux.math class CLIPTextEncodeFlux(io.ComfyNode): @classmethod @@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode): sigmas = get_schedule(steps, round(seq_len)) return io.NodeOutput(sigmas) +class KV_Attn_Input: + def __init__(self): + self.cache = {} + + def __call__(self, q, k, v, extra_options, **kwargs): + reference_image_num_tokens = extra_options.get("reference_image_num_tokens", []) + if len(reference_image_num_tokens) == 0: + return {} + + ref_toks = sum(reference_image_num_tokens) + cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"]) + if cache_key in self.cache: + kk, vv = self.cache[cache_key] + self.set_cache = False + return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)} + + self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:]) + self.set_cache = True + return {"q": q, "k": k, "v": v} + + def cleanup(self): + self.cache = {} + + +class FluxKVCache(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="FluxKVCache", + display_name="Flux KV Cache", + description="Enables KV Cache optimization for reference images on Flux family models.", + category="", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to use KV Cache on."), + ], + outputs=[ + io.Model.Output(tooltip="The patched model with KV Cache enabled."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type) -> io.NodeOutput: + m = model.clone() + input_patch_obj = KV_Attn_Input() + + def model_input_patch(inputs): + if len(input_patch_obj.cache) > 0: + ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", [])) + if ref_image_tokens > 0: + img = inputs["img"] + inputs["img"] = img[:, :-ref_image_tokens] + return inputs + + m.set_model_attn1_patch(input_patch_obj) + m.set_model_post_input_patch(model_input_patch) + if hasattr(model.model.diffusion_model, "params"): + m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero") + else: + m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero") + + return io.NodeOutput(m) class FluxExtension(ComfyExtension): @override @@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension): FluxKontextMultiReferenceLatentMethod, EmptyFlux2LatentImage, Flux2Scheduler, + FluxKVCache, ] From 73d9599495e45c22ef3672176f34945deeea5444 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Thu, 12 Mar 2026 09:55:29 -0700 Subject: [PATCH 18/27] add painter node (#12294) * add painter node * use io.Color * code improve --------- Co-authored-by: guill --- comfy_extras/nodes_painter.py | 127 ++++++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 128 insertions(+) create mode 100644 comfy_extras/nodes_painter.py diff --git a/comfy_extras/nodes_painter.py b/comfy_extras/nodes_painter.py new file mode 100644 index 000000000..b9ecdf5ea --- /dev/null +++ b/comfy_extras/nodes_painter.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import hashlib +import os + +import numpy as np +import torch +from PIL import Image + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io, UI +from typing_extensions import override + + +def hex_to_rgb(hex_color: str) -> tuple[float, float, float]: + hex_color = hex_color.lstrip("#") + if len(hex_color) != 6: + return (0.0, 0.0, 0.0) + r = int(hex_color[0:2], 16) / 255.0 + g = int(hex_color[2:4], 16) / 255.0 + b = int(hex_color[4:6], 16) / 255.0 + return (r, g, b) + + +class PainterNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Painter", + display_name="Painter", + category="image", + inputs=[ + io.Image.Input( + "image", + optional=True, + tooltip="Optional base image to paint over", + ), + io.String.Input( + "mask", + default="", + socketless=True, + extra_dict={"widgetType": "PAINTER", "image_upload": True}, + ), + io.Int.Input( + "width", + default=512, + min=64, + max=4096, + step=64, + socketless=True, + extra_dict={"hidden": True}, + ), + io.Int.Input( + "height", + default=512, + min=64, + max=4096, + step=64, + socketless=True, + extra_dict={"hidden": True}, + ), + io.Color.Input("bg_color", default="#000000"), + ], + outputs=[ + io.Image.Output("IMAGE"), + io.Mask.Output("MASK"), + ], + ) + + @classmethod + def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput: + if image is not None: + base_image = image[:1] + h, w = base_image.shape[1], base_image.shape[2] + else: + h, w = height, width + r, g, b = hex_to_rgb(bg_color) + base_image = torch.zeros((1, h, w, 3), dtype=torch.float32) + base_image[0, :, :, 0] = r + base_image[0, :, :, 1] = g + base_image[0, :, :, 2] = b + + if mask and mask.strip(): + mask_path = folder_paths.get_annotated_filepath(mask) + painter_img = node_helpers.pillow(Image.open, mask_path) + painter_img = painter_img.convert("RGBA") + + if painter_img.size != (w, h): + painter_img = painter_img.resize((w, h), Image.LANCZOS) + + painter_np = np.array(painter_img).astype(np.float32) / 255.0 + painter_rgb = painter_np[:, :, :3] + painter_alpha = painter_np[:, :, 3:4] + + mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0) + + base_np = base_image[0].cpu().numpy() + composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha) + out_image = torch.from_numpy(composited).unsqueeze(0) + else: + mask_tensor = torch.zeros((1, h, w), dtype=torch.float32) + out_image = base_image + + return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image)) + + @classmethod + def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None): + if mask and mask.strip(): + mask_path = folder_paths.get_annotated_filepath(mask) + if os.path.exists(mask_path): + m = hashlib.sha256() + with open(mask_path, "rb") as f: + m.update(f.read()) + return m.digest().hex() + return "" + + + +class PainterExtension(ComfyExtension): + @override + async def get_node_list(self): + return [PainterNode] + + +async def comfy_entrypoint(): + return PainterExtension() diff --git a/nodes.py b/nodes.py index 0ef23b640..eb63f9d44 100644 --- a/nodes.py +++ b/nodes.py @@ -2450,6 +2450,7 @@ async def init_builtin_extra_nodes(): "nodes_nag.py", "nodes_sdpose.py", "nodes_math.py", + "nodes_painter.py", ] import_failed = [] From 3fa8c5686dc86fe4e63ad3ca84d71524792a17b1 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Thu, 12 Mar 2026 10:14:28 -0700 Subject: [PATCH 19/27] fix: use frontend-compatible format for Float gradient_stops (#12789) Co-authored-by: guill Co-authored-by: Jedrzej Kosinski --- comfy/comfy_types/node_typing.py | 4 ++-- comfy_api/latest/_io.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 92b1acbd5..57126fa4a 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict): """COMBO type only. Specifies the configuration for a multi-select widget. Available after ComfyUI frontend v1.13.4 https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987""" - gradient_stops: NotRequired[list[list[float]]] - """Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``).""" + gradient_stops: NotRequired[list[dict]] + """Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}.""" class HiddenInputTypeDict(TypedDict): diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 050031dc0..7ca8f4e0c 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -297,7 +297,7 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None, + display_mode: NumberDisplay=None, gradient_stops: list[dict]=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min From 712411d53919350ae5050cbdf7ed60fcc2b52cda Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Fri, 13 Mar 2026 03:16:54 +0800 Subject: [PATCH 20/27] chore: update workflow templates to v0.9.21 (#12908) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 96cd0254f..a2e53671e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.41.16 -comfyui-workflow-templates==0.9.18 +comfyui-workflow-templates==0.9.21 comfyui-embedded-docs==0.4.3 torch torchsde From 47e1e316c580ce6bf264cb069bffc10a50d3f167 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:54:38 -0700 Subject: [PATCH 21/27] Lower kv cache memory usage. (#12909) --- comfy_extras/nodes_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index c366d0d5b..3a23c7d04 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -248,7 +248,7 @@ class KV_Attn_Input: self.set_cache = False return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)} - self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:]) + self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone()) self.set_cache = True return {"q": q, "k": k, "v": v} From 8d9faaa181b9089cf8e4e00284443ef5c3405a12 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Thu, 12 Mar 2026 15:14:59 -0700 Subject: [PATCH 22/27] Update requirements.txt (#12910) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a2e53671e..511c62fee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.41.16 +comfyui-frontend-package==1.41.18 comfyui-workflow-templates==0.9.21 comfyui-embedded-docs==0.4.3 torch From af7b4a921d7abab7c852d7b5febb654be6e57eba Mon Sep 17 00:00:00 2001 From: Deep Mehta <42841935+deepme987@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:09:07 -0700 Subject: [PATCH 23/27] feat: Add CacheProvider API for external distributed caching (#12056) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add CacheProvider API for external distributed caching Introduces a public API for external cache providers, enabling distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods). New files: - comfy_execution/cache_provider.py: CacheProvider ABC, CacheContext/CacheValue dataclasses, thread-safe provider registry, serialization utilities Modified files: - comfy_execution/caching.py: Add provider hooks to BasicCache (_notify_providers_store, _check_providers_lookup), subcache exclusion, prompt ID propagation - execution.py: Add prompt lifecycle hooks (on_prompt_start/on_prompt_end) to PromptExecutor, set _current_prompt_id on caches Key features: - Local-first caching (check local before external for performance) - NaN detection to prevent incorrect external cache hits - Subcache exclusion (ephemeral subgraph results not cached externally) - Thread-safe provider snapshot caching - Graceful error handling (provider errors logged, never break execution) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * fix: use deterministic hash for cache keys instead of pickle Pickle serialization is NOT deterministic across Python sessions due to hash randomization affecting frozenset iteration order. This causes distributed caching to fail because different pods compute different hashes for identical cache keys. Fix: Use _canonicalize() + JSON serialization which ensures deterministic ordering regardless of Python's hash randomization. This is critical for cross-pod cache key consistency in Kubernetes deployments. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * test: add unit tests for CacheProvider API - Add comprehensive tests for _canonicalize deterministic ordering - Add tests for serialize_cache_key hash consistency - Add tests for contains_nan utility - Add tests for estimate_value_size - Add tests for provider registry (register, unregister, clear) - Move json import to top-level (fix inline import) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * style: remove unused imports in test_cache_provider.py 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * fix: move _torch_available before usage and use importlib.util.find_spec Fixes ruff F821 (undefined name) and F401 (unused import) errors. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * fix: use hashable types in frozenset test and add dict test Frozensets can only contain hashable types, so use nested frozensets instead of dicts. Added separate test for dict handling via serialize_cache_key. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * refactor: expose CacheProvider API via comfy_api.latest.Caching - Add Caching class to comfy_api/latest/__init__.py that re-exports from comfy_execution.cache_provider (source of truth) - Fix docstring: "Skip large values" instead of "Skip small values" (small compute-heavy values are good cache targets) - Maintain backward compatibility: comfy_execution.cache_provider imports still work Usage: from comfy_api.latest import Caching class MyProvider(Caching.CacheProvider): def on_lookup(self, context): ... def on_store(self, context, value): ... Caching.register_provider(MyProvider()) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * docs: clarify should_cache filtering criteria Change docstring from "Skip large values" to "Skip if download time > compute time" which better captures the cost/benefit tradeoff for external caching. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * docs: make should_cache docstring implementation-agnostic Remove prescriptive filtering suggestions - let implementations decide their own caching logic based on their use case. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * feat: add optional ui field to CacheValue - Add ui field to CacheValue dataclass (default None) - Pass ui when creating CacheValue for external providers - Use result.ui (or default {}) when returning from external cache lookup This allows external cache implementations to store/retrieve UI data if desired, while remaining optional for implementations that skip it. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * refactor: rename _is_cacheable_value to _is_external_cacheable_value Clearer name since objects are also cached locally - this specifically checks for external caching eligibility. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude * refactor: async CacheProvider API + reduce public surface - Make on_lookup/on_store async on CacheProvider ABC - Simplify CacheContext: replace cache_key + cache_key_bytes with cache_key_hash (str hex digest) - Make registry/utility functions internal (_prefix) - Trim comfy_api.latest.Caching exports to core API only - Make cache get/set async throughout caching.py hierarchy - Use asyncio.create_task for fire-and-forget on_store - Add NaN gating before provider calls in Core - Add await to 5 cache call sites in execution.py Co-Authored-By: Claude Opus 4.6 * fix: remove unused imports (ruff) and update tests for internal API - Remove unused CacheContext and _serialize_cache_key imports from caching.py (now handled by _build_context helper) - Update test_cache_provider.py to use _-prefixed internal names - Update tests for new CacheContext.cache_key_hash field (str) - Make MockCacheProvider methods async to match ABC Co-Authored-By: Claude Opus 4.6 * fix: address coderabbit review feedback - Add try/except to _build_context, return None when hash fails - Return None from _serialize_cache_key on total failure (no id()-based fallback) - Replace hex-like test literal with non-secret placeholder Co-Authored-By: Claude Opus 4.6 * fix: use _-prefixed imports in _notify_prompt_lifecycle The lifecycle notification method was importing the old non-prefixed names (has_cache_providers, get_cache_providers, logger) which no longer exist after the API cleanup. Co-Authored-By: Claude Opus 4.6 * fix: add sync get_local/set_local for graph traversal ExecutionList in graph.py calls output_cache.get() and .set() from sync methods (is_cached, cache_link, get_cache). These cannot await the now-async get/set. Add get_local/set_local that bypass external providers and only access the local dict — which is all graph traversal needs. Co-Authored-By: Claude Opus 4.6 * chore: remove cloud-specific language from cache provider API Make all docstrings and comments generic for the OSS codebase. Remove references to Kubernetes, Redis, GCS, pods, and other infrastructure-specific terminology. Co-Authored-By: Claude Opus 4.6 * style: align documentation with codebase conventions Strip verbose docstrings and section banners to match existing minimal documentation style used throughout the codebase. Co-Authored-By: Claude Opus 4.6 * fix: add usage example to Caching class, remove pickle fallback - Add docstring with usage example to Caching class matching the convention used by sibling APIs (Execution.set_progress, ComfyExtension) - Remove non-deterministic pickle fallback from _serialize_cache_key; return None on JSON failure instead of producing unretrievable hashes - Move cache_provider imports to top of execution.py (no circular dep) Co-Authored-By: Claude Opus 4.6 * refactor: move public types to comfy_api, eager provider snapshot Address review feedback: - Move CacheProvider/CacheContext/CacheValue definitions to comfy_api/latest/_caching.py (source of truth for public API) - comfy_execution/cache_provider.py re-exports types from there - Build _providers_snapshot eagerly on register/unregister instead of lazy memoization in _get_cache_providers Co-Authored-By: Claude Opus 4.6 * fix: generalize self-inequality check, fail-closed canonicalization Address review feedback from guill: - Rename _contains_nan to _contains_self_unequal, use not (x == x) instead of math.isnan to catch any self-unequal value - Remove Unhashable and repr() fallbacks from _canonicalize; raise ValueError for unknown types so _serialize_cache_key returns None and external caching is skipped (fail-closed) - Update tests for renamed function and new fail-closed behavior Co-Authored-By: Claude Opus 4.6 * fix: suppress ruff F401 for re-exported CacheContext CacheContext is imported from _caching and re-exported for use by caching.py. Add noqa comment to satisfy the linter. Co-Authored-By: Claude Opus 4.6 * fix: enable external caching for subcache (expanded) nodes Subcache nodes (from node expansion) now participate in external provider store/lookup. Previously skipped to avoid duplicates, but the cost of missing partial-expansion cache hits outweighs redundant stores — especially with looping behavior on the horizon. Co-Authored-By: Claude Opus 4.6 * fix: wrap register/unregister as explicit static methods Define register_provider and unregister_provider as wrapper functions in the Caching class instead of re-importing. This locks the public API signature in comfy_api/ so internal changes can't accidentally break it. Co-Authored-By: Claude Opus 4.6 * fix: use debug-level logging for provider registration Co-Authored-By: Claude Opus 4.6 * fix: follow ProxiedSingleton pattern for Caching class Add Caching as a nested class inside ComfyAPI_latest inheriting from ProxiedSingleton with async instance methods, matching the Execution and NodeReplacement patterns. Retains standalone Caching class for direct import convenience. Co-Authored-By: Claude Opus 4.6 * fix: inline registration logic in Caching class Follow the Execution/NodeReplacement pattern — the public API methods contain the actual logic operating on cache_provider module state, not wrapper functions delegating to free functions. Co-Authored-By: Claude Opus 4.6 * fix: single Caching definition inside ComfyAPI_latest Remove duplicate standalone Caching class. Define it once as a nested class in ComfyAPI_latest (matching Execution/NodeReplacement pattern), with a module-level alias for import convenience. Co-Authored-By: Claude Opus 4.6 * fix: remove prompt_id from CacheContext, type-safe canonicalization Remove prompt_id from CacheContext — it's not relevant for cache matching and added unnecessary plumbing (_current_prompt_id on every cache). Lifecycle hooks still receive prompt_id directly. Include type name in canonicalized primitives so that int 7 and str "7" produce distinct hashes. Also canonicalize dict keys properly instead of str() coercion. Co-Authored-By: Claude Opus 4.6 * fix: address review feedback on cache provider API - Hold references to pending store tasks to prevent "Task was destroyed but it is still pending" warnings (bigcat88) - Parallel cache lookups with asyncio.gather instead of sequential awaits for better performance (bigcat88) - Delegate Caching.register/unregister_provider to existing functions in cache_provider.py instead of reimplementing (bigcat88) Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude --- comfy_api/latest/__init__.py | 35 ++ comfy_api/latest/_caching.py | 42 ++ comfy_execution/cache_provider.py | 138 ++++++ comfy_execution/caching.py | 177 +++++++- comfy_execution/graph.py | 6 +- execution.py | 141 +++--- .../execution_test/test_cache_provider.py | 403 ++++++++++++++++++ 7 files changed, 859 insertions(+), 83 deletions(-) create mode 100644 comfy_api/latest/_caching.py create mode 100644 comfy_execution/cache_provider.py create mode 100644 tests-unit/execution_test/test_cache_provider.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index f2399422b..04973fea0 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase): super().__init__() self.node_replacement = self.NodeReplacement() self.execution = self.Execution() + self.caching = self.Caching() class NodeReplacement(ProxiedSingleton): async def register(self, node_replace: io.NodeReplace) -> None: @@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase): image=to_display, ) + class Caching(ProxiedSingleton): + """ + External cache provider API for sharing cached node outputs + across ComfyUI instances. + + Example:: + + from comfy_api.latest import Caching + + class MyCacheProvider(Caching.CacheProvider): + async def on_lookup(self, context): + ... # check external storage + + async def on_store(self, context, value): + ... # store to external storage + + Caching.register_provider(MyCacheProvider()) + """ + from ._caching import CacheProvider, CacheContext, CacheValue + + async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Register an external cache provider. Providers are called in registration order.""" + from comfy_execution.cache_provider import register_cache_provider + register_cache_provider(provider) + + async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Unregister a previously registered cache provider.""" + from comfy_execution.cache_provider import unregister_cache_provider + unregister_cache_provider(provider) + class ComfyExtension(ABC): async def on_load(self) -> None: """ @@ -116,6 +147,9 @@ class Types: VOXEL = VOXEL File3D = File3D + +Caching = ComfyAPI_latest.Caching + ComfyAPI = ComfyAPI_latest # Create a synchronous version of the API @@ -135,6 +169,7 @@ __all__ = [ "Input", "InputImpl", "Types", + "Caching", "ComfyExtension", "io", "IO", diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py new file mode 100644 index 000000000..30c8848cd --- /dev/null +++ b/comfy_api/latest/_caching.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class CacheContext: + node_id: str + class_type: str + cache_key_hash: str # SHA256 hex digest + + +@dataclass +class CacheValue: + outputs: list + ui: dict = None + + +class CacheProvider(ABC): + """Abstract base class for external cache providers. + Exceptions from provider methods are caught by the caller and never break execution. + """ + + @abstractmethod + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + """Called on local cache miss. Return CacheValue if found, None otherwise.""" + pass + + @abstractmethod + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + """Called after local store. Dispatched via asyncio.create_task.""" + pass + + def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: + """Return False to skip external caching for this node. Default: True.""" + return True + + def on_prompt_start(self, prompt_id: str) -> None: + pass + + def on_prompt_end(self, prompt_id: str) -> None: + pass diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py new file mode 100644 index 000000000..d455d08e8 --- /dev/null +++ b/comfy_execution/cache_provider.py @@ -0,0 +1,138 @@ +from typing import Any, Optional, Tuple, List +import hashlib +import json +import logging +import threading + +# Public types — source of truth is comfy_api.latest._caching +from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported) + +_logger = logging.getLogger(__name__) + + +_providers: List[CacheProvider] = [] +_providers_lock = threading.Lock() +_providers_snapshot: Tuple[CacheProvider, ...] = () + + +def register_cache_provider(provider: CacheProvider) -> None: + """Register an external cache provider. Providers are called in registration order.""" + global _providers_snapshot + with _providers_lock: + if provider in _providers: + _logger.warning(f"Provider {provider.__class__.__name__} already registered") + return + _providers.append(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Registered cache provider: {provider.__class__.__name__}") + + +def unregister_cache_provider(provider: CacheProvider) -> None: + global _providers_snapshot + with _providers_lock: + try: + _providers.remove(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}") + except ValueError: + _logger.warning(f"Provider {provider.__class__.__name__} was not registered") + + +def _get_cache_providers() -> Tuple[CacheProvider, ...]: + return _providers_snapshot + + +def _has_cache_providers() -> bool: + return bool(_providers_snapshot) + + +def _clear_cache_providers() -> None: + global _providers_snapshot + with _providers_lock: + _providers.clear() + _providers_snapshot = () + + +def _canonicalize(obj: Any) -> Any: + # Convert to canonical JSON-serializable form with deterministic ordering. + # Frozensets have non-deterministic iteration order between Python sessions. + # Raises ValueError for non-cacheable types (Unhashable, unknown) so that + # _serialize_cache_key returns None and external caching is skipped. + if isinstance(obj, frozenset): + return ("__frozenset__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, set): + return ("__set__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, tuple): + return ("__tuple__", [_canonicalize(item) for item in obj]) + elif isinstance(obj, list): + return [_canonicalize(item) for item in obj] + elif isinstance(obj, dict): + return {"__dict__": sorted( + [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], + key=lambda x: json.dumps(x, sort_keys=True) + )} + elif isinstance(obj, (int, float, str, bool, type(None))): + return (type(obj).__name__, obj) + elif isinstance(obj, bytes): + return ("__bytes__", obj.hex()) + else: + raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") + + +def _serialize_cache_key(cache_key: Any) -> Optional[str]: + # Returns deterministic SHA256 hex digest, or None on failure. + # Uses JSON (not pickle) because pickle is non-deterministic across sessions. + try: + canonical = _canonicalize(cache_key) + json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) + return hashlib.sha256(json_str.encode('utf-8')).hexdigest() + except Exception as e: + _logger.warning(f"Failed to serialize cache key: {e}") + return None + + +def _contains_self_unequal(obj: Any) -> bool: + # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will + # never hit locally, but serialized form would match externally. Skip these. + try: + if not (obj == obj): + return True + except Exception: + return True + if isinstance(obj, (frozenset, tuple, list, set)): + return any(_contains_self_unequal(item) for item in obj) + if isinstance(obj, dict): + return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) + if hasattr(obj, 'value'): + return _contains_self_unequal(obj.value) + return False + + +def _estimate_value_size(value: CacheValue) -> int: + try: + import torch + except ImportError: + return 0 + + total = 0 + + def estimate(obj): + nonlocal total + if isinstance(obj, torch.Tensor): + total += obj.numel() * obj.element_size() + elif isinstance(obj, dict): + for v in obj.values(): + estimate(v) + elif isinstance(obj, (list, tuple)): + for item in obj: + estimate(item) + + for output in value.outputs: + estimate(output) + return total diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..750bddf2e 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,3 +1,4 @@ +import asyncio import bisect import gc import itertools @@ -154,6 +155,7 @@ class BasicCache: self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} + self._pending_store_tasks: set = set() async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt @@ -196,18 +198,134 @@ class BasicCache: def poll(self, **kwargs): pass - def _set_immediate(self, node_id, value): - assert self.initialized - cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value - - def _get_immediate(self, node_id): + def get_local(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: return self.cache[cache_key] - else: + return None + + def set_local(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + async def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + await self._notify_providers_store(node_id, cache_key, value) + + async def _get_immediate(self, node_id): + if not self.initialized: + return None + cache_key = self.cache_key_set.get_data_key(node_id) + + if cache_key in self.cache: + return self.cache[cache_key] + + external_result = await self._check_providers_lookup(node_id, cache_key) + if external_result is not None: + self.cache[cache_key] = external_result + return external_result + + return None + + async def _notify_providers_store(self, node_id, cache_key, value): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not _has_cache_providers(): + return + if not self._is_external_cacheable_value(value): + return + if _contains_self_unequal(cache_key): + return + + context = self._build_context(node_id, cache_key) + if context is None: + return + cache_value = CacheValue(outputs=value.outputs, ui=value.ui) + + for provider in _get_cache_providers(): + try: + if provider.should_cache(context, cache_value): + task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value)) + self._pending_store_tasks.add(task) + task.add_done_callback(self._pending_store_tasks.discard) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + + @staticmethod + async def _safe_provider_store(provider, context, cache_value): + from comfy_execution.cache_provider import _logger + try: + await provider.on_store(context, cache_value) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") + + async def _check_providers_lookup(self, node_id, cache_key): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not _has_cache_providers(): + return None + if _contains_self_unequal(cache_key): + return None + + context = self._build_context(node_id, cache_key) + if context is None: + return None + + for provider in _get_cache_providers(): + try: + if not provider.should_cache(context): + continue + result = await provider.on_lookup(context) + if result is not None: + if not isinstance(result, CacheValue): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") + continue + if not isinstance(result.outputs, (list, tuple)): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") + continue + from execution import CacheEntry + return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs)) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") + + return None + + def _is_external_cacheable_value(self, value): + return hasattr(value, 'outputs') and hasattr(value, 'ui') + + def _get_class_type(self, node_id): + if not self.initialized or not self.dynprompt: + return '' + try: + return self.dynprompt.get_node(node_id).get('class_type', '') + except Exception: + return '' + + def _build_context(self, node_id, cache_key): + from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger + try: + cache_key_hash = _serialize_cache_key(cache_key) + if cache_key_hash is None: + return None + return CacheContext( + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key_hash=cache_key_hash, + ) + except Exception as e: + _logger.warning(f"Failed to build cache context for node {node_id}: {e}") return None async def _ensure_subcache(self, node_id, children_ids): @@ -257,16 +375,27 @@ class HierarchicalCache(BasicCache): return None return cache - def get(self, node_id): + async def get(self, node_id): cache = self._get_cache_for(node_id) if cache is None: return None - return cache._get_immediate(node_id) + return await cache._get_immediate(node_id) - def set(self, node_id, value): + def get_local(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return BasicCache.get_local(cache, node_id) + + async def set(self, node_id, value): cache = self._get_cache_for(node_id) assert cache is not None - cache._set_immediate(node_id, value) + await cache._set_immediate(node_id, value) + + def set_local(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + BasicCache.set_local(cache, node_id, value) async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) @@ -287,10 +416,16 @@ class NullCache: def poll(self, **kwargs): pass - def get(self, node_id): + async def get(self, node_id): return None - def set(self, node_id, value): + def get_local(self, node_id): + return None + + async def set(self, node_id, value): + pass + + def set_local(self, node_id, value): pass async def ensure_subcache_for(self, node_id, children_ids): @@ -322,18 +457,18 @@ class LRUCache(BasicCache): del self.children[key] self._clean_subcaches() - def get(self, node_id): + async def get(self, node_id): self._mark_used(node_id) - return self._get_immediate(node_id) + return await self._get_immediate(node_id) def _mark_used(self, node_id): cache_key = self.cache_key_set.get_data_key(node_id) if cache_key is not None: self.used_generation[cache_key] = self.generation - def set(self, node_id, value): + async def set(self, node_id, value): self._mark_used(node_id) - return self._set_immediate(node_id, value) + return await self._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes @@ -373,13 +508,13 @@ class RAMPressureCache(LRUCache): def clean_unused(self): self._clean_subcaches() - def set(self, node_id, value): + async def set(self, node_id, value): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - super().set(node_id, value) + await super().set(node_id, value) - def get(self, node_id): + async def get(self, node_id): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - return super().get(node_id) + return await super().get(node_id) def poll(self, ram_headroom): def _ram_gb(): diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 9d170b16e..c47f3c79b 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners = {} def is_cached(self, node_id): - return self.output_cache.get(node_id) is not None + return self.output_cache.get_local(node_id) is not None def cache_link(self, from_node_id, to_node_id): if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} - self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id) if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) @@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort): if value is None: return None #Write back to the main cache on touch. - self.output_cache.set(from_node_id, value) + self.output_cache.set_local(from_node_id, value) return value def cache_update(self, node_id, value): diff --git a/execution.py b/execution.py index a7791efed..a8e8fc59f 100644 --- a/execution.py +++ b/execution.py @@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io +from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger class ExecutionResult(Enum): @@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - cached = caches.outputs.get(unique_id) + cached = await caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: cached_ui = cached.ui or {} @@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) - obj = caches.objects.get(unique_id) + obj = await caches.objects.get(unique_id) if obj is None: obj = class_def() - caches.objects.set(unique_id, obj) + await caches.objects.set(unique_id, obj) if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) execution_list.cache_update(unique_id, cache_entry) - caches.outputs.set(unique_id, cache_entry) + await caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -684,6 +685,19 @@ class PromptExecutor: } self.add_message("execution_error", mes, broadcast=False) + def _notify_prompt_lifecycle(self, event: str, prompt_id: str): + if not _has_cache_providers(): + return + + for provider in _get_cache_providers(): + try: + if event == "start": + provider.on_prompt_start(prompt_id) + elif event == "end": + provider.on_prompt_end(prompt_id) + except Exception as e: + _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -700,66 +714,75 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - with torch.inference_mode(): - dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) - add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) - for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) - cache.clean_unused() + self._notify_prompt_lifecycle("start", prompt_id) - cached_nodes = [] - for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: - cached_nodes.append(node_id) + try: + with torch.inference_mode(): + dynamic_prompt = DynamicPrompt(prompt) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) - pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results - ui_node_outputs = {} - executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) - current_outputs = self.caches.outputs.all_node_ids() - for node_id in list(execute_outputs): - execution_list.add_node(node_id) + node_ids = list(prompt.keys()) + cache_results = await asyncio.gather( + *(self.caches.outputs.get(node_id) for node_id in node_ids) + ) + cached_nodes = [ + node_id for node_id, result in zip(node_ids, cache_results) + if result is not None + ] - while not execution_list.is_empty(): - node_id, error, ex = await execution_list.stage_node_execution() - if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break + comfy.model_management.cleanup_models_gc() + self.add_message("execution_cached", + { "nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False) + pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} + executed = set() + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() + for node_id in list(execute_outputs): + execution_list.add_node(node_id) - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) - self.success = result != ExecutionResult.FAILURE - if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break - elif result == ExecutionResult.PENDING: - execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: - execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) - else: - # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + while not execution_list.is_empty(): + node_id, error, ex = await execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break - ui_outputs = {} - meta_outputs = {} - for node_id, ui_info in ui_node_outputs.items(): - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] - self.history_result = { - "outputs": ui_outputs, - "meta": meta_outputs, - } - self.server.last_node_id = None - if comfy.model_management.DISABLE_SMART_MEMORY: - comfy.model_management.unload_all_models() + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + self.success = result != ExecutionResult.FAILURE + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + else: + # Only execute when the while-loop ends without break + self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + + ui_outputs = {} + meta_outputs = {} + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } + self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() + finally: + self._notify_prompt_lifecycle("end", prompt_id) async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py new file mode 100644 index 000000000..ac3814746 --- /dev/null +++ b/tests-unit/execution_test/test_cache_provider.py @@ -0,0 +1,403 @@ +"""Tests for external cache provider API.""" + +import importlib.util +import pytest +from typing import Optional + + +def _torch_available() -> bool: + """Check if PyTorch is available.""" + return importlib.util.find_spec("torch") is not None + + +from comfy_execution.cache_provider import ( + CacheProvider, + CacheContext, + CacheValue, + register_cache_provider, + unregister_cache_provider, + _get_cache_providers, + _has_cache_providers, + _clear_cache_providers, + _serialize_cache_key, + _contains_self_unequal, + _estimate_value_size, + _canonicalize, +) + + +class TestCanonicalize: + """Test _canonicalize function for deterministic ordering.""" + + def test_frozenset_ordering_is_deterministic(self): + """Frozensets should produce consistent canonical form regardless of iteration order.""" + # Create two frozensets with same content + fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) + fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_nested_frozenset_ordering(self): + """Nested frozensets should also be deterministically ordered.""" + inner1 = frozenset([1, 2, 3]) + inner2 = frozenset([3, 2, 1]) + + fs1 = frozenset([("key", inner1)]) + fs2 = frozenset([("key", inner2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_dict_ordering(self): + """Dicts should be sorted by key.""" + d1 = {"z": 1, "a": 2, "m": 3} + d2 = {"a": 2, "m": 3, "z": 1} + + result1 = _canonicalize(d1) + result2 = _canonicalize(d2) + + assert result1 == result2 + + def test_tuple_preserved(self): + """Tuples should be marked and preserved.""" + t = (1, 2, 3) + result = _canonicalize(t) + + assert result[0] == "__tuple__" + + def test_list_preserved(self): + """Lists should be recursively canonicalized.""" + lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] + result = _canonicalize(lst) + + # First element should be canonicalized dict + assert "__dict__" in result[0] + # Second element should be canonicalized frozenset + assert result[1][0] == "__frozenset__" + + def test_primitives_include_type(self): + """Primitive types should include type name for disambiguation.""" + assert _canonicalize(42) == ("int", 42) + assert _canonicalize(3.14) == ("float", 3.14) + assert _canonicalize("hello") == ("str", "hello") + assert _canonicalize(True) == ("bool", True) + assert _canonicalize(None) == ("NoneType", None) + + def test_int_and_str_distinguished(self): + """int 7 and str '7' must produce different canonical forms.""" + assert _canonicalize(7) != _canonicalize("7") + + def test_bytes_converted(self): + """Bytes should be converted to hex string.""" + b = b"\x00\xff" + result = _canonicalize(b) + + assert result[0] == "__bytes__" + assert result[1] == "00ff" + + def test_set_ordering(self): + """Sets should be sorted like frozensets.""" + s1 = {3, 1, 2} + s2 = {1, 2, 3} + + result1 = _canonicalize(s1) + result2 = _canonicalize(s2) + + assert result1 == result2 + assert result1[0] == "__set__" + + def test_unknown_type_raises(self): + """Unknown types should raise ValueError (fail-closed).""" + class CustomObj: + pass + with pytest.raises(ValueError): + _canonicalize(CustomObj()) + + def test_object_with_value_attr_raises(self): + """Objects with .value attribute (Unhashable-like) should raise ValueError.""" + class FakeUnhashable: + def __init__(self): + self.value = float('nan') + with pytest.raises(ValueError): + _canonicalize(FakeUnhashable()) + + +class TestSerializeCacheKey: + """Test _serialize_cache_key for deterministic hashing.""" + + def test_same_content_same_hash(self): + """Same content should produce same hash.""" + key1 = frozenset([("node_1", frozenset([("input", "value")]))]) + key2 = frozenset([("node_1", frozenset([("input", "value")]))]) + + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) + + assert hash1 == hash2 + + def test_different_content_different_hash(self): + """Different content should produce different hash.""" + key1 = frozenset([("node_1", "value_a")]) + key2 = frozenset([("node_1", "value_b")]) + + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) + + assert hash1 != hash2 + + def test_returns_hex_string(self): + """Should return hex string (SHA256 hex digest).""" + key = frozenset([("test", 123)]) + result = _serialize_cache_key(key) + + assert isinstance(result, str) + assert len(result) == 64 # SHA256 hex digest is 64 chars + + def test_complex_nested_structure(self): + """Complex nested structures should hash deterministically.""" + # Note: frozensets can only contain hashable types, so we use + # nested frozensets of tuples to represent dict-like structures + key = frozenset([ + ("node_1", frozenset([ + ("input_a", ("tuple", "value")), + ("input_b", frozenset([("nested", "dict")])), + ])), + ("node_2", frozenset([ + ("param", 42), + ])), + ]) + + # Hash twice to verify determinism + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) + + assert hash1 == hash2 + + def test_dict_in_cache_key(self): + """Dicts passed directly to _serialize_cache_key should work.""" + key = {"node_1": {"input": "value"}, "node_2": 42} + + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) + + assert hash1 == hash2 + assert isinstance(hash1, str) + assert len(hash1) == 64 + + def test_unknown_type_returns_none(self): + """Non-cacheable types should return None (fail-closed).""" + class CustomObj: + pass + assert _serialize_cache_key(CustomObj()) is None + + +class TestContainsSelfUnequal: + """Test _contains_self_unequal utility function.""" + + def test_nan_float_detected(self): + """NaN floats should be detected (not equal to itself).""" + assert _contains_self_unequal(float('nan')) is True + + def test_regular_float_not_detected(self): + """Regular floats are equal to themselves.""" + assert _contains_self_unequal(3.14) is False + assert _contains_self_unequal(0.0) is False + assert _contains_self_unequal(-1.5) is False + + def test_infinity_not_detected(self): + """Infinity is equal to itself.""" + assert _contains_self_unequal(float('inf')) is False + assert _contains_self_unequal(float('-inf')) is False + + def test_nan_in_list(self): + """NaN in list should be detected.""" + assert _contains_self_unequal([1, 2, float('nan'), 4]) is True + assert _contains_self_unequal([1, 2, 3, 4]) is False + + def test_nan_in_tuple(self): + """NaN in tuple should be detected.""" + assert _contains_self_unequal((1, float('nan'))) is True + assert _contains_self_unequal((1, 2, 3)) is False + + def test_nan_in_frozenset(self): + """NaN in frozenset should be detected.""" + assert _contains_self_unequal(frozenset([1, float('nan')])) is True + assert _contains_self_unequal(frozenset([1, 2, 3])) is False + + def test_nan_in_dict_value(self): + """NaN in dict value should be detected.""" + assert _contains_self_unequal({"key": float('nan')}) is True + assert _contains_self_unequal({"key": 42}) is False + + def test_nan_in_nested_structure(self): + """NaN in deeply nested structure should be detected.""" + nested = {"level1": [{"level2": (1, 2, float('nan'))}]} + assert _contains_self_unequal(nested) is True + + def test_non_numeric_types(self): + """Non-numeric types should not be self-unequal.""" + assert _contains_self_unequal("string") is False + assert _contains_self_unequal(None) is False + assert _contains_self_unequal(True) is False + + def test_object_with_nan_value_attribute(self): + """Objects wrapping NaN in .value should be detected.""" + class NanWrapper: + def __init__(self): + self.value = float('nan') + assert _contains_self_unequal(NanWrapper()) is True + + def test_custom_self_unequal_object(self): + """Custom objects where not (x == x) should be detected.""" + class NeverEqual: + def __eq__(self, other): + return False + assert _contains_self_unequal(NeverEqual()) is True + + +class TestEstimateValueSize: + """Test _estimate_value_size utility function.""" + + def test_empty_outputs(self): + """Empty outputs should have zero size.""" + value = CacheValue(outputs=[]) + assert _estimate_value_size(value) == 0 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_tensor_size_estimation(self): + """Tensor size should be estimated correctly.""" + import torch + + # 1000 float32 elements = 4000 bytes + tensor = torch.zeros(1000, dtype=torch.float32) + value = CacheValue(outputs=[[tensor]]) + + size = _estimate_value_size(value) + assert size == 4000 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_nested_tensor_in_dict(self): + """Tensors nested in dicts should be counted.""" + import torch + + tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes + value = CacheValue(outputs=[[{"samples": tensor}]]) + + size = _estimate_value_size(value) + assert size == 400 + + +class TestProviderRegistry: + """Test cache provider registration and retrieval.""" + + def setup_method(self): + """Clear providers before each test.""" + _clear_cache_providers() + + def teardown_method(self): + """Clear providers after each test.""" + _clear_cache_providers() + + def test_register_provider(self): + """Provider should be registered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + + assert _has_cache_providers() is True + providers = _get_cache_providers() + assert len(providers) == 1 + assert providers[0] is provider + + def test_unregister_provider(self): + """Provider should be unregistered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + unregister_cache_provider(provider) + + assert _has_cache_providers() is False + + def test_multiple_providers(self): + """Multiple providers can be registered.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + + providers = _get_cache_providers() + assert len(providers) == 2 + + def test_duplicate_registration_ignored(self): + """Registering same provider twice should be ignored.""" + provider = MockCacheProvider() + + register_cache_provider(provider) + register_cache_provider(provider) # Should be ignored + + providers = _get_cache_providers() + assert len(providers) == 1 + + def test_clear_providers(self): + """_clear_cache_providers should remove all providers.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + _clear_cache_providers() + + assert _has_cache_providers() is False + assert len(_get_cache_providers()) == 0 + + +class TestCacheContext: + """Test CacheContext dataclass.""" + + def test_context_creation(self): + """CacheContext should be created with all fields.""" + context = CacheContext( + node_id="node-456", + class_type="KSampler", + cache_key_hash="a" * 64, + ) + + assert context.node_id == "node-456" + assert context.class_type == "KSampler" + assert context.cache_key_hash == "a" * 64 + + +class TestCacheValue: + """Test CacheValue dataclass.""" + + def test_value_creation(self): + """CacheValue should be created with outputs.""" + outputs = [[{"samples": "tensor_data"}]] + value = CacheValue(outputs=outputs) + + assert value.outputs == outputs + + +class MockCacheProvider(CacheProvider): + """Mock cache provider for testing.""" + + def __init__(self): + self.lookups = [] + self.stores = [] + + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + self.lookups.append(context) + return None + + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + self.stores.append((context, value)) From d1d53c14be8442fca19aae978e944edad1935d46 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 12 Mar 2026 17:21:23 -0700 Subject: [PATCH 24/27] Revert "feat: Add CacheProvider API for external distributed caching (#12056)" (#12912) This reverts commit af7b4a921d7abab7c852d7b5febb654be6e57eba. --- comfy_api/latest/__init__.py | 35 -- comfy_api/latest/_caching.py | 42 -- comfy_execution/cache_provider.py | 138 ------ comfy_execution/caching.py | 177 +------- comfy_execution/graph.py | 6 +- execution.py | 141 +++--- .../execution_test/test_cache_provider.py | 403 ------------------ 7 files changed, 83 insertions(+), 859 deletions(-) delete mode 100644 comfy_api/latest/_caching.py delete mode 100644 comfy_execution/cache_provider.py delete mode 100644 tests-unit/execution_test/test_cache_provider.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 04973fea0..f2399422b 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -25,7 +25,6 @@ class ComfyAPI_latest(ComfyAPIBase): super().__init__() self.node_replacement = self.NodeReplacement() self.execution = self.Execution() - self.caching = self.Caching() class NodeReplacement(ProxiedSingleton): async def register(self, node_replace: io.NodeReplace) -> None: @@ -85,36 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase): image=to_display, ) - class Caching(ProxiedSingleton): - """ - External cache provider API for sharing cached node outputs - across ComfyUI instances. - - Example:: - - from comfy_api.latest import Caching - - class MyCacheProvider(Caching.CacheProvider): - async def on_lookup(self, context): - ... # check external storage - - async def on_store(self, context, value): - ... # store to external storage - - Caching.register_provider(MyCacheProvider()) - """ - from ._caching import CacheProvider, CacheContext, CacheValue - - async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: - """Register an external cache provider. Providers are called in registration order.""" - from comfy_execution.cache_provider import register_cache_provider - register_cache_provider(provider) - - async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: - """Unregister a previously registered cache provider.""" - from comfy_execution.cache_provider import unregister_cache_provider - unregister_cache_provider(provider) - class ComfyExtension(ABC): async def on_load(self) -> None: """ @@ -147,9 +116,6 @@ class Types: VOXEL = VOXEL File3D = File3D - -Caching = ComfyAPI_latest.Caching - ComfyAPI = ComfyAPI_latest # Create a synchronous version of the API @@ -169,7 +135,6 @@ __all__ = [ "Input", "InputImpl", "Types", - "Caching", "ComfyExtension", "io", "IO", diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py deleted file mode 100644 index 30c8848cd..000000000 --- a/comfy_api/latest/_caching.py +++ /dev/null @@ -1,42 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional -from dataclasses import dataclass - - -@dataclass -class CacheContext: - node_id: str - class_type: str - cache_key_hash: str # SHA256 hex digest - - -@dataclass -class CacheValue: - outputs: list - ui: dict = None - - -class CacheProvider(ABC): - """Abstract base class for external cache providers. - Exceptions from provider methods are caught by the caller and never break execution. - """ - - @abstractmethod - async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: - """Called on local cache miss. Return CacheValue if found, None otherwise.""" - pass - - @abstractmethod - async def on_store(self, context: CacheContext, value: CacheValue) -> None: - """Called after local store. Dispatched via asyncio.create_task.""" - pass - - def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: - """Return False to skip external caching for this node. Default: True.""" - return True - - def on_prompt_start(self, prompt_id: str) -> None: - pass - - def on_prompt_end(self, prompt_id: str) -> None: - pass diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py deleted file mode 100644 index d455d08e8..000000000 --- a/comfy_execution/cache_provider.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Any, Optional, Tuple, List -import hashlib -import json -import logging -import threading - -# Public types — source of truth is comfy_api.latest._caching -from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported) - -_logger = logging.getLogger(__name__) - - -_providers: List[CacheProvider] = [] -_providers_lock = threading.Lock() -_providers_snapshot: Tuple[CacheProvider, ...] = () - - -def register_cache_provider(provider: CacheProvider) -> None: - """Register an external cache provider. Providers are called in registration order.""" - global _providers_snapshot - with _providers_lock: - if provider in _providers: - _logger.warning(f"Provider {provider.__class__.__name__} already registered") - return - _providers.append(provider) - _providers_snapshot = tuple(_providers) - _logger.debug(f"Registered cache provider: {provider.__class__.__name__}") - - -def unregister_cache_provider(provider: CacheProvider) -> None: - global _providers_snapshot - with _providers_lock: - try: - _providers.remove(provider) - _providers_snapshot = tuple(_providers) - _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}") - except ValueError: - _logger.warning(f"Provider {provider.__class__.__name__} was not registered") - - -def _get_cache_providers() -> Tuple[CacheProvider, ...]: - return _providers_snapshot - - -def _has_cache_providers() -> bool: - return bool(_providers_snapshot) - - -def _clear_cache_providers() -> None: - global _providers_snapshot - with _providers_lock: - _providers.clear() - _providers_snapshot = () - - -def _canonicalize(obj: Any) -> Any: - # Convert to canonical JSON-serializable form with deterministic ordering. - # Frozensets have non-deterministic iteration order between Python sessions. - # Raises ValueError for non-cacheable types (Unhashable, unknown) so that - # _serialize_cache_key returns None and external caching is skipped. - if isinstance(obj, frozenset): - return ("__frozenset__", sorted( - [_canonicalize(item) for item in obj], - key=lambda x: json.dumps(x, sort_keys=True) - )) - elif isinstance(obj, set): - return ("__set__", sorted( - [_canonicalize(item) for item in obj], - key=lambda x: json.dumps(x, sort_keys=True) - )) - elif isinstance(obj, tuple): - return ("__tuple__", [_canonicalize(item) for item in obj]) - elif isinstance(obj, list): - return [_canonicalize(item) for item in obj] - elif isinstance(obj, dict): - return {"__dict__": sorted( - [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], - key=lambda x: json.dumps(x, sort_keys=True) - )} - elif isinstance(obj, (int, float, str, bool, type(None))): - return (type(obj).__name__, obj) - elif isinstance(obj, bytes): - return ("__bytes__", obj.hex()) - else: - raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") - - -def _serialize_cache_key(cache_key: Any) -> Optional[str]: - # Returns deterministic SHA256 hex digest, or None on failure. - # Uses JSON (not pickle) because pickle is non-deterministic across sessions. - try: - canonical = _canonicalize(cache_key) - json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) - return hashlib.sha256(json_str.encode('utf-8')).hexdigest() - except Exception as e: - _logger.warning(f"Failed to serialize cache key: {e}") - return None - - -def _contains_self_unequal(obj: Any) -> bool: - # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will - # never hit locally, but serialized form would match externally. Skip these. - try: - if not (obj == obj): - return True - except Exception: - return True - if isinstance(obj, (frozenset, tuple, list, set)): - return any(_contains_self_unequal(item) for item in obj) - if isinstance(obj, dict): - return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) - if hasattr(obj, 'value'): - return _contains_self_unequal(obj.value) - return False - - -def _estimate_value_size(value: CacheValue) -> int: - try: - import torch - except ImportError: - return 0 - - total = 0 - - def estimate(obj): - nonlocal total - if isinstance(obj, torch.Tensor): - total += obj.numel() * obj.element_size() - elif isinstance(obj, dict): - for v in obj.values(): - estimate(v) - elif isinstance(obj, (list, tuple)): - for item in obj: - estimate(item) - - for output in value.outputs: - estimate(output) - return total diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 750bddf2e..326a279fc 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,4 +1,3 @@ -import asyncio import bisect import gc import itertools @@ -155,7 +154,6 @@ class BasicCache: self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} - self._pending_store_tasks: set = set() async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt @@ -198,134 +196,18 @@ class BasicCache: def poll(self, **kwargs): pass - def get_local(self, node_id): + def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + def _get_immediate(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: return self.cache[cache_key] - return None - - def set_local(self, node_id, value): - assert self.initialized - cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value - - async def _set_immediate(self, node_id, value): - assert self.initialized - cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value - - await self._notify_providers_store(node_id, cache_key, value) - - async def _get_immediate(self, node_id): - if not self.initialized: - return None - cache_key = self.cache_key_set.get_data_key(node_id) - - if cache_key in self.cache: - return self.cache[cache_key] - - external_result = await self._check_providers_lookup(node_id, cache_key) - if external_result is not None: - self.cache[cache_key] = external_result - return external_result - - return None - - async def _notify_providers_store(self, node_id, cache_key, value): - from comfy_execution.cache_provider import ( - _has_cache_providers, _get_cache_providers, - CacheValue, _contains_self_unequal, _logger - ) - - if not _has_cache_providers(): - return - if not self._is_external_cacheable_value(value): - return - if _contains_self_unequal(cache_key): - return - - context = self._build_context(node_id, cache_key) - if context is None: - return - cache_value = CacheValue(outputs=value.outputs, ui=value.ui) - - for provider in _get_cache_providers(): - try: - if provider.should_cache(context, cache_value): - task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value)) - self._pending_store_tasks.add(task) - task.add_done_callback(self._pending_store_tasks.discard) - except Exception as e: - _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") - - @staticmethod - async def _safe_provider_store(provider, context, cache_value): - from comfy_execution.cache_provider import _logger - try: - await provider.on_store(context, cache_value) - except Exception as e: - _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") - - async def _check_providers_lookup(self, node_id, cache_key): - from comfy_execution.cache_provider import ( - _has_cache_providers, _get_cache_providers, - CacheValue, _contains_self_unequal, _logger - ) - - if not _has_cache_providers(): - return None - if _contains_self_unequal(cache_key): - return None - - context = self._build_context(node_id, cache_key) - if context is None: - return None - - for provider in _get_cache_providers(): - try: - if not provider.should_cache(context): - continue - result = await provider.on_lookup(context) - if result is not None: - if not isinstance(result, CacheValue): - _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") - continue - if not isinstance(result.outputs, (list, tuple)): - _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") - continue - from execution import CacheEntry - return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs)) - except Exception as e: - _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") - - return None - - def _is_external_cacheable_value(self, value): - return hasattr(value, 'outputs') and hasattr(value, 'ui') - - def _get_class_type(self, node_id): - if not self.initialized or not self.dynprompt: - return '' - try: - return self.dynprompt.get_node(node_id).get('class_type', '') - except Exception: - return '' - - def _build_context(self, node_id, cache_key): - from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger - try: - cache_key_hash = _serialize_cache_key(cache_key) - if cache_key_hash is None: - return None - return CacheContext( - node_id=node_id, - class_type=self._get_class_type(node_id), - cache_key_hash=cache_key_hash, - ) - except Exception as e: - _logger.warning(f"Failed to build cache context for node {node_id}: {e}") + else: return None async def _ensure_subcache(self, node_id, children_ids): @@ -375,27 +257,16 @@ class HierarchicalCache(BasicCache): return None return cache - async def get(self, node_id): + def get(self, node_id): cache = self._get_cache_for(node_id) if cache is None: return None - return await cache._get_immediate(node_id) + return cache._get_immediate(node_id) - def get_local(self, node_id): - cache = self._get_cache_for(node_id) - if cache is None: - return None - return BasicCache.get_local(cache, node_id) - - async def set(self, node_id, value): + def set(self, node_id, value): cache = self._get_cache_for(node_id) assert cache is not None - await cache._set_immediate(node_id, value) - - def set_local(self, node_id, value): - cache = self._get_cache_for(node_id) - assert cache is not None - BasicCache.set_local(cache, node_id, value) + cache._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) @@ -416,16 +287,10 @@ class NullCache: def poll(self, **kwargs): pass - async def get(self, node_id): + def get(self, node_id): return None - def get_local(self, node_id): - return None - - async def set(self, node_id, value): - pass - - def set_local(self, node_id, value): + def set(self, node_id, value): pass async def ensure_subcache_for(self, node_id, children_ids): @@ -457,18 +322,18 @@ class LRUCache(BasicCache): del self.children[key] self._clean_subcaches() - async def get(self, node_id): + def get(self, node_id): self._mark_used(node_id) - return await self._get_immediate(node_id) + return self._get_immediate(node_id) def _mark_used(self, node_id): cache_key = self.cache_key_set.get_data_key(node_id) if cache_key is not None: self.used_generation[cache_key] = self.generation - async def set(self, node_id, value): + def set(self, node_id, value): self._mark_used(node_id) - return await self._set_immediate(node_id, value) + return self._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes @@ -508,13 +373,13 @@ class RAMPressureCache(LRUCache): def clean_unused(self): self._clean_subcaches() - async def set(self, node_id, value): + def set(self, node_id, value): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - await super().set(node_id, value) + super().set(node_id, value) - async def get(self, node_id): + def get(self, node_id): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - return await super().get(node_id) + return super().get(node_id) def poll(self, ram_headroom): def _ram_gb(): diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index c47f3c79b..9d170b16e 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners = {} def is_cached(self, node_id): - return self.output_cache.get_local(node_id) is not None + return self.output_cache.get(node_id) is not None def cache_link(self, from_node_id, to_node_id): if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} - self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id) + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) @@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort): if value is None: return None #Write back to the main cache on touch. - self.output_cache.set_local(from_node_id, value) + self.output_cache.set(from_node_id, value) return value def cache_update(self, node_id, value): diff --git a/execution.py b/execution.py index a8e8fc59f..a7791efed 100644 --- a/execution.py +++ b/execution.py @@ -40,7 +40,6 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io -from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger class ExecutionResult(Enum): @@ -419,7 +418,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - cached = await caches.outputs.get(unique_id) + cached = caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: cached_ui = cached.ui or {} @@ -475,10 +474,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) - obj = await caches.objects.get(unique_id) + obj = caches.objects.get(unique_id) if obj is None: obj = class_def() - await caches.objects.set(unique_id, obj) + caches.objects.set(unique_id, obj) if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -589,7 +588,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) execution_list.cache_update(unique_id, cache_entry) - await caches.outputs.set(unique_id, cache_entry) + caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -685,19 +684,6 @@ class PromptExecutor: } self.add_message("execution_error", mes, broadcast=False) - def _notify_prompt_lifecycle(self, event: str, prompt_id: str): - if not _has_cache_providers(): - return - - for provider in _get_cache_providers(): - try: - if event == "start": - provider.on_prompt_start(prompt_id) - elif event == "end": - provider.on_prompt_end(prompt_id) - except Exception as e: - _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") - def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -714,75 +700,66 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - self._notify_prompt_lifecycle("start", prompt_id) + with torch.inference_mode(): + dynamic_prompt = DynamicPrompt(prompt) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - try: - with torch.inference_mode(): - dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) - add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) - for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) - cache.clean_unused() + cached_nodes = [] + for node_id in prompt: + if self.caches.outputs.get(node_id) is not None: + cached_nodes.append(node_id) - node_ids = list(prompt.keys()) - cache_results = await asyncio.gather( - *(self.caches.outputs.get(node_id) for node_id in node_ids) - ) - cached_nodes = [ - node_id for node_id, result in zip(node_ids, cache_results) - if result is not None - ] + comfy.model_management.cleanup_models_gc() + self.add_message("execution_cached", + { "nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False) + pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} + executed = set() + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() + for node_id in list(execute_outputs): + execution_list.add_node(node_id) - comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) - pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results - ui_node_outputs = {} - executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) - current_outputs = self.caches.outputs.all_node_ids() - for node_id in list(execute_outputs): - execution_list.add_node(node_id) + while not execution_list.is_empty(): + node_id, error, ex = await execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break - while not execution_list.is_empty(): - node_id, error, ex = await execution_list.stage_node_execution() - if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + self.success = result != ExecutionResult.FAILURE + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + else: + # Only execute when the while-loop ends without break + self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) - self.success = result != ExecutionResult.FAILURE - if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break - elif result == ExecutionResult.PENDING: - execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: - execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) - else: - # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) - - ui_outputs = {} - meta_outputs = {} - for node_id, ui_info in ui_node_outputs.items(): - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] - self.history_result = { - "outputs": ui_outputs, - "meta": meta_outputs, - } - self.server.last_node_id = None - if comfy.model_management.DISABLE_SMART_MEMORY: - comfy.model_management.unload_all_models() - finally: - self._notify_prompt_lifecycle("end", prompt_id) + ui_outputs = {} + meta_outputs = {} + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } + self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py deleted file mode 100644 index ac3814746..000000000 --- a/tests-unit/execution_test/test_cache_provider.py +++ /dev/null @@ -1,403 +0,0 @@ -"""Tests for external cache provider API.""" - -import importlib.util -import pytest -from typing import Optional - - -def _torch_available() -> bool: - """Check if PyTorch is available.""" - return importlib.util.find_spec("torch") is not None - - -from comfy_execution.cache_provider import ( - CacheProvider, - CacheContext, - CacheValue, - register_cache_provider, - unregister_cache_provider, - _get_cache_providers, - _has_cache_providers, - _clear_cache_providers, - _serialize_cache_key, - _contains_self_unequal, - _estimate_value_size, - _canonicalize, -) - - -class TestCanonicalize: - """Test _canonicalize function for deterministic ordering.""" - - def test_frozenset_ordering_is_deterministic(self): - """Frozensets should produce consistent canonical form regardless of iteration order.""" - # Create two frozensets with same content - fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) - fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) - - result1 = _canonicalize(fs1) - result2 = _canonicalize(fs2) - - assert result1 == result2 - - def test_nested_frozenset_ordering(self): - """Nested frozensets should also be deterministically ordered.""" - inner1 = frozenset([1, 2, 3]) - inner2 = frozenset([3, 2, 1]) - - fs1 = frozenset([("key", inner1)]) - fs2 = frozenset([("key", inner2)]) - - result1 = _canonicalize(fs1) - result2 = _canonicalize(fs2) - - assert result1 == result2 - - def test_dict_ordering(self): - """Dicts should be sorted by key.""" - d1 = {"z": 1, "a": 2, "m": 3} - d2 = {"a": 2, "m": 3, "z": 1} - - result1 = _canonicalize(d1) - result2 = _canonicalize(d2) - - assert result1 == result2 - - def test_tuple_preserved(self): - """Tuples should be marked and preserved.""" - t = (1, 2, 3) - result = _canonicalize(t) - - assert result[0] == "__tuple__" - - def test_list_preserved(self): - """Lists should be recursively canonicalized.""" - lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] - result = _canonicalize(lst) - - # First element should be canonicalized dict - assert "__dict__" in result[0] - # Second element should be canonicalized frozenset - assert result[1][0] == "__frozenset__" - - def test_primitives_include_type(self): - """Primitive types should include type name for disambiguation.""" - assert _canonicalize(42) == ("int", 42) - assert _canonicalize(3.14) == ("float", 3.14) - assert _canonicalize("hello") == ("str", "hello") - assert _canonicalize(True) == ("bool", True) - assert _canonicalize(None) == ("NoneType", None) - - def test_int_and_str_distinguished(self): - """int 7 and str '7' must produce different canonical forms.""" - assert _canonicalize(7) != _canonicalize("7") - - def test_bytes_converted(self): - """Bytes should be converted to hex string.""" - b = b"\x00\xff" - result = _canonicalize(b) - - assert result[0] == "__bytes__" - assert result[1] == "00ff" - - def test_set_ordering(self): - """Sets should be sorted like frozensets.""" - s1 = {3, 1, 2} - s2 = {1, 2, 3} - - result1 = _canonicalize(s1) - result2 = _canonicalize(s2) - - assert result1 == result2 - assert result1[0] == "__set__" - - def test_unknown_type_raises(self): - """Unknown types should raise ValueError (fail-closed).""" - class CustomObj: - pass - with pytest.raises(ValueError): - _canonicalize(CustomObj()) - - def test_object_with_value_attr_raises(self): - """Objects with .value attribute (Unhashable-like) should raise ValueError.""" - class FakeUnhashable: - def __init__(self): - self.value = float('nan') - with pytest.raises(ValueError): - _canonicalize(FakeUnhashable()) - - -class TestSerializeCacheKey: - """Test _serialize_cache_key for deterministic hashing.""" - - def test_same_content_same_hash(self): - """Same content should produce same hash.""" - key1 = frozenset([("node_1", frozenset([("input", "value")]))]) - key2 = frozenset([("node_1", frozenset([("input", "value")]))]) - - hash1 = _serialize_cache_key(key1) - hash2 = _serialize_cache_key(key2) - - assert hash1 == hash2 - - def test_different_content_different_hash(self): - """Different content should produce different hash.""" - key1 = frozenset([("node_1", "value_a")]) - key2 = frozenset([("node_1", "value_b")]) - - hash1 = _serialize_cache_key(key1) - hash2 = _serialize_cache_key(key2) - - assert hash1 != hash2 - - def test_returns_hex_string(self): - """Should return hex string (SHA256 hex digest).""" - key = frozenset([("test", 123)]) - result = _serialize_cache_key(key) - - assert isinstance(result, str) - assert len(result) == 64 # SHA256 hex digest is 64 chars - - def test_complex_nested_structure(self): - """Complex nested structures should hash deterministically.""" - # Note: frozensets can only contain hashable types, so we use - # nested frozensets of tuples to represent dict-like structures - key = frozenset([ - ("node_1", frozenset([ - ("input_a", ("tuple", "value")), - ("input_b", frozenset([("nested", "dict")])), - ])), - ("node_2", frozenset([ - ("param", 42), - ])), - ]) - - # Hash twice to verify determinism - hash1 = _serialize_cache_key(key) - hash2 = _serialize_cache_key(key) - - assert hash1 == hash2 - - def test_dict_in_cache_key(self): - """Dicts passed directly to _serialize_cache_key should work.""" - key = {"node_1": {"input": "value"}, "node_2": 42} - - hash1 = _serialize_cache_key(key) - hash2 = _serialize_cache_key(key) - - assert hash1 == hash2 - assert isinstance(hash1, str) - assert len(hash1) == 64 - - def test_unknown_type_returns_none(self): - """Non-cacheable types should return None (fail-closed).""" - class CustomObj: - pass - assert _serialize_cache_key(CustomObj()) is None - - -class TestContainsSelfUnequal: - """Test _contains_self_unequal utility function.""" - - def test_nan_float_detected(self): - """NaN floats should be detected (not equal to itself).""" - assert _contains_self_unequal(float('nan')) is True - - def test_regular_float_not_detected(self): - """Regular floats are equal to themselves.""" - assert _contains_self_unequal(3.14) is False - assert _contains_self_unequal(0.0) is False - assert _contains_self_unequal(-1.5) is False - - def test_infinity_not_detected(self): - """Infinity is equal to itself.""" - assert _contains_self_unequal(float('inf')) is False - assert _contains_self_unequal(float('-inf')) is False - - def test_nan_in_list(self): - """NaN in list should be detected.""" - assert _contains_self_unequal([1, 2, float('nan'), 4]) is True - assert _contains_self_unequal([1, 2, 3, 4]) is False - - def test_nan_in_tuple(self): - """NaN in tuple should be detected.""" - assert _contains_self_unequal((1, float('nan'))) is True - assert _contains_self_unequal((1, 2, 3)) is False - - def test_nan_in_frozenset(self): - """NaN in frozenset should be detected.""" - assert _contains_self_unequal(frozenset([1, float('nan')])) is True - assert _contains_self_unequal(frozenset([1, 2, 3])) is False - - def test_nan_in_dict_value(self): - """NaN in dict value should be detected.""" - assert _contains_self_unequal({"key": float('nan')}) is True - assert _contains_self_unequal({"key": 42}) is False - - def test_nan_in_nested_structure(self): - """NaN in deeply nested structure should be detected.""" - nested = {"level1": [{"level2": (1, 2, float('nan'))}]} - assert _contains_self_unequal(nested) is True - - def test_non_numeric_types(self): - """Non-numeric types should not be self-unequal.""" - assert _contains_self_unequal("string") is False - assert _contains_self_unequal(None) is False - assert _contains_self_unequal(True) is False - - def test_object_with_nan_value_attribute(self): - """Objects wrapping NaN in .value should be detected.""" - class NanWrapper: - def __init__(self): - self.value = float('nan') - assert _contains_self_unequal(NanWrapper()) is True - - def test_custom_self_unequal_object(self): - """Custom objects where not (x == x) should be detected.""" - class NeverEqual: - def __eq__(self, other): - return False - assert _contains_self_unequal(NeverEqual()) is True - - -class TestEstimateValueSize: - """Test _estimate_value_size utility function.""" - - def test_empty_outputs(self): - """Empty outputs should have zero size.""" - value = CacheValue(outputs=[]) - assert _estimate_value_size(value) == 0 - - @pytest.mark.skipif( - not _torch_available(), - reason="PyTorch not available" - ) - def test_tensor_size_estimation(self): - """Tensor size should be estimated correctly.""" - import torch - - # 1000 float32 elements = 4000 bytes - tensor = torch.zeros(1000, dtype=torch.float32) - value = CacheValue(outputs=[[tensor]]) - - size = _estimate_value_size(value) - assert size == 4000 - - @pytest.mark.skipif( - not _torch_available(), - reason="PyTorch not available" - ) - def test_nested_tensor_in_dict(self): - """Tensors nested in dicts should be counted.""" - import torch - - tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes - value = CacheValue(outputs=[[{"samples": tensor}]]) - - size = _estimate_value_size(value) - assert size == 400 - - -class TestProviderRegistry: - """Test cache provider registration and retrieval.""" - - def setup_method(self): - """Clear providers before each test.""" - _clear_cache_providers() - - def teardown_method(self): - """Clear providers after each test.""" - _clear_cache_providers() - - def test_register_provider(self): - """Provider should be registered successfully.""" - provider = MockCacheProvider() - register_cache_provider(provider) - - assert _has_cache_providers() is True - providers = _get_cache_providers() - assert len(providers) == 1 - assert providers[0] is provider - - def test_unregister_provider(self): - """Provider should be unregistered successfully.""" - provider = MockCacheProvider() - register_cache_provider(provider) - unregister_cache_provider(provider) - - assert _has_cache_providers() is False - - def test_multiple_providers(self): - """Multiple providers can be registered.""" - provider1 = MockCacheProvider() - provider2 = MockCacheProvider() - - register_cache_provider(provider1) - register_cache_provider(provider2) - - providers = _get_cache_providers() - assert len(providers) == 2 - - def test_duplicate_registration_ignored(self): - """Registering same provider twice should be ignored.""" - provider = MockCacheProvider() - - register_cache_provider(provider) - register_cache_provider(provider) # Should be ignored - - providers = _get_cache_providers() - assert len(providers) == 1 - - def test_clear_providers(self): - """_clear_cache_providers should remove all providers.""" - provider1 = MockCacheProvider() - provider2 = MockCacheProvider() - - register_cache_provider(provider1) - register_cache_provider(provider2) - _clear_cache_providers() - - assert _has_cache_providers() is False - assert len(_get_cache_providers()) == 0 - - -class TestCacheContext: - """Test CacheContext dataclass.""" - - def test_context_creation(self): - """CacheContext should be created with all fields.""" - context = CacheContext( - node_id="node-456", - class_type="KSampler", - cache_key_hash="a" * 64, - ) - - assert context.node_id == "node-456" - assert context.class_type == "KSampler" - assert context.cache_key_hash == "a" * 64 - - -class TestCacheValue: - """Test CacheValue dataclass.""" - - def test_value_creation(self): - """CacheValue should be created with outputs.""" - outputs = [[{"samples": "tensor_data"}]] - value = CacheValue(outputs=outputs) - - assert value.outputs == outputs - - -class MockCacheProvider(CacheProvider): - """Mock cache provider for testing.""" - - def __init__(self): - self.lookups = [] - self.stores = [] - - async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: - self.lookups.append(context) - return None - - async def on_store(self, context: CacheContext, value: CacheValue) -> None: - self.stores.append((context, value)) From 5df1427124f6ceb70166326ee257d52076adea37 Mon Sep 17 00:00:00 2001 From: PxTicks Date: Fri, 13 Mar 2026 00:44:15 +0000 Subject: [PATCH 25/27] Fix audio extraction and truncation bugs (#12652) Bug report in #12651 - to_skip fix: Prevents negative array slicing when the start offset is negative. - __duration check: Prevents the extraction loop from breaking after a single audio chunk when the requested duration is 0 (which is a sentinel for unlimited). --- comfy_api/latest/_input_impl/video_types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 58a37c9e8..1b4993aa7 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -272,7 +272,7 @@ class VideoFromFile(VideoInput): has_first_frame = False for frame in frames: offset_seconds = start_time - frame.pts * audio_stream.time_base - to_skip = int(offset_seconds * audio_stream.sample_rate) + to_skip = max(0, int(offset_seconds * audio_stream.sample_rate)) if to_skip < frame.samples: has_first_frame = True break @@ -280,7 +280,7 @@ class VideoFromFile(VideoInput): audio_frames.append(frame.to_ndarray()[..., to_skip:]) for frame in frames: - if frame.time > start_time + self.__duration: + if self.__duration and frame.time > start_time + self.__duration: break audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) if len(audio_frames) > 0: From 63d1bbdb407c69370d407ce5ced6ca3f917528a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 12 Mar 2026 20:41:48 -0400 Subject: [PATCH 26/27] ComfyUI v0.17.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 2723d02e7..701f4d66a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.4" +__version__ = "0.17.0" diff --git a/pyproject.toml b/pyproject.toml index 753b219b3..e2ca79be7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.4" +version = "0.17.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 4a8cf359fe596fc4c25a0d335d303e42c3f8605d Mon Sep 17 00:00:00 2001 From: Deep Mehta <42841935+deepme987@users.noreply.github.com> Date: Thu, 12 Mar 2026 21:17:50 -0700 Subject: [PATCH 27/27] Revert "Revert "feat: Add CacheProvider API for external distributed caching"" (#12915) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Revert "Revert "feat: Add CacheProvider API for external distributed caching …" This reverts commit d1d53c14be8442fca19aae978e944edad1935d46. * fix: gate provider lookups to outputs cache and fix UI coercion - Add `enable_providers` flag to BasicCache so only the outputs cache triggers external provider lookups/stores. The objects cache stores node class instances, not CacheEntry values, so provider calls were wasted round-trips that always missed. - Remove `or {}` coercion on `result.ui` — an empty dict passes the `is not None` gate in execution.py and causes KeyError when the history builder indexes `["output"]` and `["meta"]`. Preserving `None` correctly skips the ui_node_outputs addition. --- comfy_api/latest/__init__.py | 35 ++ comfy_api/latest/_caching.py | 42 ++ comfy_execution/cache_provider.py | 138 ++++++ comfy_execution/caching.py | 196 +++++++-- comfy_execution/graph.py | 6 +- execution.py | 147 ++++--- .../execution_test/test_cache_provider.py | 403 ++++++++++++++++++ 7 files changed, 874 insertions(+), 93 deletions(-) create mode 100644 comfy_api/latest/_caching.py create mode 100644 comfy_execution/cache_provider.py create mode 100644 tests-unit/execution_test/test_cache_provider.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index f2399422b..04973fea0 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase): super().__init__() self.node_replacement = self.NodeReplacement() self.execution = self.Execution() + self.caching = self.Caching() class NodeReplacement(ProxiedSingleton): async def register(self, node_replace: io.NodeReplace) -> None: @@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase): image=to_display, ) + class Caching(ProxiedSingleton): + """ + External cache provider API for sharing cached node outputs + across ComfyUI instances. + + Example:: + + from comfy_api.latest import Caching + + class MyCacheProvider(Caching.CacheProvider): + async def on_lookup(self, context): + ... # check external storage + + async def on_store(self, context, value): + ... # store to external storage + + Caching.register_provider(MyCacheProvider()) + """ + from ._caching import CacheProvider, CacheContext, CacheValue + + async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Register an external cache provider. Providers are called in registration order.""" + from comfy_execution.cache_provider import register_cache_provider + register_cache_provider(provider) + + async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Unregister a previously registered cache provider.""" + from comfy_execution.cache_provider import unregister_cache_provider + unregister_cache_provider(provider) + class ComfyExtension(ABC): async def on_load(self) -> None: """ @@ -116,6 +147,9 @@ class Types: VOXEL = VOXEL File3D = File3D + +Caching = ComfyAPI_latest.Caching + ComfyAPI = ComfyAPI_latest # Create a synchronous version of the API @@ -135,6 +169,7 @@ __all__ = [ "Input", "InputImpl", "Types", + "Caching", "ComfyExtension", "io", "IO", diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py new file mode 100644 index 000000000..30c8848cd --- /dev/null +++ b/comfy_api/latest/_caching.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class CacheContext: + node_id: str + class_type: str + cache_key_hash: str # SHA256 hex digest + + +@dataclass +class CacheValue: + outputs: list + ui: dict = None + + +class CacheProvider(ABC): + """Abstract base class for external cache providers. + Exceptions from provider methods are caught by the caller and never break execution. + """ + + @abstractmethod + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + """Called on local cache miss. Return CacheValue if found, None otherwise.""" + pass + + @abstractmethod + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + """Called after local store. Dispatched via asyncio.create_task.""" + pass + + def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: + """Return False to skip external caching for this node. Default: True.""" + return True + + def on_prompt_start(self, prompt_id: str) -> None: + pass + + def on_prompt_end(self, prompt_id: str) -> None: + pass diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py new file mode 100644 index 000000000..d455d08e8 --- /dev/null +++ b/comfy_execution/cache_provider.py @@ -0,0 +1,138 @@ +from typing import Any, Optional, Tuple, List +import hashlib +import json +import logging +import threading + +# Public types — source of truth is comfy_api.latest._caching +from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported) + +_logger = logging.getLogger(__name__) + + +_providers: List[CacheProvider] = [] +_providers_lock = threading.Lock() +_providers_snapshot: Tuple[CacheProvider, ...] = () + + +def register_cache_provider(provider: CacheProvider) -> None: + """Register an external cache provider. Providers are called in registration order.""" + global _providers_snapshot + with _providers_lock: + if provider in _providers: + _logger.warning(f"Provider {provider.__class__.__name__} already registered") + return + _providers.append(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Registered cache provider: {provider.__class__.__name__}") + + +def unregister_cache_provider(provider: CacheProvider) -> None: + global _providers_snapshot + with _providers_lock: + try: + _providers.remove(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}") + except ValueError: + _logger.warning(f"Provider {provider.__class__.__name__} was not registered") + + +def _get_cache_providers() -> Tuple[CacheProvider, ...]: + return _providers_snapshot + + +def _has_cache_providers() -> bool: + return bool(_providers_snapshot) + + +def _clear_cache_providers() -> None: + global _providers_snapshot + with _providers_lock: + _providers.clear() + _providers_snapshot = () + + +def _canonicalize(obj: Any) -> Any: + # Convert to canonical JSON-serializable form with deterministic ordering. + # Frozensets have non-deterministic iteration order between Python sessions. + # Raises ValueError for non-cacheable types (Unhashable, unknown) so that + # _serialize_cache_key returns None and external caching is skipped. + if isinstance(obj, frozenset): + return ("__frozenset__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, set): + return ("__set__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, tuple): + return ("__tuple__", [_canonicalize(item) for item in obj]) + elif isinstance(obj, list): + return [_canonicalize(item) for item in obj] + elif isinstance(obj, dict): + return {"__dict__": sorted( + [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], + key=lambda x: json.dumps(x, sort_keys=True) + )} + elif isinstance(obj, (int, float, str, bool, type(None))): + return (type(obj).__name__, obj) + elif isinstance(obj, bytes): + return ("__bytes__", obj.hex()) + else: + raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") + + +def _serialize_cache_key(cache_key: Any) -> Optional[str]: + # Returns deterministic SHA256 hex digest, or None on failure. + # Uses JSON (not pickle) because pickle is non-deterministic across sessions. + try: + canonical = _canonicalize(cache_key) + json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) + return hashlib.sha256(json_str.encode('utf-8')).hexdigest() + except Exception as e: + _logger.warning(f"Failed to serialize cache key: {e}") + return None + + +def _contains_self_unequal(obj: Any) -> bool: + # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will + # never hit locally, but serialized form would match externally. Skip these. + try: + if not (obj == obj): + return True + except Exception: + return True + if isinstance(obj, (frozenset, tuple, list, set)): + return any(_contains_self_unequal(item) for item in obj) + if isinstance(obj, dict): + return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) + if hasattr(obj, 'value'): + return _contains_self_unequal(obj.value) + return False + + +def _estimate_value_size(value: CacheValue) -> int: + try: + import torch + except ImportError: + return 0 + + total = 0 + + def estimate(obj): + nonlocal total + if isinstance(obj, torch.Tensor): + total += obj.numel() * obj.element_size() + elif isinstance(obj, dict): + for v in obj.values(): + estimate(v) + elif isinstance(obj, (list, tuple)): + for item in obj: + estimate(item) + + for output in value.outputs: + estimate(output) + return total diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..78212bde3 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,3 +1,4 @@ +import asyncio import bisect import gc import itertools @@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet): self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) class BasicCache: - def __init__(self, key_class): + def __init__(self, key_class, enable_providers=False): self.key_class = key_class self.initialized = False + self.enable_providers = enable_providers self.dynprompt: DynamicPrompt self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} + self._pending_store_tasks: set = set() async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt @@ -196,18 +199,138 @@ class BasicCache: def poll(self, **kwargs): pass - def _set_immediate(self, node_id, value): - assert self.initialized - cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value - - def _get_immediate(self, node_id): + def get_local(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: return self.cache[cache_key] - else: + return None + + def set_local(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + async def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + await self._notify_providers_store(node_id, cache_key, value) + + async def _get_immediate(self, node_id): + if not self.initialized: + return None + cache_key = self.cache_key_set.get_data_key(node_id) + + if cache_key in self.cache: + return self.cache[cache_key] + + external_result = await self._check_providers_lookup(node_id, cache_key) + if external_result is not None: + self.cache[cache_key] = external_result + return external_result + + return None + + async def _notify_providers_store(self, node_id, cache_key, value): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not self.enable_providers: + return + if not _has_cache_providers(): + return + if not self._is_external_cacheable_value(value): + return + if _contains_self_unequal(cache_key): + return + + context = self._build_context(node_id, cache_key) + if context is None: + return + cache_value = CacheValue(outputs=value.outputs, ui=value.ui) + + for provider in _get_cache_providers(): + try: + if provider.should_cache(context, cache_value): + task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value)) + self._pending_store_tasks.add(task) + task.add_done_callback(self._pending_store_tasks.discard) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + + @staticmethod + async def _safe_provider_store(provider, context, cache_value): + from comfy_execution.cache_provider import _logger + try: + await provider.on_store(context, cache_value) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") + + async def _check_providers_lookup(self, node_id, cache_key): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not self.enable_providers: + return None + if not _has_cache_providers(): + return None + if _contains_self_unequal(cache_key): + return None + + context = self._build_context(node_id, cache_key) + if context is None: + return None + + for provider in _get_cache_providers(): + try: + if not provider.should_cache(context): + continue + result = await provider.on_lookup(context) + if result is not None: + if not isinstance(result, CacheValue): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") + continue + if not isinstance(result.outputs, (list, tuple)): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") + continue + from execution import CacheEntry + return CacheEntry(ui=result.ui, outputs=list(result.outputs)) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") + + return None + + def _is_external_cacheable_value(self, value): + return hasattr(value, 'outputs') and hasattr(value, 'ui') + + def _get_class_type(self, node_id): + if not self.initialized or not self.dynprompt: + return '' + try: + return self.dynprompt.get_node(node_id).get('class_type', '') + except Exception: + return '' + + def _build_context(self, node_id, cache_key): + from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger + try: + cache_key_hash = _serialize_cache_key(cache_key) + if cache_key_hash is None: + return None + return CacheContext( + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key_hash=cache_key_hash, + ) + except Exception as e: + _logger.warning(f"Failed to build cache context for node {node_id}: {e}") return None async def _ensure_subcache(self, node_id, children_ids): @@ -236,8 +359,8 @@ class BasicCache: return result class HierarchicalCache(BasicCache): - def __init__(self, key_class): - super().__init__(key_class) + def __init__(self, key_class, enable_providers=False): + super().__init__(key_class, enable_providers=enable_providers) def _get_cache_for(self, node_id): assert self.dynprompt is not None @@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache): return None return cache - def get(self, node_id): + async def get(self, node_id): cache = self._get_cache_for(node_id) if cache is None: return None - return cache._get_immediate(node_id) + return await cache._get_immediate(node_id) - def set(self, node_id, value): + def get_local(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return BasicCache.get_local(cache, node_id) + + async def set(self, node_id, value): cache = self._get_cache_for(node_id) assert cache is not None - cache._set_immediate(node_id, value) + await cache._set_immediate(node_id, value) + + def set_local(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + BasicCache.set_local(cache, node_id, value) async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) @@ -287,18 +421,24 @@ class NullCache: def poll(self, **kwargs): pass - def get(self, node_id): + async def get(self, node_id): return None - def set(self, node_id, value): + def get_local(self, node_id): + return None + + async def set(self, node_id, value): + pass + + def set_local(self, node_id, value): pass async def ensure_subcache_for(self, node_id, children_ids): return self class LRUCache(BasicCache): - def __init__(self, key_class, max_size=100): - super().__init__(key_class) + def __init__(self, key_class, max_size=100, enable_providers=False): + super().__init__(key_class, enable_providers=enable_providers) self.max_size = max_size self.min_generation = 0 self.generation = 0 @@ -322,18 +462,18 @@ class LRUCache(BasicCache): del self.children[key] self._clean_subcaches() - def get(self, node_id): + async def get(self, node_id): self._mark_used(node_id) - return self._get_immediate(node_id) + return await self._get_immediate(node_id) def _mark_used(self, node_id): cache_key = self.cache_key_set.get_data_key(node_id) if cache_key is not None: self.used_generation[cache_key] = self.generation - def set(self, node_id, value): + async def set(self, node_id, value): self._mark_used(node_id) - return self._set_immediate(node_id, value) + return await self._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes @@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 class RAMPressureCache(LRUCache): - def __init__(self, key_class): - super().__init__(key_class, 0) + def __init__(self, key_class, enable_providers=False): + super().__init__(key_class, 0, enable_providers=enable_providers) self.timestamps = {} def clean_unused(self): self._clean_subcaches() - def set(self, node_id, value): + async def set(self, node_id, value): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - super().set(node_id, value) + await super().set(node_id, value) - def get(self, node_id): + async def get(self, node_id): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - return super().get(node_id) + return await super().get(node_id) def poll(self, ram_headroom): def _ram_gb(): diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 9d170b16e..c47f3c79b 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners = {} def is_cached(self, node_id): - return self.output_cache.get(node_id) is not None + return self.output_cache.get_local(node_id) is not None def cache_link(self, from_node_id, to_node_id): if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} - self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id) if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) @@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort): if value is None: return None #Write back to the main cache on touch. - self.output_cache.set(from_node_id, value) + self.output_cache.set_local(from_node_id, value) return value def cache_update(self, node_id, value): diff --git a/execution.py b/execution.py index a7791efed..1a6c3429c 100644 --- a/execution.py +++ b/execution.py @@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io +from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger class ExecutionResult(Enum): @@ -126,15 +127,15 @@ class CacheSet: # Performs like the old cache -- dump data ASAP def init_classic_cache(self): - self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True) self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - cached = caches.outputs.get(unique_id) + cached = await caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: cached_ui = cached.ui or {} @@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) - obj = caches.objects.get(unique_id) + obj = await caches.objects.get(unique_id) if obj is None: obj = class_def() - caches.objects.set(unique_id, obj) + await caches.objects.set(unique_id, obj) if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) execution_list.cache_update(unique_id, cache_entry) - caches.outputs.set(unique_id, cache_entry) + await caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -684,6 +685,19 @@ class PromptExecutor: } self.add_message("execution_error", mes, broadcast=False) + def _notify_prompt_lifecycle(self, event: str, prompt_id: str): + if not _has_cache_providers(): + return + + for provider in _get_cache_providers(): + try: + if event == "start": + provider.on_prompt_start(prompt_id) + elif event == "end": + provider.on_prompt_end(prompt_id) + except Exception as e: + _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -700,66 +714,75 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - with torch.inference_mode(): - dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) - add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) - for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) - cache.clean_unused() + self._notify_prompt_lifecycle("start", prompt_id) - cached_nodes = [] - for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: - cached_nodes.append(node_id) + try: + with torch.inference_mode(): + dynamic_prompt = DynamicPrompt(prompt) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) - pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results - ui_node_outputs = {} - executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) - current_outputs = self.caches.outputs.all_node_ids() - for node_id in list(execute_outputs): - execution_list.add_node(node_id) + node_ids = list(prompt.keys()) + cache_results = await asyncio.gather( + *(self.caches.outputs.get(node_id) for node_id in node_ids) + ) + cached_nodes = [ + node_id for node_id, result in zip(node_ids, cache_results) + if result is not None + ] - while not execution_list.is_empty(): - node_id, error, ex = await execution_list.stage_node_execution() - if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break + comfy.model_management.cleanup_models_gc() + self.add_message("execution_cached", + { "nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False) + pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} + executed = set() + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() + for node_id in list(execute_outputs): + execution_list.add_node(node_id) - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) - self.success = result != ExecutionResult.FAILURE - if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break - elif result == ExecutionResult.PENDING: - execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: - execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) - else: - # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + while not execution_list.is_empty(): + node_id, error, ex = await execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break - ui_outputs = {} - meta_outputs = {} - for node_id, ui_info in ui_node_outputs.items(): - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] - self.history_result = { - "outputs": ui_outputs, - "meta": meta_outputs, - } - self.server.last_node_id = None - if comfy.model_management.DISABLE_SMART_MEMORY: - comfy.model_management.unload_all_models() + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + self.success = result != ExecutionResult.FAILURE + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + else: + # Only execute when the while-loop ends without break + self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + + ui_outputs = {} + meta_outputs = {} + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } + self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() + finally: + self._notify_prompt_lifecycle("end", prompt_id) async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py new file mode 100644 index 000000000..ac3814746 --- /dev/null +++ b/tests-unit/execution_test/test_cache_provider.py @@ -0,0 +1,403 @@ +"""Tests for external cache provider API.""" + +import importlib.util +import pytest +from typing import Optional + + +def _torch_available() -> bool: + """Check if PyTorch is available.""" + return importlib.util.find_spec("torch") is not None + + +from comfy_execution.cache_provider import ( + CacheProvider, + CacheContext, + CacheValue, + register_cache_provider, + unregister_cache_provider, + _get_cache_providers, + _has_cache_providers, + _clear_cache_providers, + _serialize_cache_key, + _contains_self_unequal, + _estimate_value_size, + _canonicalize, +) + + +class TestCanonicalize: + """Test _canonicalize function for deterministic ordering.""" + + def test_frozenset_ordering_is_deterministic(self): + """Frozensets should produce consistent canonical form regardless of iteration order.""" + # Create two frozensets with same content + fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) + fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_nested_frozenset_ordering(self): + """Nested frozensets should also be deterministically ordered.""" + inner1 = frozenset([1, 2, 3]) + inner2 = frozenset([3, 2, 1]) + + fs1 = frozenset([("key", inner1)]) + fs2 = frozenset([("key", inner2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_dict_ordering(self): + """Dicts should be sorted by key.""" + d1 = {"z": 1, "a": 2, "m": 3} + d2 = {"a": 2, "m": 3, "z": 1} + + result1 = _canonicalize(d1) + result2 = _canonicalize(d2) + + assert result1 == result2 + + def test_tuple_preserved(self): + """Tuples should be marked and preserved.""" + t = (1, 2, 3) + result = _canonicalize(t) + + assert result[0] == "__tuple__" + + def test_list_preserved(self): + """Lists should be recursively canonicalized.""" + lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] + result = _canonicalize(lst) + + # First element should be canonicalized dict + assert "__dict__" in result[0] + # Second element should be canonicalized frozenset + assert result[1][0] == "__frozenset__" + + def test_primitives_include_type(self): + """Primitive types should include type name for disambiguation.""" + assert _canonicalize(42) == ("int", 42) + assert _canonicalize(3.14) == ("float", 3.14) + assert _canonicalize("hello") == ("str", "hello") + assert _canonicalize(True) == ("bool", True) + assert _canonicalize(None) == ("NoneType", None) + + def test_int_and_str_distinguished(self): + """int 7 and str '7' must produce different canonical forms.""" + assert _canonicalize(7) != _canonicalize("7") + + def test_bytes_converted(self): + """Bytes should be converted to hex string.""" + b = b"\x00\xff" + result = _canonicalize(b) + + assert result[0] == "__bytes__" + assert result[1] == "00ff" + + def test_set_ordering(self): + """Sets should be sorted like frozensets.""" + s1 = {3, 1, 2} + s2 = {1, 2, 3} + + result1 = _canonicalize(s1) + result2 = _canonicalize(s2) + + assert result1 == result2 + assert result1[0] == "__set__" + + def test_unknown_type_raises(self): + """Unknown types should raise ValueError (fail-closed).""" + class CustomObj: + pass + with pytest.raises(ValueError): + _canonicalize(CustomObj()) + + def test_object_with_value_attr_raises(self): + """Objects with .value attribute (Unhashable-like) should raise ValueError.""" + class FakeUnhashable: + def __init__(self): + self.value = float('nan') + with pytest.raises(ValueError): + _canonicalize(FakeUnhashable()) + + +class TestSerializeCacheKey: + """Test _serialize_cache_key for deterministic hashing.""" + + def test_same_content_same_hash(self): + """Same content should produce same hash.""" + key1 = frozenset([("node_1", frozenset([("input", "value")]))]) + key2 = frozenset([("node_1", frozenset([("input", "value")]))]) + + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) + + assert hash1 == hash2 + + def test_different_content_different_hash(self): + """Different content should produce different hash.""" + key1 = frozenset([("node_1", "value_a")]) + key2 = frozenset([("node_1", "value_b")]) + + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) + + assert hash1 != hash2 + + def test_returns_hex_string(self): + """Should return hex string (SHA256 hex digest).""" + key = frozenset([("test", 123)]) + result = _serialize_cache_key(key) + + assert isinstance(result, str) + assert len(result) == 64 # SHA256 hex digest is 64 chars + + def test_complex_nested_structure(self): + """Complex nested structures should hash deterministically.""" + # Note: frozensets can only contain hashable types, so we use + # nested frozensets of tuples to represent dict-like structures + key = frozenset([ + ("node_1", frozenset([ + ("input_a", ("tuple", "value")), + ("input_b", frozenset([("nested", "dict")])), + ])), + ("node_2", frozenset([ + ("param", 42), + ])), + ]) + + # Hash twice to verify determinism + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) + + assert hash1 == hash2 + + def test_dict_in_cache_key(self): + """Dicts passed directly to _serialize_cache_key should work.""" + key = {"node_1": {"input": "value"}, "node_2": 42} + + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) + + assert hash1 == hash2 + assert isinstance(hash1, str) + assert len(hash1) == 64 + + def test_unknown_type_returns_none(self): + """Non-cacheable types should return None (fail-closed).""" + class CustomObj: + pass + assert _serialize_cache_key(CustomObj()) is None + + +class TestContainsSelfUnequal: + """Test _contains_self_unequal utility function.""" + + def test_nan_float_detected(self): + """NaN floats should be detected (not equal to itself).""" + assert _contains_self_unequal(float('nan')) is True + + def test_regular_float_not_detected(self): + """Regular floats are equal to themselves.""" + assert _contains_self_unequal(3.14) is False + assert _contains_self_unequal(0.0) is False + assert _contains_self_unequal(-1.5) is False + + def test_infinity_not_detected(self): + """Infinity is equal to itself.""" + assert _contains_self_unequal(float('inf')) is False + assert _contains_self_unequal(float('-inf')) is False + + def test_nan_in_list(self): + """NaN in list should be detected.""" + assert _contains_self_unequal([1, 2, float('nan'), 4]) is True + assert _contains_self_unequal([1, 2, 3, 4]) is False + + def test_nan_in_tuple(self): + """NaN in tuple should be detected.""" + assert _contains_self_unequal((1, float('nan'))) is True + assert _contains_self_unequal((1, 2, 3)) is False + + def test_nan_in_frozenset(self): + """NaN in frozenset should be detected.""" + assert _contains_self_unequal(frozenset([1, float('nan')])) is True + assert _contains_self_unequal(frozenset([1, 2, 3])) is False + + def test_nan_in_dict_value(self): + """NaN in dict value should be detected.""" + assert _contains_self_unequal({"key": float('nan')}) is True + assert _contains_self_unequal({"key": 42}) is False + + def test_nan_in_nested_structure(self): + """NaN in deeply nested structure should be detected.""" + nested = {"level1": [{"level2": (1, 2, float('nan'))}]} + assert _contains_self_unequal(nested) is True + + def test_non_numeric_types(self): + """Non-numeric types should not be self-unequal.""" + assert _contains_self_unequal("string") is False + assert _contains_self_unequal(None) is False + assert _contains_self_unequal(True) is False + + def test_object_with_nan_value_attribute(self): + """Objects wrapping NaN in .value should be detected.""" + class NanWrapper: + def __init__(self): + self.value = float('nan') + assert _contains_self_unequal(NanWrapper()) is True + + def test_custom_self_unequal_object(self): + """Custom objects where not (x == x) should be detected.""" + class NeverEqual: + def __eq__(self, other): + return False + assert _contains_self_unequal(NeverEqual()) is True + + +class TestEstimateValueSize: + """Test _estimate_value_size utility function.""" + + def test_empty_outputs(self): + """Empty outputs should have zero size.""" + value = CacheValue(outputs=[]) + assert _estimate_value_size(value) == 0 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_tensor_size_estimation(self): + """Tensor size should be estimated correctly.""" + import torch + + # 1000 float32 elements = 4000 bytes + tensor = torch.zeros(1000, dtype=torch.float32) + value = CacheValue(outputs=[[tensor]]) + + size = _estimate_value_size(value) + assert size == 4000 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_nested_tensor_in_dict(self): + """Tensors nested in dicts should be counted.""" + import torch + + tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes + value = CacheValue(outputs=[[{"samples": tensor}]]) + + size = _estimate_value_size(value) + assert size == 400 + + +class TestProviderRegistry: + """Test cache provider registration and retrieval.""" + + def setup_method(self): + """Clear providers before each test.""" + _clear_cache_providers() + + def teardown_method(self): + """Clear providers after each test.""" + _clear_cache_providers() + + def test_register_provider(self): + """Provider should be registered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + + assert _has_cache_providers() is True + providers = _get_cache_providers() + assert len(providers) == 1 + assert providers[0] is provider + + def test_unregister_provider(self): + """Provider should be unregistered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + unregister_cache_provider(provider) + + assert _has_cache_providers() is False + + def test_multiple_providers(self): + """Multiple providers can be registered.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + + providers = _get_cache_providers() + assert len(providers) == 2 + + def test_duplicate_registration_ignored(self): + """Registering same provider twice should be ignored.""" + provider = MockCacheProvider() + + register_cache_provider(provider) + register_cache_provider(provider) # Should be ignored + + providers = _get_cache_providers() + assert len(providers) == 1 + + def test_clear_providers(self): + """_clear_cache_providers should remove all providers.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + _clear_cache_providers() + + assert _has_cache_providers() is False + assert len(_get_cache_providers()) == 0 + + +class TestCacheContext: + """Test CacheContext dataclass.""" + + def test_context_creation(self): + """CacheContext should be created with all fields.""" + context = CacheContext( + node_id="node-456", + class_type="KSampler", + cache_key_hash="a" * 64, + ) + + assert context.node_id == "node-456" + assert context.class_type == "KSampler" + assert context.cache_key_hash == "a" * 64 + + +class TestCacheValue: + """Test CacheValue dataclass.""" + + def test_value_creation(self): + """CacheValue should be created with outputs.""" + outputs = [[{"samples": "tensor_data"}]] + value = CacheValue(outputs=outputs) + + assert value.outputs == outputs + + +class MockCacheProvider(CacheProvider): + """Mock cache provider for testing.""" + + def __init__(self): + self.lookups = [] + self.stores = [] + + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + self.lookups.append(context) + return None + + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + self.stores.append((context, value))