import torch from comfy_api.latest import ComfyExtension, io from typing_extensions import override class NAGuidance(io.ComfyNode): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( node_id="NAGuidance", display_name="Normalized Attention Guidance", description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.", category="", is_experimental=True, inputs=[ io.Model.Input("model", tooltip="The model to apply NAG to."), io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."), io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."), io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01), # io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."), # io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."), ], outputs=[ io.Model.Output(tooltip="The patched model with NAG enabled."), ], ) @classmethod def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput: m = model.clone() # sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent) # sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent) def nag_attention_output_patch(out, extra_options): cond_or_uncond = extra_options.get("cond_or_uncond", None) if cond_or_uncond is None: return out if not (1 in cond_or_uncond and 0 in cond_or_uncond): return out # sigma = extra_options.get("sigmas", None) # if sigma is not None and len(sigma) > 0: # sigma = sigma[0].item() # if sigma > sigma_start or sigma < sigma_end: # return out img_slice = extra_options.get("img_slice", None) if img_slice is not None: orig_out = out out = out[:, img_slice[0]:img_slice[1]] # only apply on img part batch_size = out.shape[0] half_size = batch_size // len(cond_or_uncond) ind_neg = cond_or_uncond.index(1) ind_pos = cond_or_uncond.index(0) z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)] z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)] guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0) eps = 1e-6 norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps) norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps) ratio = norm_guided / norm_pos scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio guided_normalized = guided * scale_factor z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha) if img_slice is not None: orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final return orig_out else: out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final return out m.set_model_attn1_output_patch(nag_attention_output_patch) m.disable_model_cfg1_optimization() return io.NodeOutput(m) class NagExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ NAGuidance, ] async def comfy_entrypoint() -> NagExtension: return NagExtension()