diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py
index 92b1acbd5..57126fa4a 100644
--- a/comfy/comfy_types/node_typing.py
+++ b/comfy/comfy_types/node_typing.py
@@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
- gradient_stops: NotRequired[list[list[float]]]
- """Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
+ gradient_stops: NotRequired[list[dict]]
+ """Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
class HiddenInputTypeDict(TypedDict):
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 8b3f500d7..e28d704b4 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor * m_mult
else:
for d in modulation_dims:
- tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
+ tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
if m_add is not None:
- tensor[:, d[0]:d[1]] += m_add[:, d[2]]
+ tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
return tensor
@@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module):
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
+
+ extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
+ if "attn1_patch" in transformer_patches:
+ patch = transformer_patches["attn1_patch"]
+ for p in patch:
+ out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
+ q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
+
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
- extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
@@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module):
del qkv
q, k = self.norm(q, k, v)
+ if "attn1_patch" in transformer_patches:
+ patch = transformer_patches["attn1_patch"]
+ for p in patch:
+ out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
+ q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
+
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py
index 5e764bb46..824daf5e6 100644
--- a/comfy/ldm/flux/math.py
+++ b/comfy/ldm/flux/math.py
@@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
+ if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
+ freqs_cis = freqs_cis[:, :, :x_.shape[2]]
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index ef4dcf7c5..8e7912e6d 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -44,6 +44,22 @@ class FluxParams:
txt_norm: bool = False
+def invert_slices(slices, length):
+ sorted_slices = sorted(slices)
+ result = []
+ current = 0
+
+ for start, end in sorted_slices:
+ if current < start:
+ result.append((current, start))
+ current = max(current, end)
+
+ if current < length:
+ result.append((current, length))
+
+ return result
+
+
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
@@ -138,6 +154,7 @@ class Flux(nn.Module):
y: Tensor,
guidance: Tensor = None,
control = None,
+ timestep_zero_index=None,
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
@@ -164,13 +181,9 @@ class Flux(nn.Module):
txt = self.txt_norm(txt)
txt = self.txt_in(txt)
- vec_orig = vec
- if self.params.global_modulation:
- vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
-
if "post_input" in patches:
for p in patches["post_input"]:
- out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
+ out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
img = out["img"]
txt = out["txt"]
img_ids = out["img_ids"]
@@ -182,6 +195,24 @@ class Flux(nn.Module):
else:
pe = None
+ vec_orig = vec
+ txt_vec = vec
+ extra_kwargs = {}
+ if timestep_zero_index is not None:
+ modulation_dims = []
+ batch = vec.shape[0] // 2
+ vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
+ invert = invert_slices(timestep_zero_index, img.shape[1])
+ for s in invert:
+ modulation_dims.append((s[0], s[1], 0))
+ for s in timestep_zero_index:
+ modulation_dims.append((s[0], s[1], 1))
+ extra_kwargs["modulation_dims_img"] = modulation_dims
+ txt_vec = vec[:batch]
+
+ if self.params.global_modulation:
+ vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
+
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
@@ -195,7 +226,8 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
- transformer_options=args.get("transformer_options"))
+ transformer_options=args.get("transformer_options"),
+ **extra_kwargs)
return out
out = blocks_replace[("double_block", i)]({"img": img,
@@ -213,7 +245,8 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
attn_mask=attn_mask,
- transformer_options=transformer_options)
+ transformer_options=transformer_options,
+ **extra_kwargs)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -230,6 +263,12 @@ class Flux(nn.Module):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
+ extra_kwargs = {}
+ if timestep_zero_index is not None:
+ lambda a: 0 if a == 0 else a + txt.shape[1]
+ modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
+ extra_kwargs["modulation_dims"] = modulation_dims_combined
+
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@@ -242,7 +281,8 @@ class Flux(nn.Module):
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"),
- transformer_options=args.get("transformer_options"))
+ transformer_options=args.get("transformer_options"),
+ **extra_kwargs)
return out
out = blocks_replace[("single_block", i)]({"img": img,
@@ -253,7 +293,7 @@ class Flux(nn.Module):
{"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -264,7 +304,11 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
- img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
+ extra_kwargs = {}
+ if timestep_zero_index is not None:
+ extra_kwargs["modulation_dims"] = modulation_dims
+
+ img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@@ -312,13 +356,16 @@ class Flux(nn.Module):
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
+ timestep_zero_index = None
if ref_latents is not None:
+ ref_num_tokens = []
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
+ timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents:
- if ref_latents_method == "index":
+ if ref_latents_method in ("index", "index_timestep_zero"):
index += self.params.ref_index_scale
h_offset = 0
w_offset = 0
@@ -342,6 +389,13 @@ class Flux(nn.Module):
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
+ ref_num_tokens.append(kontext.shape[1])
+ if timestep_zero:
+ if index > 0:
+ timestep = torch.cat([timestep, timestep * 0], dim=0)
+ timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
+ transformer_options = transformer_options.copy()
+ transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
@@ -349,6 +403,6 @@ class Flux(nn.Module):
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
- out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
+ out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 10d051325..b193fe5e8 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
- except model_management.OOM_EXCEPTION as e:
+ except Exception as e:
+ model_management.raise_non_oom(e)
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 805592aa5..fcbaa074f 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -258,7 +258,8 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
- except model_management.OOM_EXCEPTION as e:
+ except Exception as e:
+ model_management.raise_non_oom(e)
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
@@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
- except model_management.OOM_EXCEPTION:
+ except Exception as e:
+ model_management.raise_non_oom(e)
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:
diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py
index fab145f1c..f982afc2b 100644
--- a/comfy/ldm/modules/sub_quadratic_attention.py
+++ b/comfy/ldm/modules/sub_quadratic_attention.py
@@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
- except model_management.OOM_EXCEPTION:
+ except Exception as e:
+ model_management.raise_non_oom(e)
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)
diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py
index 6eb744286..0862f72f7 100644
--- a/comfy/ldm/qwen_image/model.py
+++ b/comfy/ldm/qwen_image/model.py
@@ -149,6 +149,9 @@ class Attention(nn.Module):
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]
+ transformer_patches = transformer_options.get("patches", {})
+ extra_options = transformer_options.copy()
+
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
@@ -167,15 +170,22 @@ class Attention(nn.Module):
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
- joint_query = apply_rope1(joint_query, image_rotary_emb)
- joint_key = apply_rope1(joint_key, image_rotary_emb)
-
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None
+ extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
+ if "attn1_patch" in transformer_patches:
+ patch = transformer_patches["attn1_patch"]
+ for p in patch:
+ out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
+ joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
+
+ joint_query = apply_rope1(joint_query, image_rotary_emb)
+ joint_key = apply_rope1(joint_key, image_rotary_emb)
+
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options,
skip_reshape=True)
@@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module):
timestep_zero_index = None
if ref_latents is not None:
+ ref_num_tokens = []
h = 0
w = 0
index = 0
@@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
+ ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
+ transformer_options = transformer_options.copy()
+ transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
- ids = torch.cat((txt_ids, img_ids), dim=1)
- image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
- del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
+ if "post_input" in patches:
+ for p in patches["post_input"]:
+ out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
+ hidden_states = out["img"]
+ encoder_hidden_states = out["txt"]
+ img_ids = out["img_ids"]
+ txt_ids = out["txt_ids"]
+
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
+ del ids, txt_ids, img_ids
+
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
diff --git a/comfy/lora.py b/comfy/lora.py
index f36ddb046..63ee85323 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
+ tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
+ if tp > 0 and not k.startswith("clip_"):
+ key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 6eace4628..35a6822e3 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -1,4 +1,5 @@
import json
+import comfy.memory_management
import comfy.supported_models
import comfy.supported_models_base
import comfy.utils
@@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
new[:old_weight.shape[0]] = old_weight
old_weight = new
+ if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
+ old_weight = old_weight.clone()
+
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
+ if comfy.memory_management.aimdo_enabled:
+ weight = weight.clone()
old_weight = weight
w = weight
w[:] = fun(weight)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 07bc8ad67..81c89b180 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -270,6 +270,23 @@ try:
except:
OOM_EXCEPTION = Exception
+try:
+ ACCELERATOR_ERROR = torch.AcceleratorError
+except AttributeError:
+ ACCELERATOR_ERROR = RuntimeError
+
+def is_oom(e):
+ if isinstance(e, OOM_EXCEPTION):
+ return True
+ if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
+ discard_cuda_async_error()
+ return True
+ return False
+
+def raise_non_oom(e):
+ if not is_oom(e):
+ raise e
+
XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
if args.disable_xformers:
@@ -1263,7 +1280,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
- except torch.AcceleratorError:
+ except RuntimeError:
#Dump it! We already know about it from the synchronous return
pass
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 745384271..bc3a8f446 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -599,6 +599,27 @@ class ModelPatcher:
return models
+ def model_patches_call_function(self, function_name="cleanup", arguments={}):
+ to = self.model_options["transformer_options"]
+ if "patches" in to:
+ patches = to["patches"]
+ for name in patches:
+ patch_list = patches[name]
+ for i in range(len(patch_list)):
+ if hasattr(patch_list[i], function_name):
+ getattr(patch_list[i], function_name)(**arguments)
+ if "patches_replace" in to:
+ patches = to["patches_replace"]
+ for name in patches:
+ patch_list = patches[name]
+ for k in patch_list:
+ if hasattr(patch_list[k], function_name):
+ getattr(patch_list[k], function_name)(**arguments)
+ if "model_function_wrapper" in self.model_options:
+ wrap_func = self.model_options["model_function_wrapper"]
+ if hasattr(wrap_func, function_name):
+ getattr(wrap_func, function_name)(**arguments)
+
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
@@ -1062,6 +1083,7 @@ class ModelPatcher:
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def cleanup(self):
+ self.model_patches_call_function(function_name="cleanup")
self.clean_hooks()
if hasattr(self.model, "current_patcher"):
self.model.current_patcher = None
diff --git a/comfy/sd.py b/comfy/sd.py
index 888ef1e77..adcd67767 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -954,7 +954,8 @@ class VAE:
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
- except model_management.OOM_EXCEPTION:
+ except Exception as e:
+ model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
@@ -1029,7 +1030,8 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
- except model_management.OOM_EXCEPTION:
+ except Exception as e:
+ model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index f2399422b..04973fea0 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
+ self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
@@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
+ class Caching(ProxiedSingleton):
+ """
+ External cache provider API for sharing cached node outputs
+ across ComfyUI instances.
+
+ Example::
+
+ from comfy_api.latest import Caching
+
+ class MyCacheProvider(Caching.CacheProvider):
+ async def on_lookup(self, context):
+ ... # check external storage
+
+ async def on_store(self, context, value):
+ ... # store to external storage
+
+ Caching.register_provider(MyCacheProvider())
+ """
+ from ._caching import CacheProvider, CacheContext, CacheValue
+
+ async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
+ """Register an external cache provider. Providers are called in registration order."""
+ from comfy_execution.cache_provider import register_cache_provider
+ register_cache_provider(provider)
+
+ async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
+ """Unregister a previously registered cache provider."""
+ from comfy_execution.cache_provider import unregister_cache_provider
+ unregister_cache_provider(provider)
+
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
@@ -116,6 +147,9 @@ class Types:
VOXEL = VOXEL
File3D = File3D
+
+Caching = ComfyAPI_latest.Caching
+
ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API
@@ -135,6 +169,7 @@ __all__ = [
"Input",
"InputImpl",
"Types",
+ "Caching",
"ComfyExtension",
"io",
"IO",
diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py
new file mode 100644
index 000000000..30c8848cd
--- /dev/null
+++ b/comfy_api/latest/_caching.py
@@ -0,0 +1,42 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+from dataclasses import dataclass
+
+
+@dataclass
+class CacheContext:
+ node_id: str
+ class_type: str
+ cache_key_hash: str # SHA256 hex digest
+
+
+@dataclass
+class CacheValue:
+ outputs: list
+ ui: dict = None
+
+
+class CacheProvider(ABC):
+ """Abstract base class for external cache providers.
+ Exceptions from provider methods are caught by the caller and never break execution.
+ """
+
+ @abstractmethod
+ async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
+ """Called on local cache miss. Return CacheValue if found, None otherwise."""
+ pass
+
+ @abstractmethod
+ async def on_store(self, context: CacheContext, value: CacheValue) -> None:
+ """Called after local store. Dispatched via asyncio.create_task."""
+ pass
+
+ def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
+ """Return False to skip external caching for this node. Default: True."""
+ return True
+
+ def on_prompt_start(self, prompt_id: str) -> None:
+ pass
+
+ def on_prompt_end(self, prompt_id: str) -> None:
+ pass
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index 58a37c9e8..1b4993aa7 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
has_first_frame = False
for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base
- to_skip = int(offset_seconds * audio_stream.sample_rate)
+ to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
if to_skip < frame.samples:
has_first_frame = True
break
@@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
- if frame.time > start_time + self.__duration:
+ if self.__duration and frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 050031dc0..7ca8f4e0c 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
'''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
- display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
+ display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
diff --git a/comfy_api_nodes/apis/reve.py b/comfy_api_nodes/apis/reve.py
new file mode 100644
index 000000000..c6b5a69d8
--- /dev/null
+++ b/comfy_api_nodes/apis/reve.py
@@ -0,0 +1,68 @@
+from pydantic import BaseModel, Field
+
+
+class RevePostprocessingOperation(BaseModel):
+ process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
+ upscale_factor: int | None = Field(
+ None,
+ description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
+ ge=2,
+ le=4,
+ )
+
+
+class ReveImageCreateRequest(BaseModel):
+ prompt: str = Field(...)
+ aspect_ratio: str | None = Field(...)
+ version: str = Field(...)
+ test_time_scaling: int = Field(
+ ...,
+ description="If included, the model will spend more effort making better images. Values between 1 and 15.",
+ ge=1,
+ le=15,
+ )
+ postprocessing: list[RevePostprocessingOperation] | None = Field(
+ None, description="Optional postprocessing operations to apply after generation."
+ )
+
+
+class ReveImageEditRequest(BaseModel):
+ edit_instruction: str = Field(...)
+ reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
+ aspect_ratio: str | None = Field(...)
+ version: str = Field(...)
+ test_time_scaling: int | None = Field(
+ ...,
+ description="If included, the model will spend more effort making better images. Values between 1 and 15.",
+ ge=1,
+ le=15,
+ )
+ postprocessing: list[RevePostprocessingOperation] | None = Field(
+ None, description="Optional postprocessing operations to apply after generation."
+ )
+
+
+class ReveImageRemixRequest(BaseModel):
+ prompt: str = Field(...)
+ reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
+ aspect_ratio: str | None = Field(...)
+ version: str = Field(...)
+ test_time_scaling: int | None = Field(
+ ...,
+ description="If included, the model will spend more effort making better images. Values between 1 and 15.",
+ ge=1,
+ le=15,
+ )
+ postprocessing: list[RevePostprocessingOperation] | None = Field(
+ None, description="Optional postprocessing operations to apply after generation."
+ )
+
+
+class ReveImageResponse(BaseModel):
+ image: str | None = Field(None, description="The base64 encoded image data.")
+ request_id: str | None = Field(None, description="A unique id for the request.")
+ credits_used: float | None = Field(None, description="The number of credits used for this request.")
+ version: str | None = Field(None, description="The specific model version used.")
+ content_violation: bool | None = Field(
+ None, description="Indicates whether the generated image violates the content policy."
+ )
diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py
new file mode 100644
index 000000000..608d9f058
--- /dev/null
+++ b/comfy_api_nodes/nodes_reve.py
@@ -0,0 +1,395 @@
+from io import BytesIO
+
+from typing_extensions import override
+
+from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api_nodes.apis.reve import (
+ ReveImageCreateRequest,
+ ReveImageEditRequest,
+ ReveImageRemixRequest,
+ RevePostprocessingOperation,
+)
+from comfy_api_nodes.util import (
+ ApiEndpoint,
+ bytesio_to_image_tensor,
+ sync_op_raw,
+ tensor_to_base64_string,
+ validate_string,
+)
+
+
+def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
+ ops = []
+ if upscale["upscale"] == "enabled":
+ ops.append(
+ RevePostprocessingOperation(
+ process="upscale",
+ upscale_factor=upscale["upscale_factor"],
+ )
+ )
+ if remove_background:
+ ops.append(RevePostprocessingOperation(process="remove_background"))
+ return ops or None
+
+
+def _postprocessing_inputs():
+ return [
+ IO.DynamicCombo.Input(
+ "upscale",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option(
+ "enabled",
+ [
+ IO.Int.Input(
+ "upscale_factor",
+ default=2,
+ min=2,
+ max=4,
+ step=1,
+ tooltip="Upscale factor (2x, 3x, or 4x).",
+ ),
+ ],
+ ),
+ ],
+ tooltip="Upscale the generated image. May add additional cost.",
+ ),
+ IO.Boolean.Input(
+ "remove_background",
+ default=False,
+ tooltip="Remove the background from the generated image. May add additional cost.",
+ ),
+ ]
+
+
+def _reve_price_extractor(headers: dict) -> float | None:
+ credits_used = headers.get("x-reve-credits-used")
+ if credits_used is not None:
+ return float(credits_used) / 524.48
+ return None
+
+
+def _reve_response_header_validator(headers: dict) -> None:
+ error_code = headers.get("x-reve-error-code")
+ if error_code:
+ raise ValueError(f"Reve API error: {error_code}")
+ if headers.get("x-reve-content-violation", "").lower() == "true":
+ raise ValueError("The generated image was flagged for content policy violation.")
+
+
+def _model_inputs(versions: list[str], aspect_ratios: list[str]):
+ return [
+ IO.DynamicCombo.Option(
+ version,
+ [
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=aspect_ratios,
+ tooltip="Aspect ratio of the output image.",
+ ),
+ IO.Int.Input(
+ "test_time_scaling",
+ default=1,
+ min=1,
+ max=5,
+ step=1,
+ tooltip="Higher values produce better images but cost more credits.",
+ advanced=True,
+ ),
+ ],
+ )
+ for version in versions
+ ]
+
+
+class ReveImageCreateNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ReveImageCreateNode",
+ display_name="Reve Image Create",
+ category="api node/image/Reve",
+ description="Generate images from text descriptions using Reve.",
+ inputs=[
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Text description of the desired image. Maximum 2560 characters.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=_model_inputs(
+ ["reve-create@20250915"],
+ aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
+ ),
+ tooltip="Model version to use for generation.",
+ ),
+ *_postprocessing_inputs(),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ model: dict,
+ upscale: dict,
+ remove_background: bool,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2560)
+ response = await sync_op_raw(
+ cls,
+ ApiEndpoint(
+ path="/proxy/reve/v1/image/create",
+ method="POST",
+ headers={"Accept": "image/webp"},
+ ),
+ as_binary=True,
+ price_extractor=_reve_price_extractor,
+ response_header_validator=_reve_response_header_validator,
+ data=ReveImageCreateRequest(
+ prompt=prompt,
+ aspect_ratio=model["aspect_ratio"],
+ version=model["model"],
+ test_time_scaling=model["test_time_scaling"],
+ postprocessing=_build_postprocessing(upscale, remove_background),
+ ),
+ )
+ return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
+
+
+class ReveImageEditNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ReveImageEditNode",
+ display_name="Reve Image Edit",
+ category="api node/image/Reve",
+ description="Edit images using natural language instructions with Reve.",
+ inputs=[
+ IO.Image.Input("image", tooltip="The image to edit."),
+ IO.String.Input(
+ "edit_instruction",
+ multiline=True,
+ default="",
+ tooltip="Text description of how to edit the image. Maximum 2560 characters.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=_model_inputs(
+ ["reve-edit@20250915", "reve-edit-fast@20251030"],
+ aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
+ ),
+ tooltip="Model version to use for editing.",
+ ),
+ *_postprocessing_inputs(),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=["model"],
+ ),
+ expr="""
+ (
+ $isFast := $contains(widgets.model, "fast");
+ $base := $isFast ? 0.01001 : 0.0572;
+ {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ image: Input.Image,
+ edit_instruction: str,
+ model: dict,
+ upscale: dict,
+ remove_background: bool,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_string(edit_instruction, min_length=1, max_length=2560)
+ tts = model["test_time_scaling"]
+ ar = model["aspect_ratio"]
+ response = await sync_op_raw(
+ cls,
+ ApiEndpoint(
+ path="/proxy/reve/v1/image/edit",
+ method="POST",
+ headers={"Accept": "image/webp"},
+ ),
+ as_binary=True,
+ price_extractor=_reve_price_extractor,
+ response_header_validator=_reve_response_header_validator,
+ data=ReveImageEditRequest(
+ edit_instruction=edit_instruction,
+ reference_image=tensor_to_base64_string(image),
+ aspect_ratio=ar if ar != "auto" else None,
+ version=model["model"],
+ test_time_scaling=tts if tts and tts > 1 else None,
+ postprocessing=_build_postprocessing(upscale, remove_background),
+ ),
+ )
+ return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
+
+
+class ReveImageRemixNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ReveImageRemixNode",
+ display_name="Reve Image Remix",
+ category="api node/image/Reve",
+ description="Combine reference images with text prompts to create new images using Reve.",
+ inputs=[
+ IO.Autogrow.Input(
+ "reference_images",
+ template=IO.Autogrow.TemplatePrefix(
+ IO.Image.Input("image"),
+ prefix="image_",
+ min=1,
+ max=6,
+ ),
+ ),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Text description of the desired image. "
+ "May include XML img tags to reference specific images by index, "
+ "e.g.
0,
1, etc.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=_model_inputs(
+ ["reve-remix@20250915", "reve-remix-fast@20251030"],
+ aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
+ ),
+ tooltip="Model version to use for remixing.",
+ ),
+ *_postprocessing_inputs(),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=["model"],
+ ),
+ expr="""
+ (
+ $isFast := $contains(widgets.model, "fast");
+ $base := $isFast ? 0.01001 : 0.0572;
+ {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ reference_images: IO.Autogrow.Type,
+ prompt: str,
+ model: dict,
+ upscale: dict,
+ remove_background: bool,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2560)
+ if not reference_images:
+ raise ValueError("At least one reference image is required.")
+ ref_base64_list = []
+ for key in reference_images:
+ ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
+ if len(ref_base64_list) > 6:
+ raise ValueError("Maximum 6 reference images are allowed.")
+ tts = model["test_time_scaling"]
+ ar = model["aspect_ratio"]
+ response = await sync_op_raw(
+ cls,
+ ApiEndpoint(
+ path="/proxy/reve/v1/image/remix",
+ method="POST",
+ headers={"Accept": "image/webp"},
+ ),
+ as_binary=True,
+ price_extractor=_reve_price_extractor,
+ response_header_validator=_reve_response_header_validator,
+ data=ReveImageRemixRequest(
+ prompt=prompt,
+ reference_images=ref_base64_list,
+ aspect_ratio=ar if ar != "auto" else None,
+ version=model["model"],
+ test_time_scaling=tts if tts and tts > 1 else None,
+ postprocessing=_build_postprocessing(upscale, remove_background),
+ ),
+ )
+ return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
+
+
+class ReveExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ ReveImageCreateNode,
+ ReveImageEditNode,
+ ReveImageRemixNode,
+ ]
+
+
+async def comfy_entrypoint() -> ReveExtension:
+ return ReveExtension()
diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py
index 79ffb77c1..9d730b81a 100644
--- a/comfy_api_nodes/util/client.py
+++ b/comfy_api_nodes/util/client.py
@@ -67,6 +67,7 @@ class _RequestConfig:
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
+ response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass
@@ -202,11 +203,13 @@ async def sync_op_raw(
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
+ response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON).
- If as_binary=True: returns bytes.
+ - response_header_validator: optional callback receiving response headers dict
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
@@ -232,6 +235,7 @@ async def sync_op_raw(
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
+ response_header_validator=response_header_validator,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
+ resp_headers = {k.lower(): v for k, v in resp.headers.items()}
+ if cfg.price_extractor:
+ with contextlib.suppress(Exception):
+ extracted_price = cfg.price_extractor(resp_headers)
+ if cfg.response_header_validator:
+ cfg.response_header_validator(resp_headers)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method,
request_url=url,
response_status_code=resp.status,
- response_headers=dict(resp.headers),
+ response_headers=resp_headers,
response_content=bytes_payload,
)
return bytes_payload
diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py
new file mode 100644
index 000000000..d455d08e8
--- /dev/null
+++ b/comfy_execution/cache_provider.py
@@ -0,0 +1,138 @@
+from typing import Any, Optional, Tuple, List
+import hashlib
+import json
+import logging
+import threading
+
+# Public types — source of truth is comfy_api.latest._caching
+from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
+
+_logger = logging.getLogger(__name__)
+
+
+_providers: List[CacheProvider] = []
+_providers_lock = threading.Lock()
+_providers_snapshot: Tuple[CacheProvider, ...] = ()
+
+
+def register_cache_provider(provider: CacheProvider) -> None:
+ """Register an external cache provider. Providers are called in registration order."""
+ global _providers_snapshot
+ with _providers_lock:
+ if provider in _providers:
+ _logger.warning(f"Provider {provider.__class__.__name__} already registered")
+ return
+ _providers.append(provider)
+ _providers_snapshot = tuple(_providers)
+ _logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
+
+
+def unregister_cache_provider(provider: CacheProvider) -> None:
+ global _providers_snapshot
+ with _providers_lock:
+ try:
+ _providers.remove(provider)
+ _providers_snapshot = tuple(_providers)
+ _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
+ except ValueError:
+ _logger.warning(f"Provider {provider.__class__.__name__} was not registered")
+
+
+def _get_cache_providers() -> Tuple[CacheProvider, ...]:
+ return _providers_snapshot
+
+
+def _has_cache_providers() -> bool:
+ return bool(_providers_snapshot)
+
+
+def _clear_cache_providers() -> None:
+ global _providers_snapshot
+ with _providers_lock:
+ _providers.clear()
+ _providers_snapshot = ()
+
+
+def _canonicalize(obj: Any) -> Any:
+ # Convert to canonical JSON-serializable form with deterministic ordering.
+ # Frozensets have non-deterministic iteration order between Python sessions.
+ # Raises ValueError for non-cacheable types (Unhashable, unknown) so that
+ # _serialize_cache_key returns None and external caching is skipped.
+ if isinstance(obj, frozenset):
+ return ("__frozenset__", sorted(
+ [_canonicalize(item) for item in obj],
+ key=lambda x: json.dumps(x, sort_keys=True)
+ ))
+ elif isinstance(obj, set):
+ return ("__set__", sorted(
+ [_canonicalize(item) for item in obj],
+ key=lambda x: json.dumps(x, sort_keys=True)
+ ))
+ elif isinstance(obj, tuple):
+ return ("__tuple__", [_canonicalize(item) for item in obj])
+ elif isinstance(obj, list):
+ return [_canonicalize(item) for item in obj]
+ elif isinstance(obj, dict):
+ return {"__dict__": sorted(
+ [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
+ key=lambda x: json.dumps(x, sort_keys=True)
+ )}
+ elif isinstance(obj, (int, float, str, bool, type(None))):
+ return (type(obj).__name__, obj)
+ elif isinstance(obj, bytes):
+ return ("__bytes__", obj.hex())
+ else:
+ raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
+
+
+def _serialize_cache_key(cache_key: Any) -> Optional[str]:
+ # Returns deterministic SHA256 hex digest, or None on failure.
+ # Uses JSON (not pickle) because pickle is non-deterministic across sessions.
+ try:
+ canonical = _canonicalize(cache_key)
+ json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
+ return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
+ except Exception as e:
+ _logger.warning(f"Failed to serialize cache key: {e}")
+ return None
+
+
+def _contains_self_unequal(obj: Any) -> bool:
+ # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
+ # never hit locally, but serialized form would match externally. Skip these.
+ try:
+ if not (obj == obj):
+ return True
+ except Exception:
+ return True
+ if isinstance(obj, (frozenset, tuple, list, set)):
+ return any(_contains_self_unequal(item) for item in obj)
+ if isinstance(obj, dict):
+ return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
+ if hasattr(obj, 'value'):
+ return _contains_self_unequal(obj.value)
+ return False
+
+
+def _estimate_value_size(value: CacheValue) -> int:
+ try:
+ import torch
+ except ImportError:
+ return 0
+
+ total = 0
+
+ def estimate(obj):
+ nonlocal total
+ if isinstance(obj, torch.Tensor):
+ total += obj.numel() * obj.element_size()
+ elif isinstance(obj, dict):
+ for v in obj.values():
+ estimate(v)
+ elif isinstance(obj, (list, tuple)):
+ for item in obj:
+ estimate(item)
+
+ for output in value.outputs:
+ estimate(output)
+ return total
diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py
index 326a279fc..78212bde3 100644
--- a/comfy_execution/caching.py
+++ b/comfy_execution/caching.py
@@ -1,3 +1,4 @@
+import asyncio
import bisect
import gc
import itertools
@@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet):
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache:
- def __init__(self, key_class):
+ def __init__(self, key_class, enable_providers=False):
self.key_class = key_class
self.initialized = False
+ self.enable_providers = enable_providers
self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet
self.cache = {}
self.subcaches = {}
+ self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt
@@ -196,18 +199,138 @@ class BasicCache:
def poll(self, **kwargs):
pass
- def _set_immediate(self, node_id, value):
- assert self.initialized
- cache_key = self.cache_key_set.get_data_key(node_id)
- self.cache[cache_key] = value
-
- def _get_immediate(self, node_id):
+ def get_local(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
- else:
+ return None
+
+ def set_local(self, node_id, value):
+ assert self.initialized
+ cache_key = self.cache_key_set.get_data_key(node_id)
+ self.cache[cache_key] = value
+
+ async def _set_immediate(self, node_id, value):
+ assert self.initialized
+ cache_key = self.cache_key_set.get_data_key(node_id)
+ self.cache[cache_key] = value
+
+ await self._notify_providers_store(node_id, cache_key, value)
+
+ async def _get_immediate(self, node_id):
+ if not self.initialized:
+ return None
+ cache_key = self.cache_key_set.get_data_key(node_id)
+
+ if cache_key in self.cache:
+ return self.cache[cache_key]
+
+ external_result = await self._check_providers_lookup(node_id, cache_key)
+ if external_result is not None:
+ self.cache[cache_key] = external_result
+ return external_result
+
+ return None
+
+ async def _notify_providers_store(self, node_id, cache_key, value):
+ from comfy_execution.cache_provider import (
+ _has_cache_providers, _get_cache_providers,
+ CacheValue, _contains_self_unequal, _logger
+ )
+
+ if not self.enable_providers:
+ return
+ if not _has_cache_providers():
+ return
+ if not self._is_external_cacheable_value(value):
+ return
+ if _contains_self_unequal(cache_key):
+ return
+
+ context = self._build_context(node_id, cache_key)
+ if context is None:
+ return
+ cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
+
+ for provider in _get_cache_providers():
+ try:
+ if provider.should_cache(context, cache_value):
+ task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
+ self._pending_store_tasks.add(task)
+ task.add_done_callback(self._pending_store_tasks.discard)
+ except Exception as e:
+ _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
+
+ @staticmethod
+ async def _safe_provider_store(provider, context, cache_value):
+ from comfy_execution.cache_provider import _logger
+ try:
+ await provider.on_store(context, cache_value)
+ except Exception as e:
+ _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
+
+ async def _check_providers_lookup(self, node_id, cache_key):
+ from comfy_execution.cache_provider import (
+ _has_cache_providers, _get_cache_providers,
+ CacheValue, _contains_self_unequal, _logger
+ )
+
+ if not self.enable_providers:
+ return None
+ if not _has_cache_providers():
+ return None
+ if _contains_self_unequal(cache_key):
+ return None
+
+ context = self._build_context(node_id, cache_key)
+ if context is None:
+ return None
+
+ for provider in _get_cache_providers():
+ try:
+ if not provider.should_cache(context):
+ continue
+ result = await provider.on_lookup(context)
+ if result is not None:
+ if not isinstance(result, CacheValue):
+ _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
+ continue
+ if not isinstance(result.outputs, (list, tuple)):
+ _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
+ continue
+ from execution import CacheEntry
+ return CacheEntry(ui=result.ui, outputs=list(result.outputs))
+ except Exception as e:
+ _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
+
+ return None
+
+ def _is_external_cacheable_value(self, value):
+ return hasattr(value, 'outputs') and hasattr(value, 'ui')
+
+ def _get_class_type(self, node_id):
+ if not self.initialized or not self.dynprompt:
+ return ''
+ try:
+ return self.dynprompt.get_node(node_id).get('class_type', '')
+ except Exception:
+ return ''
+
+ def _build_context(self, node_id, cache_key):
+ from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
+ try:
+ cache_key_hash = _serialize_cache_key(cache_key)
+ if cache_key_hash is None:
+ return None
+ return CacheContext(
+ node_id=node_id,
+ class_type=self._get_class_type(node_id),
+ cache_key_hash=cache_key_hash,
+ )
+ except Exception as e:
+ _logger.warning(f"Failed to build cache context for node {node_id}: {e}")
return None
async def _ensure_subcache(self, node_id, children_ids):
@@ -236,8 +359,8 @@ class BasicCache:
return result
class HierarchicalCache(BasicCache):
- def __init__(self, key_class):
- super().__init__(key_class)
+ def __init__(self, key_class, enable_providers=False):
+ super().__init__(key_class, enable_providers=enable_providers)
def _get_cache_for(self, node_id):
assert self.dynprompt is not None
@@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache):
return None
return cache
- def get(self, node_id):
+ async def get(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
- return cache._get_immediate(node_id)
+ return await cache._get_immediate(node_id)
- def set(self, node_id, value):
+ def get_local(self, node_id):
+ cache = self._get_cache_for(node_id)
+ if cache is None:
+ return None
+ return BasicCache.get_local(cache, node_id)
+
+ async def set(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
- cache._set_immediate(node_id, value)
+ await cache._set_immediate(node_id, value)
+
+ def set_local(self, node_id, value):
+ cache = self._get_cache_for(node_id)
+ assert cache is not None
+ BasicCache.set_local(cache, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id)
@@ -287,18 +421,24 @@ class NullCache:
def poll(self, **kwargs):
pass
- def get(self, node_id):
+ async def get(self, node_id):
return None
- def set(self, node_id, value):
+ def get_local(self, node_id):
+ return None
+
+ async def set(self, node_id, value):
+ pass
+
+ def set_local(self, node_id, value):
pass
async def ensure_subcache_for(self, node_id, children_ids):
return self
class LRUCache(BasicCache):
- def __init__(self, key_class, max_size=100):
- super().__init__(key_class)
+ def __init__(self, key_class, max_size=100, enable_providers=False):
+ super().__init__(key_class, enable_providers=enable_providers)
self.max_size = max_size
self.min_generation = 0
self.generation = 0
@@ -322,18 +462,18 @@ class LRUCache(BasicCache):
del self.children[key]
self._clean_subcaches()
- def get(self, node_id):
+ async def get(self, node_id):
self._mark_used(node_id)
- return self._get_immediate(node_id)
+ return await self._get_immediate(node_id)
def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None:
self.used_generation[cache_key] = self.generation
- def set(self, node_id, value):
+ async def set(self, node_id, value):
self._mark_used(node_id)
- return self._set_immediate(node_id, value)
+ return await self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
@@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
- def __init__(self, key_class):
- super().__init__(key_class, 0)
+ def __init__(self, key_class, enable_providers=False):
+ super().__init__(key_class, 0, enable_providers=enable_providers)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
- def set(self, node_id, value):
+ async def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
- super().set(node_id, value)
+ await super().set(node_id, value)
- def get(self, node_id):
+ async def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
- return super().get(node_id)
+ return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py
index 9d170b16e..c47f3c79b 100644
--- a/comfy_execution/graph.py
+++ b/comfy_execution/graph.py
@@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners = {}
def is_cached(self, node_id):
- return self.output_cache.get(node_id) is not None
+ return self.output_cache.get_local(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {}
- self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
+ self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
@@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
if value is None:
return None
#Write back to the main cache on touch.
- self.output_cache.set(from_node_id, value)
+ self.output_cache.set_local(from_node_id, value)
return value
def cache_update(self, node_id, value):
diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py
index fe9552022..3a23c7d04 100644
--- a/comfy_extras/nodes_flux.py
+++ b/comfy_extras/nodes_flux.py
@@ -6,6 +6,7 @@ import comfy.model_management
import torch
import math
import nodes
+import comfy.ldm.flux.math
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod
@@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode):
sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas)
+class KV_Attn_Input:
+ def __init__(self):
+ self.cache = {}
+
+ def __call__(self, q, k, v, extra_options, **kwargs):
+ reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
+ if len(reference_image_num_tokens) == 0:
+ return {}
+
+ ref_toks = sum(reference_image_num_tokens)
+ cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
+ if cache_key in self.cache:
+ kk, vv = self.cache[cache_key]
+ self.set_cache = False
+ return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
+
+ self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
+ self.set_cache = True
+ return {"q": q, "k": k, "v": v}
+
+ def cleanup(self):
+ self.cache = {}
+
+
+class FluxKVCache(io.ComfyNode):
+ @classmethod
+ def define_schema(cls) -> io.Schema:
+ return io.Schema(
+ node_id="FluxKVCache",
+ display_name="Flux KV Cache",
+ description="Enables KV Cache optimization for reference images on Flux family models.",
+ category="",
+ is_experimental=True,
+ inputs=[
+ io.Model.Input("model", tooltip="The model to use KV Cache on."),
+ ],
+ outputs=[
+ io.Model.Output(tooltip="The patched model with KV Cache enabled."),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, model: io.Model.Type) -> io.NodeOutput:
+ m = model.clone()
+ input_patch_obj = KV_Attn_Input()
+
+ def model_input_patch(inputs):
+ if len(input_patch_obj.cache) > 0:
+ ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
+ if ref_image_tokens > 0:
+ img = inputs["img"]
+ inputs["img"] = img[:, :-ref_image_tokens]
+ return inputs
+
+ m.set_model_attn1_patch(input_patch_obj)
+ m.set_model_post_input_patch(model_input_patch)
+ if hasattr(model.model.diffusion_model, "params"):
+ m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
+ else:
+ m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
+
+ return io.NodeOutput(m)
class FluxExtension(ComfyExtension):
@override
@@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension):
FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage,
Flux2Scheduler,
+ FluxKVCache,
]
diff --git a/comfy_extras/nodes_painter.py b/comfy_extras/nodes_painter.py
new file mode 100644
index 000000000..b9ecdf5ea
--- /dev/null
+++ b/comfy_extras/nodes_painter.py
@@ -0,0 +1,127 @@
+from __future__ import annotations
+
+import hashlib
+import os
+
+import numpy as np
+import torch
+from PIL import Image
+
+import folder_paths
+import node_helpers
+from comfy_api.latest import ComfyExtension, io, UI
+from typing_extensions import override
+
+
+def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
+ hex_color = hex_color.lstrip("#")
+ if len(hex_color) != 6:
+ return (0.0, 0.0, 0.0)
+ r = int(hex_color[0:2], 16) / 255.0
+ g = int(hex_color[2:4], 16) / 255.0
+ b = int(hex_color[4:6], 16) / 255.0
+ return (r, g, b)
+
+
+class PainterNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="Painter",
+ display_name="Painter",
+ category="image",
+ inputs=[
+ io.Image.Input(
+ "image",
+ optional=True,
+ tooltip="Optional base image to paint over",
+ ),
+ io.String.Input(
+ "mask",
+ default="",
+ socketless=True,
+ extra_dict={"widgetType": "PAINTER", "image_upload": True},
+ ),
+ io.Int.Input(
+ "width",
+ default=512,
+ min=64,
+ max=4096,
+ step=64,
+ socketless=True,
+ extra_dict={"hidden": True},
+ ),
+ io.Int.Input(
+ "height",
+ default=512,
+ min=64,
+ max=4096,
+ step=64,
+ socketless=True,
+ extra_dict={"hidden": True},
+ ),
+ io.Color.Input("bg_color", default="#000000"),
+ ],
+ outputs=[
+ io.Image.Output("IMAGE"),
+ io.Mask.Output("MASK"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
+ if image is not None:
+ base_image = image[:1]
+ h, w = base_image.shape[1], base_image.shape[2]
+ else:
+ h, w = height, width
+ r, g, b = hex_to_rgb(bg_color)
+ base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
+ base_image[0, :, :, 0] = r
+ base_image[0, :, :, 1] = g
+ base_image[0, :, :, 2] = b
+
+ if mask and mask.strip():
+ mask_path = folder_paths.get_annotated_filepath(mask)
+ painter_img = node_helpers.pillow(Image.open, mask_path)
+ painter_img = painter_img.convert("RGBA")
+
+ if painter_img.size != (w, h):
+ painter_img = painter_img.resize((w, h), Image.LANCZOS)
+
+ painter_np = np.array(painter_img).astype(np.float32) / 255.0
+ painter_rgb = painter_np[:, :, :3]
+ painter_alpha = painter_np[:, :, 3:4]
+
+ mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
+
+ base_np = base_image[0].cpu().numpy()
+ composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
+ out_image = torch.from_numpy(composited).unsqueeze(0)
+ else:
+ mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
+ out_image = base_image
+
+ return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
+
+ @classmethod
+ def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
+ if mask and mask.strip():
+ mask_path = folder_paths.get_annotated_filepath(mask)
+ if os.path.exists(mask_path):
+ m = hashlib.sha256()
+ with open(mask_path, "rb") as f:
+ m.update(f.read())
+ return m.digest().hex()
+ return ""
+
+
+
+class PainterExtension(ComfyExtension):
+ @override
+ async def get_node_list(self):
+ return [PainterNode]
+
+
+async def comfy_entrypoint():
+ return PainterExtension()
diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py
index 97b9e948d..db4f9d231 100644
--- a/comfy_extras/nodes_upscale_model.py
+++ b/comfy_extras/nodes_upscale_model.py
@@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode):
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
- except model_management.OOM_EXCEPTION as e:
+ except Exception as e:
+ model_management.raise_non_oom(e)
tile //= 2
if tile < 128:
raise e
diff --git a/comfyui_version.py b/comfyui_version.py
index 2723d02e7..701f4d66a 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.16.4"
+__version__ = "0.17.0"
diff --git a/execution.py b/execution.py
index 7ccdbf93e..1a6c3429c 100644
--- a/execution.py
+++ b/execution.py
@@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
+from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
class ExecutionResult(Enum):
@@ -126,15 +127,15 @@ class CacheSet:
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
- self.outputs = HierarchicalCache(CacheKeySetInputSignature)
+ self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
- self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
+ self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
- self.outputs = RAMPressureCache(CacheKeySetInputSignature)
+ self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
@@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
- cached = caches.outputs.get(unique_id)
+ cached = await caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_ui = cached.ui or {}
@@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
- obj = caches.objects.get(unique_id)
+ obj = await caches.objects.get(unique_id)
if obj is None:
obj = class_def()
- caches.objects.set(unique_id, obj)
+ await caches.objects.set(unique_id, obj)
if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
@@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
- caches.outputs.set(unique_id, cache_entry)
+ await caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")
@@ -612,7 +613,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.error(traceback.format_exc())
tips = ""
- if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
+ if comfy.model_management.is_oom(ex):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
@@ -684,6 +685,19 @@ class PromptExecutor:
}
self.add_message("execution_error", mes, broadcast=False)
+ def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
+ if not _has_cache_providers():
+ return
+
+ for provider in _get_cache_providers():
+ try:
+ if event == "start":
+ provider.on_prompt_start(prompt_id)
+ elif event == "end":
+ provider.on_prompt_end(prompt_id)
+ except Exception as e:
+ _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
+
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
@@ -700,66 +714,75 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
- with torch.inference_mode():
- dynamic_prompt = DynamicPrompt(prompt)
- reset_progress_state(prompt_id, dynamic_prompt)
- add_progress_handler(WebUIProgressHandler(self.server))
- is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
- for cache in self.caches.all:
- await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
- cache.clean_unused()
+ self._notify_prompt_lifecycle("start", prompt_id)
- cached_nodes = []
- for node_id in prompt:
- if self.caches.outputs.get(node_id) is not None:
- cached_nodes.append(node_id)
+ try:
+ with torch.inference_mode():
+ dynamic_prompt = DynamicPrompt(prompt)
+ reset_progress_state(prompt_id, dynamic_prompt)
+ add_progress_handler(WebUIProgressHandler(self.server))
+ is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
+ for cache in self.caches.all:
+ await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
+ cache.clean_unused()
- comfy.model_management.cleanup_models_gc()
- self.add_message("execution_cached",
- { "nodes": cached_nodes, "prompt_id": prompt_id},
- broadcast=False)
- pending_subgraph_results = {}
- pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
- ui_node_outputs = {}
- executed = set()
- execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
- current_outputs = self.caches.outputs.all_node_ids()
- for node_id in list(execute_outputs):
- execution_list.add_node(node_id)
+ node_ids = list(prompt.keys())
+ cache_results = await asyncio.gather(
+ *(self.caches.outputs.get(node_id) for node_id in node_ids)
+ )
+ cached_nodes = [
+ node_id for node_id, result in zip(node_ids, cache_results)
+ if result is not None
+ ]
- while not execution_list.is_empty():
- node_id, error, ex = await execution_list.stage_node_execution()
- if error is not None:
- self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
- break
+ comfy.model_management.cleanup_models_gc()
+ self.add_message("execution_cached",
+ { "nodes": cached_nodes, "prompt_id": prompt_id},
+ broadcast=False)
+ pending_subgraph_results = {}
+ pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
+ ui_node_outputs = {}
+ executed = set()
+ execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
+ current_outputs = self.caches.outputs.all_node_ids()
+ for node_id in list(execute_outputs):
+ execution_list.add_node(node_id)
- assert node_id is not None, "Node ID should not be None at this point"
- result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
- self.success = result != ExecutionResult.FAILURE
- if result == ExecutionResult.FAILURE:
- self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
- break
- elif result == ExecutionResult.PENDING:
- execution_list.unstage_node_execution()
- else: # result == ExecutionResult.SUCCESS:
- execution_list.complete_node_execution()
- self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
- else:
- # Only execute when the while-loop ends without break
- self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
+ while not execution_list.is_empty():
+ node_id, error, ex = await execution_list.stage_node_execution()
+ if error is not None:
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
+ break
- ui_outputs = {}
- meta_outputs = {}
- for node_id, ui_info in ui_node_outputs.items():
- ui_outputs[node_id] = ui_info["output"]
- meta_outputs[node_id] = ui_info["meta"]
- self.history_result = {
- "outputs": ui_outputs,
- "meta": meta_outputs,
- }
- self.server.last_node_id = None
- if comfy.model_management.DISABLE_SMART_MEMORY:
- comfy.model_management.unload_all_models()
+ assert node_id is not None, "Node ID should not be None at this point"
+ result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
+ self.success = result != ExecutionResult.FAILURE
+ if result == ExecutionResult.FAILURE:
+ self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
+ break
+ elif result == ExecutionResult.PENDING:
+ execution_list.unstage_node_execution()
+ else: # result == ExecutionResult.SUCCESS:
+ execution_list.complete_node_execution()
+ self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
+ else:
+ # Only execute when the while-loop ends without break
+ self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
+
+ ui_outputs = {}
+ meta_outputs = {}
+ for node_id, ui_info in ui_node_outputs.items():
+ ui_outputs[node_id] = ui_info["output"]
+ meta_outputs[node_id] = ui_info["meta"]
+ self.history_result = {
+ "outputs": ui_outputs,
+ "meta": meta_outputs,
+ }
+ self.server.last_node_id = None
+ if comfy.model_management.DISABLE_SMART_MEMORY:
+ comfy.model_management.unload_all_models()
+ finally:
+ self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated):
diff --git a/main.py b/main.py
index 1977f9362..8905fd09a 100644
--- a/main.py
+++ b/main.py
@@ -3,6 +3,7 @@ comfy.options.enable_args_parsing()
import os
import importlib.util
+import shutil
import importlib.metadata
import folder_paths
import time
@@ -11,6 +12,7 @@ from app.logger import setup_logger
import itertools
import utils.extra_config
from utils.mime_types import init_mime_types
+import faulthandler
import logging
import sys
from comfy_execution.progress import get_progress_state
@@ -25,6 +27,8 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
+faulthandler.enable(file=sys.stderr, all_threads=False)
+
import comfy_aimdo.control
if enables_dynamic_vram():
@@ -64,8 +68,15 @@ if __name__ == "__main__":
def handle_comfyui_manager_unavailable():
- if not args.windows_standalone_build:
- logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
+ manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt")
+ uv_available = shutil.which("uv") is not None
+
+ pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
+ msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}"
+ if uv_available:
+ msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}"
+ msg += "\n"
+ logging.warning(msg)
args.enable_manager = False
@@ -173,7 +184,6 @@ execute_prestartup_script()
# Main code
import asyncio
-import shutil
import threading
import gc
diff --git a/manager_requirements.txt b/manager_requirements.txt
index c420cc48e..6bcc3fb50 100644
--- a/manager_requirements.txt
+++ b/manager_requirements.txt
@@ -1 +1 @@
-comfyui_manager==4.1b1
+comfyui_manager==4.1b2
\ No newline at end of file
diff --git a/nodes.py b/nodes.py
index 0ef23b640..eb63f9d44 100644
--- a/nodes.py
+++ b/nodes.py
@@ -2450,6 +2450,7 @@ async def init_builtin_extra_nodes():
"nodes_nag.py",
"nodes_sdpose.py",
"nodes_math.py",
+ "nodes_painter.py",
]
import_failed = []
diff --git a/pyproject.toml b/pyproject.toml
index 753b219b3..e2ca79be7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.16.4"
+version = "0.17.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
diff --git a/requirements.txt b/requirements.txt
index b1db1cf24..511c62fee 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
-comfyui-frontend-package==1.39.19
-comfyui-workflow-templates==0.9.11
+comfyui-frontend-package==1.41.18
+comfyui-workflow-templates==0.9.21
comfyui-embedded-docs==0.4.3
torch
torchsde
@@ -22,8 +22,8 @@ alembic
SQLAlchemy
filelock
av>=14.2.0
-comfy-kitchen>=0.2.7
-comfy-aimdo>=0.2.9
+comfy-kitchen>=0.2.8
+comfy-aimdo>=0.2.10
requests
simpleeval>=1.0.0
blake3
diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py
new file mode 100644
index 000000000..ac3814746
--- /dev/null
+++ b/tests-unit/execution_test/test_cache_provider.py
@@ -0,0 +1,403 @@
+"""Tests for external cache provider API."""
+
+import importlib.util
+import pytest
+from typing import Optional
+
+
+def _torch_available() -> bool:
+ """Check if PyTorch is available."""
+ return importlib.util.find_spec("torch") is not None
+
+
+from comfy_execution.cache_provider import (
+ CacheProvider,
+ CacheContext,
+ CacheValue,
+ register_cache_provider,
+ unregister_cache_provider,
+ _get_cache_providers,
+ _has_cache_providers,
+ _clear_cache_providers,
+ _serialize_cache_key,
+ _contains_self_unequal,
+ _estimate_value_size,
+ _canonicalize,
+)
+
+
+class TestCanonicalize:
+ """Test _canonicalize function for deterministic ordering."""
+
+ def test_frozenset_ordering_is_deterministic(self):
+ """Frozensets should produce consistent canonical form regardless of iteration order."""
+ # Create two frozensets with same content
+ fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
+ fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
+
+ result1 = _canonicalize(fs1)
+ result2 = _canonicalize(fs2)
+
+ assert result1 == result2
+
+ def test_nested_frozenset_ordering(self):
+ """Nested frozensets should also be deterministically ordered."""
+ inner1 = frozenset([1, 2, 3])
+ inner2 = frozenset([3, 2, 1])
+
+ fs1 = frozenset([("key", inner1)])
+ fs2 = frozenset([("key", inner2)])
+
+ result1 = _canonicalize(fs1)
+ result2 = _canonicalize(fs2)
+
+ assert result1 == result2
+
+ def test_dict_ordering(self):
+ """Dicts should be sorted by key."""
+ d1 = {"z": 1, "a": 2, "m": 3}
+ d2 = {"a": 2, "m": 3, "z": 1}
+
+ result1 = _canonicalize(d1)
+ result2 = _canonicalize(d2)
+
+ assert result1 == result2
+
+ def test_tuple_preserved(self):
+ """Tuples should be marked and preserved."""
+ t = (1, 2, 3)
+ result = _canonicalize(t)
+
+ assert result[0] == "__tuple__"
+
+ def test_list_preserved(self):
+ """Lists should be recursively canonicalized."""
+ lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
+ result = _canonicalize(lst)
+
+ # First element should be canonicalized dict
+ assert "__dict__" in result[0]
+ # Second element should be canonicalized frozenset
+ assert result[1][0] == "__frozenset__"
+
+ def test_primitives_include_type(self):
+ """Primitive types should include type name for disambiguation."""
+ assert _canonicalize(42) == ("int", 42)
+ assert _canonicalize(3.14) == ("float", 3.14)
+ assert _canonicalize("hello") == ("str", "hello")
+ assert _canonicalize(True) == ("bool", True)
+ assert _canonicalize(None) == ("NoneType", None)
+
+ def test_int_and_str_distinguished(self):
+ """int 7 and str '7' must produce different canonical forms."""
+ assert _canonicalize(7) != _canonicalize("7")
+
+ def test_bytes_converted(self):
+ """Bytes should be converted to hex string."""
+ b = b"\x00\xff"
+ result = _canonicalize(b)
+
+ assert result[0] == "__bytes__"
+ assert result[1] == "00ff"
+
+ def test_set_ordering(self):
+ """Sets should be sorted like frozensets."""
+ s1 = {3, 1, 2}
+ s2 = {1, 2, 3}
+
+ result1 = _canonicalize(s1)
+ result2 = _canonicalize(s2)
+
+ assert result1 == result2
+ assert result1[0] == "__set__"
+
+ def test_unknown_type_raises(self):
+ """Unknown types should raise ValueError (fail-closed)."""
+ class CustomObj:
+ pass
+ with pytest.raises(ValueError):
+ _canonicalize(CustomObj())
+
+ def test_object_with_value_attr_raises(self):
+ """Objects with .value attribute (Unhashable-like) should raise ValueError."""
+ class FakeUnhashable:
+ def __init__(self):
+ self.value = float('nan')
+ with pytest.raises(ValueError):
+ _canonicalize(FakeUnhashable())
+
+
+class TestSerializeCacheKey:
+ """Test _serialize_cache_key for deterministic hashing."""
+
+ def test_same_content_same_hash(self):
+ """Same content should produce same hash."""
+ key1 = frozenset([("node_1", frozenset([("input", "value")]))])
+ key2 = frozenset([("node_1", frozenset([("input", "value")]))])
+
+ hash1 = _serialize_cache_key(key1)
+ hash2 = _serialize_cache_key(key2)
+
+ assert hash1 == hash2
+
+ def test_different_content_different_hash(self):
+ """Different content should produce different hash."""
+ key1 = frozenset([("node_1", "value_a")])
+ key2 = frozenset([("node_1", "value_b")])
+
+ hash1 = _serialize_cache_key(key1)
+ hash2 = _serialize_cache_key(key2)
+
+ assert hash1 != hash2
+
+ def test_returns_hex_string(self):
+ """Should return hex string (SHA256 hex digest)."""
+ key = frozenset([("test", 123)])
+ result = _serialize_cache_key(key)
+
+ assert isinstance(result, str)
+ assert len(result) == 64 # SHA256 hex digest is 64 chars
+
+ def test_complex_nested_structure(self):
+ """Complex nested structures should hash deterministically."""
+ # Note: frozensets can only contain hashable types, so we use
+ # nested frozensets of tuples to represent dict-like structures
+ key = frozenset([
+ ("node_1", frozenset([
+ ("input_a", ("tuple", "value")),
+ ("input_b", frozenset([("nested", "dict")])),
+ ])),
+ ("node_2", frozenset([
+ ("param", 42),
+ ])),
+ ])
+
+ # Hash twice to verify determinism
+ hash1 = _serialize_cache_key(key)
+ hash2 = _serialize_cache_key(key)
+
+ assert hash1 == hash2
+
+ def test_dict_in_cache_key(self):
+ """Dicts passed directly to _serialize_cache_key should work."""
+ key = {"node_1": {"input": "value"}, "node_2": 42}
+
+ hash1 = _serialize_cache_key(key)
+ hash2 = _serialize_cache_key(key)
+
+ assert hash1 == hash2
+ assert isinstance(hash1, str)
+ assert len(hash1) == 64
+
+ def test_unknown_type_returns_none(self):
+ """Non-cacheable types should return None (fail-closed)."""
+ class CustomObj:
+ pass
+ assert _serialize_cache_key(CustomObj()) is None
+
+
+class TestContainsSelfUnequal:
+ """Test _contains_self_unequal utility function."""
+
+ def test_nan_float_detected(self):
+ """NaN floats should be detected (not equal to itself)."""
+ assert _contains_self_unequal(float('nan')) is True
+
+ def test_regular_float_not_detected(self):
+ """Regular floats are equal to themselves."""
+ assert _contains_self_unequal(3.14) is False
+ assert _contains_self_unequal(0.0) is False
+ assert _contains_self_unequal(-1.5) is False
+
+ def test_infinity_not_detected(self):
+ """Infinity is equal to itself."""
+ assert _contains_self_unequal(float('inf')) is False
+ assert _contains_self_unequal(float('-inf')) is False
+
+ def test_nan_in_list(self):
+ """NaN in list should be detected."""
+ assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
+ assert _contains_self_unequal([1, 2, 3, 4]) is False
+
+ def test_nan_in_tuple(self):
+ """NaN in tuple should be detected."""
+ assert _contains_self_unequal((1, float('nan'))) is True
+ assert _contains_self_unequal((1, 2, 3)) is False
+
+ def test_nan_in_frozenset(self):
+ """NaN in frozenset should be detected."""
+ assert _contains_self_unequal(frozenset([1, float('nan')])) is True
+ assert _contains_self_unequal(frozenset([1, 2, 3])) is False
+
+ def test_nan_in_dict_value(self):
+ """NaN in dict value should be detected."""
+ assert _contains_self_unequal({"key": float('nan')}) is True
+ assert _contains_self_unequal({"key": 42}) is False
+
+ def test_nan_in_nested_structure(self):
+ """NaN in deeply nested structure should be detected."""
+ nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
+ assert _contains_self_unequal(nested) is True
+
+ def test_non_numeric_types(self):
+ """Non-numeric types should not be self-unequal."""
+ assert _contains_self_unequal("string") is False
+ assert _contains_self_unequal(None) is False
+ assert _contains_self_unequal(True) is False
+
+ def test_object_with_nan_value_attribute(self):
+ """Objects wrapping NaN in .value should be detected."""
+ class NanWrapper:
+ def __init__(self):
+ self.value = float('nan')
+ assert _contains_self_unequal(NanWrapper()) is True
+
+ def test_custom_self_unequal_object(self):
+ """Custom objects where not (x == x) should be detected."""
+ class NeverEqual:
+ def __eq__(self, other):
+ return False
+ assert _contains_self_unequal(NeverEqual()) is True
+
+
+class TestEstimateValueSize:
+ """Test _estimate_value_size utility function."""
+
+ def test_empty_outputs(self):
+ """Empty outputs should have zero size."""
+ value = CacheValue(outputs=[])
+ assert _estimate_value_size(value) == 0
+
+ @pytest.mark.skipif(
+ not _torch_available(),
+ reason="PyTorch not available"
+ )
+ def test_tensor_size_estimation(self):
+ """Tensor size should be estimated correctly."""
+ import torch
+
+ # 1000 float32 elements = 4000 bytes
+ tensor = torch.zeros(1000, dtype=torch.float32)
+ value = CacheValue(outputs=[[tensor]])
+
+ size = _estimate_value_size(value)
+ assert size == 4000
+
+ @pytest.mark.skipif(
+ not _torch_available(),
+ reason="PyTorch not available"
+ )
+ def test_nested_tensor_in_dict(self):
+ """Tensors nested in dicts should be counted."""
+ import torch
+
+ tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
+ value = CacheValue(outputs=[[{"samples": tensor}]])
+
+ size = _estimate_value_size(value)
+ assert size == 400
+
+
+class TestProviderRegistry:
+ """Test cache provider registration and retrieval."""
+
+ def setup_method(self):
+ """Clear providers before each test."""
+ _clear_cache_providers()
+
+ def teardown_method(self):
+ """Clear providers after each test."""
+ _clear_cache_providers()
+
+ def test_register_provider(self):
+ """Provider should be registered successfully."""
+ provider = MockCacheProvider()
+ register_cache_provider(provider)
+
+ assert _has_cache_providers() is True
+ providers = _get_cache_providers()
+ assert len(providers) == 1
+ assert providers[0] is provider
+
+ def test_unregister_provider(self):
+ """Provider should be unregistered successfully."""
+ provider = MockCacheProvider()
+ register_cache_provider(provider)
+ unregister_cache_provider(provider)
+
+ assert _has_cache_providers() is False
+
+ def test_multiple_providers(self):
+ """Multiple providers can be registered."""
+ provider1 = MockCacheProvider()
+ provider2 = MockCacheProvider()
+
+ register_cache_provider(provider1)
+ register_cache_provider(provider2)
+
+ providers = _get_cache_providers()
+ assert len(providers) == 2
+
+ def test_duplicate_registration_ignored(self):
+ """Registering same provider twice should be ignored."""
+ provider = MockCacheProvider()
+
+ register_cache_provider(provider)
+ register_cache_provider(provider) # Should be ignored
+
+ providers = _get_cache_providers()
+ assert len(providers) == 1
+
+ def test_clear_providers(self):
+ """_clear_cache_providers should remove all providers."""
+ provider1 = MockCacheProvider()
+ provider2 = MockCacheProvider()
+
+ register_cache_provider(provider1)
+ register_cache_provider(provider2)
+ _clear_cache_providers()
+
+ assert _has_cache_providers() is False
+ assert len(_get_cache_providers()) == 0
+
+
+class TestCacheContext:
+ """Test CacheContext dataclass."""
+
+ def test_context_creation(self):
+ """CacheContext should be created with all fields."""
+ context = CacheContext(
+ node_id="node-456",
+ class_type="KSampler",
+ cache_key_hash="a" * 64,
+ )
+
+ assert context.node_id == "node-456"
+ assert context.class_type == "KSampler"
+ assert context.cache_key_hash == "a" * 64
+
+
+class TestCacheValue:
+ """Test CacheValue dataclass."""
+
+ def test_value_creation(self):
+ """CacheValue should be created with outputs."""
+ outputs = [[{"samples": "tensor_data"}]]
+ value = CacheValue(outputs=outputs)
+
+ assert value.outputs == outputs
+
+
+class MockCacheProvider(CacheProvider):
+ """Mock cache provider for testing."""
+
+ def __init__(self):
+ self.lookups = []
+ self.stores = []
+
+ async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
+ self.lookups.append(context)
+ return None
+
+ async def on_store(self, context: CacheContext, value: CacheValue) -> None:
+ self.stores.append((context, value))