diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 92b1acbd5..57126fa4a 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict): """COMBO type only. Specifies the configuration for a multi-select widget. Available after ComfyUI frontend v1.13.4 https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987""" - gradient_stops: NotRequired[list[list[float]]] - """Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``).""" + gradient_stops: NotRequired[list[dict]] + """Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}.""" class HiddenInputTypeDict(TypedDict): diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8b3f500d7..e28d704b4 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor * m_mult else: for d in modulation_dims: - tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] + tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1] if m_add is not None: - tensor[:, d[0]:d[1]] += m_add[:, d[2]] + tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1] return tensor @@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module): del txt_k, img_k v = torch.cat((txt_v, img_v), dim=2) del txt_v, img_v + + extra_options["img_slice"] = [txt.shape[1], q.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # run actual attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v if "attn1_output_patch" in transformer_patches: - extra_options["img_slice"] = [txt.shape[1], attn.shape[1]] patch = transformer_patches["attn1_output_patch"] for p in patch: attn = p(attn, extra_options) @@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module): del qkv q, k = self.norm(q, k, v) + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 5e764bb46..824daf5e6 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def _apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]: + freqs_cis = freqs_cis[:, :, :x_.shape[2]] x_out = freqs_cis[..., 0] * x_[..., 0] x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4dcf7c5..8e7912e6d 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -44,6 +44,22 @@ class FluxParams: txt_norm: bool = False +def invert_slices(slices, length): + sorted_slices = sorted(slices) + result = [] + current = 0 + + for start, end in sorted_slices: + if current < start: + result.append((current, start)) + current = max(current, end) + + if current < length: + result.append((current, length)) + + return result + + class Flux(nn.Module): """ Transformer model for flow matching on sequences. @@ -138,6 +154,7 @@ class Flux(nn.Module): y: Tensor, guidance: Tensor = None, control = None, + timestep_zero_index=None, transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: @@ -164,13 +181,9 @@ class Flux(nn.Module): txt = self.txt_norm(txt) txt = self.txt_in(txt) - vec_orig = vec - if self.params.global_modulation: - vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) - if "post_input" in patches: for p in patches["post_input"]: - out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) img = out["img"] txt = out["txt"] img_ids = out["img_ids"] @@ -182,6 +195,24 @@ class Flux(nn.Module): else: pe = None + vec_orig = vec + txt_vec = vec + extra_kwargs = {} + if timestep_zero_index is not None: + modulation_dims = [] + batch = vec.shape[0] // 2 + vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1) + invert = invert_slices(timestep_zero_index, img.shape[1]) + for s in invert: + modulation_dims.append((s[0], s[1], 0)) + for s in timestep_zero_index: + modulation_dims.append((s[0], s[1], 1)) + extra_kwargs["modulation_dims_img"] = modulation_dims + txt_vec = vec[:batch] + + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec)) + blocks_replace = patches_replace.get("dit", {}) transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" @@ -195,7 +226,8 @@ class Flux(nn.Module): vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) + transformer_options=args.get("transformer_options"), + **extra_kwargs) return out out = blocks_replace[("double_block", i)]({"img": img, @@ -213,7 +245,8 @@ class Flux(nn.Module): vec=vec, pe=pe, attn_mask=attn_mask, - transformer_options=transformer_options) + transformer_options=transformer_options, + **extra_kwargs) if control is not None: # Controlnet control_i = control.get("input") @@ -230,6 +263,12 @@ class Flux(nn.Module): if self.params.global_modulation: vec, _ = self.single_stream_modulation(vec_orig) + extra_kwargs = {} + if timestep_zero_index is not None: + lambda a: 0 if a == 0 else a + txt.shape[1] + modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims)) + extra_kwargs["modulation_dims"] = modulation_dims_combined + transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] @@ -242,7 +281,8 @@ class Flux(nn.Module): vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) + transformer_options=args.get("transformer_options"), + **extra_kwargs) return out out = blocks_replace[("single_block", i)]({"img": img, @@ -253,7 +293,7 @@ class Flux(nn.Module): {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs) if control is not None: # Controlnet control_o = control.get("output") @@ -264,7 +304,11 @@ class Flux(nn.Module): img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) + extra_kwargs = {} + if timestep_zero_index is not None: + extra_kwargs["modulation_dims"] = modulation_dims + + img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels) return img def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): @@ -312,13 +356,16 @@ class Flux(nn.Module): w_len = ((w_orig + (patch_size // 2)) // patch_size) img, img_ids = self.process_img(x, transformer_options=transformer_options) img_tokens = img.shape[1] + timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) + timestep_zero = ref_latents_method == "index_timestep_zero" for ref in ref_latents: - if ref_latents_method == "index": + if ref_latents_method in ("index", "index_timestep_zero"): index += self.params.ref_index_scale h_offset = 0 w_offset = 0 @@ -342,6 +389,13 @@ class Flux(nn.Module): kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) + if timestep_zero: + if index > 0: + timestep = torch.cat([timestep, timestep * 0], dim=0) + timestep_zero_index = [[img_tokens, img_ids.shape[1]]] + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) @@ -349,6 +403,6 @@ class Flux(nn.Module): for i in self.params.txt_ids_dims: txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) - out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 10d051325..b193fe5e8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) if first_op_done == False: model_management.soft_empty_cache(True) if cleared_cache == False: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 805592aa5..fcbaa074f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -258,7 +258,8 @@ def slice_attention(q, k, v): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) model_management.soft_empty_cache(True) steps *= 2 if steps > 128: @@ -314,7 +315,8 @@ def pytorch_attention(q, k, v): try: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") oom_fallback = True if oom_fallback: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fab145f1c..f982afc2b 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined torch.exp(attn_scores, out=attn_scores) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 6eb744286..0862f72f7 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -149,6 +149,9 @@ class Attention(nn.Module): seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + # Project and reshape to BHND format (batch, heads, seq, dim) img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() @@ -167,15 +170,22 @@ class Attention(nn.Module): joint_key = torch.cat([txt_key, img_key], dim=2) joint_value = torch.cat([txt_value, img_value], dim=2) - joint_query = apply_rope1(joint_query, image_rotary_emb) - joint_key = apply_rope1(joint_key, image_rotary_emb) - if encoder_hidden_states_mask is not None: attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask else: attn_mask = None + extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options) + joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask) + + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attn_mask, transformer_options=transformer_options, skip_reshape=True) @@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module): timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 @@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module): kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) if timestep_zero: if index > 0: timestep = torch.cat([timestep, timestep * 0], dim=0) timestep_zero_index = num_embeds + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() - del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + del ids, txt_ids, img_ids + transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): diff --git a/comfy/lora.py b/comfy/lora.py index f36ddb046..63ee85323 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models + if tp > 0 and not k.startswith("clip_"): + key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6eace4628..35a6822e3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,4 +1,5 @@ import json +import comfy.memory_management import comfy.supported_models import comfy.supported_models_base import comfy.utils @@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): new[:old_weight.shape[0]] = old_weight old_weight = new + if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled: + old_weight = old_weight.clone() + w = old_weight.narrow(offset[0], offset[1], offset[2]) else: + if comfy.memory_management.aimdo_enabled: + weight = weight.clone() old_weight = weight w = weight w[:] = fun(weight) diff --git a/comfy/model_management.py b/comfy/model_management.py index 07bc8ad67..81c89b180 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,6 +270,23 @@ try: except: OOM_EXCEPTION = Exception +try: + ACCELERATOR_ERROR = torch.AcceleratorError +except AttributeError: + ACCELERATOR_ERROR = RuntimeError + +def is_oom(e): + if isinstance(e, OOM_EXCEPTION): + return True + if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()): + discard_cuda_async_error() + return True + return False + +def raise_non_oom(e): + if not is_oom(e): + raise e + XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True if args.disable_xformers: @@ -1263,7 +1280,7 @@ def discard_cuda_async_error(): b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) _ = a + b synchronize() - except torch.AcceleratorError: + except RuntimeError: #Dump it! We already know about it from the synchronous return pass diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 745384271..bc3a8f446 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -599,6 +599,27 @@ class ModelPatcher: return models + def model_patches_call_function(self, function_name="cleanup", arguments={}): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], function_name): + getattr(patch_list[i], function_name)(**arguments) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], function_name): + getattr(patch_list[k], function_name)(**arguments) + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, function_name): + getattr(wrap_func, function_name)(**arguments) + def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() @@ -1062,6 +1083,7 @@ class ModelPatcher: return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) def cleanup(self): + self.model_patches_call_function(function_name="cleanup") self.clean_hooks() if hasattr(self.model, "current_patcher"): self.model.current_patcher = None diff --git a/comfy/sd.py b/comfy/sd.py index 888ef1e77..adcd67767 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -954,7 +954,8 @@ class VAE: if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. @@ -1029,7 +1030,8 @@ class VAE: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples[x:x + batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index f2399422b..04973fea0 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase): super().__init__() self.node_replacement = self.NodeReplacement() self.execution = self.Execution() + self.caching = self.Caching() class NodeReplacement(ProxiedSingleton): async def register(self, node_replace: io.NodeReplace) -> None: @@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase): image=to_display, ) + class Caching(ProxiedSingleton): + """ + External cache provider API for sharing cached node outputs + across ComfyUI instances. + + Example:: + + from comfy_api.latest import Caching + + class MyCacheProvider(Caching.CacheProvider): + async def on_lookup(self, context): + ... # check external storage + + async def on_store(self, context, value): + ... # store to external storage + + Caching.register_provider(MyCacheProvider()) + """ + from ._caching import CacheProvider, CacheContext, CacheValue + + async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Register an external cache provider. Providers are called in registration order.""" + from comfy_execution.cache_provider import register_cache_provider + register_cache_provider(provider) + + async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Unregister a previously registered cache provider.""" + from comfy_execution.cache_provider import unregister_cache_provider + unregister_cache_provider(provider) + class ComfyExtension(ABC): async def on_load(self) -> None: """ @@ -116,6 +147,9 @@ class Types: VOXEL = VOXEL File3D = File3D + +Caching = ComfyAPI_latest.Caching + ComfyAPI = ComfyAPI_latest # Create a synchronous version of the API @@ -135,6 +169,7 @@ __all__ = [ "Input", "InputImpl", "Types", + "Caching", "ComfyExtension", "io", "IO", diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py new file mode 100644 index 000000000..30c8848cd --- /dev/null +++ b/comfy_api/latest/_caching.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class CacheContext: + node_id: str + class_type: str + cache_key_hash: str # SHA256 hex digest + + +@dataclass +class CacheValue: + outputs: list + ui: dict = None + + +class CacheProvider(ABC): + """Abstract base class for external cache providers. + Exceptions from provider methods are caught by the caller and never break execution. + """ + + @abstractmethod + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + """Called on local cache miss. Return CacheValue if found, None otherwise.""" + pass + + @abstractmethod + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + """Called after local store. Dispatched via asyncio.create_task.""" + pass + + def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: + """Return False to skip external caching for this node. Default: True.""" + return True + + def on_prompt_start(self, prompt_id: str) -> None: + pass + + def on_prompt_end(self, prompt_id: str) -> None: + pass diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 58a37c9e8..1b4993aa7 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -272,7 +272,7 @@ class VideoFromFile(VideoInput): has_first_frame = False for frame in frames: offset_seconds = start_time - frame.pts * audio_stream.time_base - to_skip = int(offset_seconds * audio_stream.sample_rate) + to_skip = max(0, int(offset_seconds * audio_stream.sample_rate)) if to_skip < frame.samples: has_first_frame = True break @@ -280,7 +280,7 @@ class VideoFromFile(VideoInput): audio_frames.append(frame.to_ndarray()[..., to_skip:]) for frame in frames: - if frame.time > start_time + self.__duration: + if self.__duration and frame.time > start_time + self.__duration: break audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) if len(audio_frames) > 0: diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 050031dc0..7ca8f4e0c 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -297,7 +297,7 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None, + display_mode: NumberDisplay=None, gradient_stops: list[dict]=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min diff --git a/comfy_api_nodes/apis/reve.py b/comfy_api_nodes/apis/reve.py new file mode 100644 index 000000000..c6b5a69d8 --- /dev/null +++ b/comfy_api_nodes/apis/reve.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel, Field + + +class RevePostprocessingOperation(BaseModel): + process: str = Field(..., description="The postprocessing operation: upscale or remove_background.") + upscale_factor: int | None = Field( + None, + description="Upscale factor (2, 3, or 4). Only used when process is upscale.", + ge=2, + le=4, + ) + + +class ReveImageCreateRequest(BaseModel): + prompt: str = Field(...) + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageEditRequest(BaseModel): + edit_instruction: str = Field(...) + reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageRemixRequest(BaseModel): + prompt: str = Field(...) + reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageResponse(BaseModel): + image: str | None = Field(None, description="The base64 encoded image data.") + request_id: str | None = Field(None, description="A unique id for the request.") + credits_used: float | None = Field(None, description="The number of credits used for this request.") + version: str | None = Field(None, description="The specific model version used.") + content_violation: bool | None = Field( + None, description="Indicates whether the generated image violates the content policy." + ) diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py new file mode 100644 index 000000000..608d9f058 --- /dev/null +++ b/comfy_api_nodes/nodes_reve.py @@ -0,0 +1,395 @@ +from io import BytesIO + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.reve import ( + ReveImageCreateRequest, + ReveImageEditRequest, + ReveImageRemixRequest, + RevePostprocessingOperation, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + bytesio_to_image_tensor, + sync_op_raw, + tensor_to_base64_string, + validate_string, +) + + +def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None: + ops = [] + if upscale["upscale"] == "enabled": + ops.append( + RevePostprocessingOperation( + process="upscale", + upscale_factor=upscale["upscale_factor"], + ) + ) + if remove_background: + ops.append(RevePostprocessingOperation(process="remove_background")) + return ops or None + + +def _postprocessing_inputs(): + return [ + IO.DynamicCombo.Input( + "upscale", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Int.Input( + "upscale_factor", + default=2, + min=2, + max=4, + step=1, + tooltip="Upscale factor (2x, 3x, or 4x).", + ), + ], + ), + ], + tooltip="Upscale the generated image. May add additional cost.", + ), + IO.Boolean.Input( + "remove_background", + default=False, + tooltip="Remove the background from the generated image. May add additional cost.", + ), + ] + + +def _reve_price_extractor(headers: dict) -> float | None: + credits_used = headers.get("x-reve-credits-used") + if credits_used is not None: + return float(credits_used) / 524.48 + return None + + +def _reve_response_header_validator(headers: dict) -> None: + error_code = headers.get("x-reve-error-code") + if error_code: + raise ValueError(f"Reve API error: {error_code}") + if headers.get("x-reve-content-violation", "").lower() == "true": + raise ValueError("The generated image was flagged for content policy violation.") + + +def _model_inputs(versions: list[str], aspect_ratios: list[str]): + return [ + IO.DynamicCombo.Option( + version, + [ + IO.Combo.Input( + "aspect_ratio", + options=aspect_ratios, + tooltip="Aspect ratio of the output image.", + ), + IO.Int.Input( + "test_time_scaling", + default=1, + min=1, + max=5, + step=1, + tooltip="Higher values produce better images but cost more credits.", + advanced=True, + ), + ], + ) + for version in versions + ] + + +class ReveImageCreateNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageCreateNode", + display_name="Reve Image Create", + category="api node/image/Reve", + description="Generate images from text descriptions using Reve.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-create@20250915"], + aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for generation.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""", + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/create", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageCreateRequest( + prompt=prompt, + aspect_ratio=model["aspect_ratio"], + version=model["model"], + test_time_scaling=model["test_time_scaling"], + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageEditNode", + display_name="Reve Image Edit", + category="api node/image/Reve", + description="Edit images using natural language instructions with Reve.", + inputs=[ + IO.Image.Input("image", tooltip="The image to edit."), + IO.String.Input( + "edit_instruction", + multiline=True, + default="", + tooltip="Text description of how to edit the image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-edit@20250915", "reve-edit-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for editing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model"], + ), + expr=""" + ( + $isFast := $contains(widgets.model, "fast"); + $base := $isFast ? 0.01001 : 0.0572; + {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + edit_instruction: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(edit_instruction, min_length=1, max_length=2560) + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/edit", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageEditRequest( + edit_instruction=edit_instruction, + reference_image=tensor_to_base64_string(image), + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageRemixNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageRemixNode", + display_name="Reve Image Remix", + category="api node/image/Reve", + description="Combine reference images with text prompts to create new images using Reve.", + inputs=[ + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="image_", + min=1, + max=6, + ), + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. " + "May include XML img tags to reference specific images by index, " + "e.g. 0, 1, etc.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-remix@20250915", "reve-remix-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for remixing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model"], + ), + expr=""" + ( + $isFast := $contains(widgets.model, "fast"); + $base := $isFast ? 0.01001 : 0.0572; + {"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + reference_images: IO.Autogrow.Type, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + if not reference_images: + raise ValueError("At least one reference image is required.") + ref_base64_list = [] + for key in reference_images: + ref_base64_list.append(tensor_to_base64_string(reference_images[key])) + if len(ref_base64_list) > 6: + raise ValueError("Maximum 6 reference images are allowed.") + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/remix", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageRemixRequest( + prompt=prompt, + reference_images=ref_base64_list, + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ReveImageCreateNode, + ReveImageEditNode, + ReveImageRemixNode, + ] + + +async def comfy_entrypoint() -> ReveExtension: + return ReveExtension() diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 79ffb77c1..9d730b81a 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -67,6 +67,7 @@ class _RequestConfig: progress_origin_ts: float | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None is_rate_limited: Callable[[int, Any], bool] | None = None + response_header_validator: Callable[[dict[str, str]], None] | None = None @dataclass @@ -202,11 +203,13 @@ async def sync_op_raw( monitor_progress: bool = True, max_retries_on_rate_limit: int = 16, is_rate_limited: Callable[[int, Any], bool] | None = None, + response_header_validator: Callable[[dict[str, str]], None] | None = None, ) -> dict[str, Any] | bytes: """ Make a single network request. - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). - If as_binary=True: returns bytes. + - response_header_validator: optional callback receiving response headers dict """ if isinstance(data, BaseModel): data = data.model_dump(exclude_none=True) @@ -232,6 +235,7 @@ async def sync_op_raw( price_extractor=price_extractor, max_retries_on_rate_limit=max_retries_on_rate_limit, is_rate_limited=is_rate_limited, + response_header_validator=response_header_validator, ) return await _request_base(cfg, expect_binary=as_binary) @@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total ) bytes_payload = bytes(buff) + resp_headers = {k.lower(): v for k, v in resp.headers.items()} + if cfg.price_extractor: + with contextlib.suppress(Exception): + extracted_price = cfg.price_extractor(resp_headers) + if cfg.response_header_validator: + cfg.response_header_validator(resp_headers) operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) request_logger.log_request_response( @@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): request_method=method, request_url=url, response_status_code=resp.status, - response_headers=dict(resp.headers), + response_headers=resp_headers, response_content=bytes_payload, ) return bytes_payload diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py new file mode 100644 index 000000000..d455d08e8 --- /dev/null +++ b/comfy_execution/cache_provider.py @@ -0,0 +1,138 @@ +from typing import Any, Optional, Tuple, List +import hashlib +import json +import logging +import threading + +# Public types — source of truth is comfy_api.latest._caching +from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported) + +_logger = logging.getLogger(__name__) + + +_providers: List[CacheProvider] = [] +_providers_lock = threading.Lock() +_providers_snapshot: Tuple[CacheProvider, ...] = () + + +def register_cache_provider(provider: CacheProvider) -> None: + """Register an external cache provider. Providers are called in registration order.""" + global _providers_snapshot + with _providers_lock: + if provider in _providers: + _logger.warning(f"Provider {provider.__class__.__name__} already registered") + return + _providers.append(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Registered cache provider: {provider.__class__.__name__}") + + +def unregister_cache_provider(provider: CacheProvider) -> None: + global _providers_snapshot + with _providers_lock: + try: + _providers.remove(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}") + except ValueError: + _logger.warning(f"Provider {provider.__class__.__name__} was not registered") + + +def _get_cache_providers() -> Tuple[CacheProvider, ...]: + return _providers_snapshot + + +def _has_cache_providers() -> bool: + return bool(_providers_snapshot) + + +def _clear_cache_providers() -> None: + global _providers_snapshot + with _providers_lock: + _providers.clear() + _providers_snapshot = () + + +def _canonicalize(obj: Any) -> Any: + # Convert to canonical JSON-serializable form with deterministic ordering. + # Frozensets have non-deterministic iteration order between Python sessions. + # Raises ValueError for non-cacheable types (Unhashable, unknown) so that + # _serialize_cache_key returns None and external caching is skipped. + if isinstance(obj, frozenset): + return ("__frozenset__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, set): + return ("__set__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, tuple): + return ("__tuple__", [_canonicalize(item) for item in obj]) + elif isinstance(obj, list): + return [_canonicalize(item) for item in obj] + elif isinstance(obj, dict): + return {"__dict__": sorted( + [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], + key=lambda x: json.dumps(x, sort_keys=True) + )} + elif isinstance(obj, (int, float, str, bool, type(None))): + return (type(obj).__name__, obj) + elif isinstance(obj, bytes): + return ("__bytes__", obj.hex()) + else: + raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") + + +def _serialize_cache_key(cache_key: Any) -> Optional[str]: + # Returns deterministic SHA256 hex digest, or None on failure. + # Uses JSON (not pickle) because pickle is non-deterministic across sessions. + try: + canonical = _canonicalize(cache_key) + json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) + return hashlib.sha256(json_str.encode('utf-8')).hexdigest() + except Exception as e: + _logger.warning(f"Failed to serialize cache key: {e}") + return None + + +def _contains_self_unequal(obj: Any) -> bool: + # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will + # never hit locally, but serialized form would match externally. Skip these. + try: + if not (obj == obj): + return True + except Exception: + return True + if isinstance(obj, (frozenset, tuple, list, set)): + return any(_contains_self_unequal(item) for item in obj) + if isinstance(obj, dict): + return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) + if hasattr(obj, 'value'): + return _contains_self_unequal(obj.value) + return False + + +def _estimate_value_size(value: CacheValue) -> int: + try: + import torch + except ImportError: + return 0 + + total = 0 + + def estimate(obj): + nonlocal total + if isinstance(obj, torch.Tensor): + total += obj.numel() * obj.element_size() + elif isinstance(obj, dict): + for v in obj.values(): + estimate(v) + elif isinstance(obj, (list, tuple)): + for item in obj: + estimate(item) + + for output in value.outputs: + estimate(output) + return total diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..78212bde3 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,3 +1,4 @@ +import asyncio import bisect import gc import itertools @@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet): self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) class BasicCache: - def __init__(self, key_class): + def __init__(self, key_class, enable_providers=False): self.key_class = key_class self.initialized = False + self.enable_providers = enable_providers self.dynprompt: DynamicPrompt self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} + self._pending_store_tasks: set = set() async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt @@ -196,18 +199,138 @@ class BasicCache: def poll(self, **kwargs): pass - def _set_immediate(self, node_id, value): - assert self.initialized - cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value - - def _get_immediate(self, node_id): + def get_local(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: return self.cache[cache_key] - else: + return None + + def set_local(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + async def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + await self._notify_providers_store(node_id, cache_key, value) + + async def _get_immediate(self, node_id): + if not self.initialized: + return None + cache_key = self.cache_key_set.get_data_key(node_id) + + if cache_key in self.cache: + return self.cache[cache_key] + + external_result = await self._check_providers_lookup(node_id, cache_key) + if external_result is not None: + self.cache[cache_key] = external_result + return external_result + + return None + + async def _notify_providers_store(self, node_id, cache_key, value): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not self.enable_providers: + return + if not _has_cache_providers(): + return + if not self._is_external_cacheable_value(value): + return + if _contains_self_unequal(cache_key): + return + + context = self._build_context(node_id, cache_key) + if context is None: + return + cache_value = CacheValue(outputs=value.outputs, ui=value.ui) + + for provider in _get_cache_providers(): + try: + if provider.should_cache(context, cache_value): + task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value)) + self._pending_store_tasks.add(task) + task.add_done_callback(self._pending_store_tasks.discard) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + + @staticmethod + async def _safe_provider_store(provider, context, cache_value): + from comfy_execution.cache_provider import _logger + try: + await provider.on_store(context, cache_value) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") + + async def _check_providers_lookup(self, node_id, cache_key): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not self.enable_providers: + return None + if not _has_cache_providers(): + return None + if _contains_self_unequal(cache_key): + return None + + context = self._build_context(node_id, cache_key) + if context is None: + return None + + for provider in _get_cache_providers(): + try: + if not provider.should_cache(context): + continue + result = await provider.on_lookup(context) + if result is not None: + if not isinstance(result, CacheValue): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") + continue + if not isinstance(result.outputs, (list, tuple)): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") + continue + from execution import CacheEntry + return CacheEntry(ui=result.ui, outputs=list(result.outputs)) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") + + return None + + def _is_external_cacheable_value(self, value): + return hasattr(value, 'outputs') and hasattr(value, 'ui') + + def _get_class_type(self, node_id): + if not self.initialized or not self.dynprompt: + return '' + try: + return self.dynprompt.get_node(node_id).get('class_type', '') + except Exception: + return '' + + def _build_context(self, node_id, cache_key): + from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger + try: + cache_key_hash = _serialize_cache_key(cache_key) + if cache_key_hash is None: + return None + return CacheContext( + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key_hash=cache_key_hash, + ) + except Exception as e: + _logger.warning(f"Failed to build cache context for node {node_id}: {e}") return None async def _ensure_subcache(self, node_id, children_ids): @@ -236,8 +359,8 @@ class BasicCache: return result class HierarchicalCache(BasicCache): - def __init__(self, key_class): - super().__init__(key_class) + def __init__(self, key_class, enable_providers=False): + super().__init__(key_class, enable_providers=enable_providers) def _get_cache_for(self, node_id): assert self.dynprompt is not None @@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache): return None return cache - def get(self, node_id): + async def get(self, node_id): cache = self._get_cache_for(node_id) if cache is None: return None - return cache._get_immediate(node_id) + return await cache._get_immediate(node_id) - def set(self, node_id, value): + def get_local(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return BasicCache.get_local(cache, node_id) + + async def set(self, node_id, value): cache = self._get_cache_for(node_id) assert cache is not None - cache._set_immediate(node_id, value) + await cache._set_immediate(node_id, value) + + def set_local(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + BasicCache.set_local(cache, node_id, value) async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) @@ -287,18 +421,24 @@ class NullCache: def poll(self, **kwargs): pass - def get(self, node_id): + async def get(self, node_id): return None - def set(self, node_id, value): + def get_local(self, node_id): + return None + + async def set(self, node_id, value): + pass + + def set_local(self, node_id, value): pass async def ensure_subcache_for(self, node_id, children_ids): return self class LRUCache(BasicCache): - def __init__(self, key_class, max_size=100): - super().__init__(key_class) + def __init__(self, key_class, max_size=100, enable_providers=False): + super().__init__(key_class, enable_providers=enable_providers) self.max_size = max_size self.min_generation = 0 self.generation = 0 @@ -322,18 +462,18 @@ class LRUCache(BasicCache): del self.children[key] self._clean_subcaches() - def get(self, node_id): + async def get(self, node_id): self._mark_used(node_id) - return self._get_immediate(node_id) + return await self._get_immediate(node_id) def _mark_used(self, node_id): cache_key = self.cache_key_set.get_data_key(node_id) if cache_key is not None: self.used_generation[cache_key] = self.generation - def set(self, node_id, value): + async def set(self, node_id, value): self._mark_used(node_id) - return self._set_immediate(node_id, value) + return await self._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes @@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 class RAMPressureCache(LRUCache): - def __init__(self, key_class): - super().__init__(key_class, 0) + def __init__(self, key_class, enable_providers=False): + super().__init__(key_class, 0, enable_providers=enable_providers) self.timestamps = {} def clean_unused(self): self._clean_subcaches() - def set(self, node_id, value): + async def set(self, node_id, value): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - super().set(node_id, value) + await super().set(node_id, value) - def get(self, node_id): + async def get(self, node_id): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - return super().get(node_id) + return await super().get(node_id) def poll(self, ram_headroom): def _ram_gb(): diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 9d170b16e..c47f3c79b 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners = {} def is_cached(self, node_id): - return self.output_cache.get(node_id) is not None + return self.output_cache.get_local(node_id) is not None def cache_link(self, from_node_id, to_node_id): if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} - self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id) if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) @@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort): if value is None: return None #Write back to the main cache on touch. - self.output_cache.set(from_node_id, value) + self.output_cache.set_local(from_node_id, value) return value def cache_update(self, node_id, value): diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index fe9552022..3a23c7d04 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -6,6 +6,7 @@ import comfy.model_management import torch import math import nodes +import comfy.ldm.flux.math class CLIPTextEncodeFlux(io.ComfyNode): @classmethod @@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode): sigmas = get_schedule(steps, round(seq_len)) return io.NodeOutput(sigmas) +class KV_Attn_Input: + def __init__(self): + self.cache = {} + + def __call__(self, q, k, v, extra_options, **kwargs): + reference_image_num_tokens = extra_options.get("reference_image_num_tokens", []) + if len(reference_image_num_tokens) == 0: + return {} + + ref_toks = sum(reference_image_num_tokens) + cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"]) + if cache_key in self.cache: + kk, vv = self.cache[cache_key] + self.set_cache = False + return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)} + + self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone()) + self.set_cache = True + return {"q": q, "k": k, "v": v} + + def cleanup(self): + self.cache = {} + + +class FluxKVCache(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="FluxKVCache", + display_name="Flux KV Cache", + description="Enables KV Cache optimization for reference images on Flux family models.", + category="", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to use KV Cache on."), + ], + outputs=[ + io.Model.Output(tooltip="The patched model with KV Cache enabled."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type) -> io.NodeOutput: + m = model.clone() + input_patch_obj = KV_Attn_Input() + + def model_input_patch(inputs): + if len(input_patch_obj.cache) > 0: + ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", [])) + if ref_image_tokens > 0: + img = inputs["img"] + inputs["img"] = img[:, :-ref_image_tokens] + return inputs + + m.set_model_attn1_patch(input_patch_obj) + m.set_model_post_input_patch(model_input_patch) + if hasattr(model.model.diffusion_model, "params"): + m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero") + else: + m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero") + + return io.NodeOutput(m) class FluxExtension(ComfyExtension): @override @@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension): FluxKontextMultiReferenceLatentMethod, EmptyFlux2LatentImage, Flux2Scheduler, + FluxKVCache, ] diff --git a/comfy_extras/nodes_painter.py b/comfy_extras/nodes_painter.py new file mode 100644 index 000000000..b9ecdf5ea --- /dev/null +++ b/comfy_extras/nodes_painter.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import hashlib +import os + +import numpy as np +import torch +from PIL import Image + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io, UI +from typing_extensions import override + + +def hex_to_rgb(hex_color: str) -> tuple[float, float, float]: + hex_color = hex_color.lstrip("#") + if len(hex_color) != 6: + return (0.0, 0.0, 0.0) + r = int(hex_color[0:2], 16) / 255.0 + g = int(hex_color[2:4], 16) / 255.0 + b = int(hex_color[4:6], 16) / 255.0 + return (r, g, b) + + +class PainterNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Painter", + display_name="Painter", + category="image", + inputs=[ + io.Image.Input( + "image", + optional=True, + tooltip="Optional base image to paint over", + ), + io.String.Input( + "mask", + default="", + socketless=True, + extra_dict={"widgetType": "PAINTER", "image_upload": True}, + ), + io.Int.Input( + "width", + default=512, + min=64, + max=4096, + step=64, + socketless=True, + extra_dict={"hidden": True}, + ), + io.Int.Input( + "height", + default=512, + min=64, + max=4096, + step=64, + socketless=True, + extra_dict={"hidden": True}, + ), + io.Color.Input("bg_color", default="#000000"), + ], + outputs=[ + io.Image.Output("IMAGE"), + io.Mask.Output("MASK"), + ], + ) + + @classmethod + def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput: + if image is not None: + base_image = image[:1] + h, w = base_image.shape[1], base_image.shape[2] + else: + h, w = height, width + r, g, b = hex_to_rgb(bg_color) + base_image = torch.zeros((1, h, w, 3), dtype=torch.float32) + base_image[0, :, :, 0] = r + base_image[0, :, :, 1] = g + base_image[0, :, :, 2] = b + + if mask and mask.strip(): + mask_path = folder_paths.get_annotated_filepath(mask) + painter_img = node_helpers.pillow(Image.open, mask_path) + painter_img = painter_img.convert("RGBA") + + if painter_img.size != (w, h): + painter_img = painter_img.resize((w, h), Image.LANCZOS) + + painter_np = np.array(painter_img).astype(np.float32) / 255.0 + painter_rgb = painter_np[:, :, :3] + painter_alpha = painter_np[:, :, 3:4] + + mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0) + + base_np = base_image[0].cpu().numpy() + composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha) + out_image = torch.from_numpy(composited).unsqueeze(0) + else: + mask_tensor = torch.zeros((1, h, w), dtype=torch.float32) + out_image = base_image + + return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image)) + + @classmethod + def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None): + if mask and mask.strip(): + mask_path = folder_paths.get_annotated_filepath(mask) + if os.path.exists(mask_path): + m = hashlib.sha256() + with open(mask_path, "rb") as f: + m.update(f.read()) + return m.digest().hex() + return "" + + + +class PainterExtension(ComfyExtension): + @override + async def get_node_list(self): + return [PainterNode] + + +async def comfy_entrypoint(): + return PainterExtension() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 97b9e948d..db4f9d231 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode): pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) oom = False - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) tile //= 2 if tile < 128: raise e diff --git a/comfyui_version.py b/comfyui_version.py index 2723d02e7..701f4d66a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.16.4" +__version__ = "0.17.0" diff --git a/execution.py b/execution.py index 7ccdbf93e..1a6c3429c 100644 --- a/execution.py +++ b/execution.py @@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io +from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger class ExecutionResult(Enum): @@ -126,15 +127,15 @@ class CacheSet: # Performs like the old cache -- dump data ASAP def init_classic_cache(self): - self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True) self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - cached = caches.outputs.get(unique_id) + cached = await caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: cached_ui = cached.ui or {} @@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) - obj = caches.objects.get(unique_id) + obj = await caches.objects.get(unique_id) if obj is None: obj = class_def() - caches.objects.set(unique_id, obj) + await caches.objects.set(unique_id, obj) if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) execution_list.cache_update(unique_id, cache_entry) - caches.outputs.set(unique_id, cache_entry) + await caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -612,7 +613,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, logging.error(traceback.format_exc()) tips = "" - if isinstance(ex, comfy.model_management.OOM_EXCEPTION): + if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.") @@ -684,6 +685,19 @@ class PromptExecutor: } self.add_message("execution_error", mes, broadcast=False) + def _notify_prompt_lifecycle(self, event: str, prompt_id: str): + if not _has_cache_providers(): + return + + for provider in _get_cache_providers(): + try: + if event == "start": + provider.on_prompt_start(prompt_id) + elif event == "end": + provider.on_prompt_end(prompt_id) + except Exception as e: + _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -700,66 +714,75 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - with torch.inference_mode(): - dynamic_prompt = DynamicPrompt(prompt) - reset_progress_state(prompt_id, dynamic_prompt) - add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) - for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) - cache.clean_unused() + self._notify_prompt_lifecycle("start", prompt_id) - cached_nodes = [] - for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: - cached_nodes.append(node_id) + try: + with torch.inference_mode(): + dynamic_prompt = DynamicPrompt(prompt) + reset_progress_state(prompt_id, dynamic_prompt) + add_progress_handler(WebUIProgressHandler(self.server)) + is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) - pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results - ui_node_outputs = {} - executed = set() - execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) - current_outputs = self.caches.outputs.all_node_ids() - for node_id in list(execute_outputs): - execution_list.add_node(node_id) + node_ids = list(prompt.keys()) + cache_results = await asyncio.gather( + *(self.caches.outputs.get(node_id) for node_id in node_ids) + ) + cached_nodes = [ + node_id for node_id, result in zip(node_ids, cache_results) + if result is not None + ] - while not execution_list.is_empty(): - node_id, error, ex = await execution_list.stage_node_execution() - if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break + comfy.model_management.cleanup_models_gc() + self.add_message("execution_cached", + { "nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False) + pending_subgraph_results = {} + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} + executed = set() + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) + current_outputs = self.caches.outputs.all_node_ids() + for node_id in list(execute_outputs): + execution_list.add_node(node_id) - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) - self.success = result != ExecutionResult.FAILURE - if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) - break - elif result == ExecutionResult.PENDING: - execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: - execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) - else: - # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + while not execution_list.is_empty(): + node_id, error, ex = await execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break - ui_outputs = {} - meta_outputs = {} - for node_id, ui_info in ui_node_outputs.items(): - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] - self.history_result = { - "outputs": ui_outputs, - "meta": meta_outputs, - } - self.server.last_node_id = None - if comfy.model_management.DISABLE_SMART_MEMORY: - comfy.model_management.unload_all_models() + assert node_id is not None, "Node ID should not be None at this point" + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + self.success = result != ExecutionResult.FAILURE + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + else: + # Only execute when the while-loop ends without break + self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + + ui_outputs = {} + meta_outputs = {} + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } + self.server.last_node_id = None + if comfy.model_management.DISABLE_SMART_MEMORY: + comfy.model_management.unload_all_models() + finally: + self._notify_prompt_lifecycle("end", prompt_id) async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/main.py b/main.py index 1977f9362..8905fd09a 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ comfy.options.enable_args_parsing() import os import importlib.util +import shutil import importlib.metadata import folder_paths import time @@ -11,6 +12,7 @@ from app.logger import setup_logger import itertools import utils.extra_config from utils.mime_types import init_mime_types +import faulthandler import logging import sys from comfy_execution.progress import get_progress_state @@ -25,6 +27,8 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +faulthandler.enable(file=sys.stderr, all_threads=False) + import comfy_aimdo.control if enables_dynamic_vram(): @@ -64,8 +68,15 @@ if __name__ == "__main__": def handle_comfyui_manager_unavailable(): - if not args.windows_standalone_build: - logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt") + uv_available = shutil.which("uv") is not None + + pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}" + msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}" + if uv_available: + msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}" + msg += "\n" + logging.warning(msg) args.enable_manager = False @@ -173,7 +184,6 @@ execute_prestartup_script() # Main code import asyncio -import shutil import threading import gc diff --git a/manager_requirements.txt b/manager_requirements.txt index c420cc48e..6bcc3fb50 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b1 +comfyui_manager==4.1b2 \ No newline at end of file diff --git a/nodes.py b/nodes.py index 0ef23b640..eb63f9d44 100644 --- a/nodes.py +++ b/nodes.py @@ -2450,6 +2450,7 @@ async def init_builtin_extra_nodes(): "nodes_nag.py", "nodes_sdpose.py", "nodes_math.py", + "nodes_painter.py", ] import_failed = [] diff --git a/pyproject.toml b/pyproject.toml index 753b219b3..e2ca79be7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.16.4" +version = "0.17.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index b1db1cf24..511c62fee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.11 +comfyui-frontend-package==1.41.18 +comfyui-workflow-templates==0.9.21 comfyui-embedded-docs==0.4.3 torch torchsde @@ -22,8 +22,8 @@ alembic SQLAlchemy filelock av>=14.2.0 -comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.9 +comfy-kitchen>=0.2.8 +comfy-aimdo>=0.2.10 requests simpleeval>=1.0.0 blake3 diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py new file mode 100644 index 000000000..ac3814746 --- /dev/null +++ b/tests-unit/execution_test/test_cache_provider.py @@ -0,0 +1,403 @@ +"""Tests for external cache provider API.""" + +import importlib.util +import pytest +from typing import Optional + + +def _torch_available() -> bool: + """Check if PyTorch is available.""" + return importlib.util.find_spec("torch") is not None + + +from comfy_execution.cache_provider import ( + CacheProvider, + CacheContext, + CacheValue, + register_cache_provider, + unregister_cache_provider, + _get_cache_providers, + _has_cache_providers, + _clear_cache_providers, + _serialize_cache_key, + _contains_self_unequal, + _estimate_value_size, + _canonicalize, +) + + +class TestCanonicalize: + """Test _canonicalize function for deterministic ordering.""" + + def test_frozenset_ordering_is_deterministic(self): + """Frozensets should produce consistent canonical form regardless of iteration order.""" + # Create two frozensets with same content + fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)]) + fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_nested_frozenset_ordering(self): + """Nested frozensets should also be deterministically ordered.""" + inner1 = frozenset([1, 2, 3]) + inner2 = frozenset([3, 2, 1]) + + fs1 = frozenset([("key", inner1)]) + fs2 = frozenset([("key", inner2)]) + + result1 = _canonicalize(fs1) + result2 = _canonicalize(fs2) + + assert result1 == result2 + + def test_dict_ordering(self): + """Dicts should be sorted by key.""" + d1 = {"z": 1, "a": 2, "m": 3} + d2 = {"a": 2, "m": 3, "z": 1} + + result1 = _canonicalize(d1) + result2 = _canonicalize(d2) + + assert result1 == result2 + + def test_tuple_preserved(self): + """Tuples should be marked and preserved.""" + t = (1, 2, 3) + result = _canonicalize(t) + + assert result[0] == "__tuple__" + + def test_list_preserved(self): + """Lists should be recursively canonicalized.""" + lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] + result = _canonicalize(lst) + + # First element should be canonicalized dict + assert "__dict__" in result[0] + # Second element should be canonicalized frozenset + assert result[1][0] == "__frozenset__" + + def test_primitives_include_type(self): + """Primitive types should include type name for disambiguation.""" + assert _canonicalize(42) == ("int", 42) + assert _canonicalize(3.14) == ("float", 3.14) + assert _canonicalize("hello") == ("str", "hello") + assert _canonicalize(True) == ("bool", True) + assert _canonicalize(None) == ("NoneType", None) + + def test_int_and_str_distinguished(self): + """int 7 and str '7' must produce different canonical forms.""" + assert _canonicalize(7) != _canonicalize("7") + + def test_bytes_converted(self): + """Bytes should be converted to hex string.""" + b = b"\x00\xff" + result = _canonicalize(b) + + assert result[0] == "__bytes__" + assert result[1] == "00ff" + + def test_set_ordering(self): + """Sets should be sorted like frozensets.""" + s1 = {3, 1, 2} + s2 = {1, 2, 3} + + result1 = _canonicalize(s1) + result2 = _canonicalize(s2) + + assert result1 == result2 + assert result1[0] == "__set__" + + def test_unknown_type_raises(self): + """Unknown types should raise ValueError (fail-closed).""" + class CustomObj: + pass + with pytest.raises(ValueError): + _canonicalize(CustomObj()) + + def test_object_with_value_attr_raises(self): + """Objects with .value attribute (Unhashable-like) should raise ValueError.""" + class FakeUnhashable: + def __init__(self): + self.value = float('nan') + with pytest.raises(ValueError): + _canonicalize(FakeUnhashable()) + + +class TestSerializeCacheKey: + """Test _serialize_cache_key for deterministic hashing.""" + + def test_same_content_same_hash(self): + """Same content should produce same hash.""" + key1 = frozenset([("node_1", frozenset([("input", "value")]))]) + key2 = frozenset([("node_1", frozenset([("input", "value")]))]) + + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) + + assert hash1 == hash2 + + def test_different_content_different_hash(self): + """Different content should produce different hash.""" + key1 = frozenset([("node_1", "value_a")]) + key2 = frozenset([("node_1", "value_b")]) + + hash1 = _serialize_cache_key(key1) + hash2 = _serialize_cache_key(key2) + + assert hash1 != hash2 + + def test_returns_hex_string(self): + """Should return hex string (SHA256 hex digest).""" + key = frozenset([("test", 123)]) + result = _serialize_cache_key(key) + + assert isinstance(result, str) + assert len(result) == 64 # SHA256 hex digest is 64 chars + + def test_complex_nested_structure(self): + """Complex nested structures should hash deterministically.""" + # Note: frozensets can only contain hashable types, so we use + # nested frozensets of tuples to represent dict-like structures + key = frozenset([ + ("node_1", frozenset([ + ("input_a", ("tuple", "value")), + ("input_b", frozenset([("nested", "dict")])), + ])), + ("node_2", frozenset([ + ("param", 42), + ])), + ]) + + # Hash twice to verify determinism + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) + + assert hash1 == hash2 + + def test_dict_in_cache_key(self): + """Dicts passed directly to _serialize_cache_key should work.""" + key = {"node_1": {"input": "value"}, "node_2": 42} + + hash1 = _serialize_cache_key(key) + hash2 = _serialize_cache_key(key) + + assert hash1 == hash2 + assert isinstance(hash1, str) + assert len(hash1) == 64 + + def test_unknown_type_returns_none(self): + """Non-cacheable types should return None (fail-closed).""" + class CustomObj: + pass + assert _serialize_cache_key(CustomObj()) is None + + +class TestContainsSelfUnequal: + """Test _contains_self_unequal utility function.""" + + def test_nan_float_detected(self): + """NaN floats should be detected (not equal to itself).""" + assert _contains_self_unequal(float('nan')) is True + + def test_regular_float_not_detected(self): + """Regular floats are equal to themselves.""" + assert _contains_self_unequal(3.14) is False + assert _contains_self_unequal(0.0) is False + assert _contains_self_unequal(-1.5) is False + + def test_infinity_not_detected(self): + """Infinity is equal to itself.""" + assert _contains_self_unequal(float('inf')) is False + assert _contains_self_unequal(float('-inf')) is False + + def test_nan_in_list(self): + """NaN in list should be detected.""" + assert _contains_self_unequal([1, 2, float('nan'), 4]) is True + assert _contains_self_unequal([1, 2, 3, 4]) is False + + def test_nan_in_tuple(self): + """NaN in tuple should be detected.""" + assert _contains_self_unequal((1, float('nan'))) is True + assert _contains_self_unequal((1, 2, 3)) is False + + def test_nan_in_frozenset(self): + """NaN in frozenset should be detected.""" + assert _contains_self_unequal(frozenset([1, float('nan')])) is True + assert _contains_self_unequal(frozenset([1, 2, 3])) is False + + def test_nan_in_dict_value(self): + """NaN in dict value should be detected.""" + assert _contains_self_unequal({"key": float('nan')}) is True + assert _contains_self_unequal({"key": 42}) is False + + def test_nan_in_nested_structure(self): + """NaN in deeply nested structure should be detected.""" + nested = {"level1": [{"level2": (1, 2, float('nan'))}]} + assert _contains_self_unequal(nested) is True + + def test_non_numeric_types(self): + """Non-numeric types should not be self-unequal.""" + assert _contains_self_unequal("string") is False + assert _contains_self_unequal(None) is False + assert _contains_self_unequal(True) is False + + def test_object_with_nan_value_attribute(self): + """Objects wrapping NaN in .value should be detected.""" + class NanWrapper: + def __init__(self): + self.value = float('nan') + assert _contains_self_unequal(NanWrapper()) is True + + def test_custom_self_unequal_object(self): + """Custom objects where not (x == x) should be detected.""" + class NeverEqual: + def __eq__(self, other): + return False + assert _contains_self_unequal(NeverEqual()) is True + + +class TestEstimateValueSize: + """Test _estimate_value_size utility function.""" + + def test_empty_outputs(self): + """Empty outputs should have zero size.""" + value = CacheValue(outputs=[]) + assert _estimate_value_size(value) == 0 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_tensor_size_estimation(self): + """Tensor size should be estimated correctly.""" + import torch + + # 1000 float32 elements = 4000 bytes + tensor = torch.zeros(1000, dtype=torch.float32) + value = CacheValue(outputs=[[tensor]]) + + size = _estimate_value_size(value) + assert size == 4000 + + @pytest.mark.skipif( + not _torch_available(), + reason="PyTorch not available" + ) + def test_nested_tensor_in_dict(self): + """Tensors nested in dicts should be counted.""" + import torch + + tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes + value = CacheValue(outputs=[[{"samples": tensor}]]) + + size = _estimate_value_size(value) + assert size == 400 + + +class TestProviderRegistry: + """Test cache provider registration and retrieval.""" + + def setup_method(self): + """Clear providers before each test.""" + _clear_cache_providers() + + def teardown_method(self): + """Clear providers after each test.""" + _clear_cache_providers() + + def test_register_provider(self): + """Provider should be registered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + + assert _has_cache_providers() is True + providers = _get_cache_providers() + assert len(providers) == 1 + assert providers[0] is provider + + def test_unregister_provider(self): + """Provider should be unregistered successfully.""" + provider = MockCacheProvider() + register_cache_provider(provider) + unregister_cache_provider(provider) + + assert _has_cache_providers() is False + + def test_multiple_providers(self): + """Multiple providers can be registered.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + + providers = _get_cache_providers() + assert len(providers) == 2 + + def test_duplicate_registration_ignored(self): + """Registering same provider twice should be ignored.""" + provider = MockCacheProvider() + + register_cache_provider(provider) + register_cache_provider(provider) # Should be ignored + + providers = _get_cache_providers() + assert len(providers) == 1 + + def test_clear_providers(self): + """_clear_cache_providers should remove all providers.""" + provider1 = MockCacheProvider() + provider2 = MockCacheProvider() + + register_cache_provider(provider1) + register_cache_provider(provider2) + _clear_cache_providers() + + assert _has_cache_providers() is False + assert len(_get_cache_providers()) == 0 + + +class TestCacheContext: + """Test CacheContext dataclass.""" + + def test_context_creation(self): + """CacheContext should be created with all fields.""" + context = CacheContext( + node_id="node-456", + class_type="KSampler", + cache_key_hash="a" * 64, + ) + + assert context.node_id == "node-456" + assert context.class_type == "KSampler" + assert context.cache_key_hash == "a" * 64 + + +class TestCacheValue: + """Test CacheValue dataclass.""" + + def test_value_creation(self): + """CacheValue should be created with outputs.""" + outputs = [[{"samples": "tensor_data"}]] + value = CacheValue(outputs=outputs) + + assert value.outputs == outputs + + +class MockCacheProvider(CacheProvider): + """Mock cache provider for testing.""" + + def __init__(self): + self.lookups = [] + self.stores = [] + + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + self.lookups.append(context) + return None + + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + self.stores.append((context, value))