mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-18 22:20:03 +00:00
Implement NAG on all the models based on the Flux code. (#12500)
Use the Normalized Attention Guidance node. Flux, Flux2, Klein, Chroma, Chroma radiance, Hunyuan Video, etc..
This commit is contained in:
@@ -152,6 +152,7 @@ class Chroma(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
# running on sequences img
|
||||
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
|
||||
@@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
@@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module):
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module):
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
|
||||
@@ -142,6 +142,7 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -231,6 +232,7 @@ class Flux(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module):
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
@@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@@ -406,13 +406,16 @@ class ModelPatcher:
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def disable_model_cfg1_optimization(self):
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
if disable_cfg1_optimization:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
self.disable_model_cfg1_optimization()
|
||||
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||
|
||||
99
comfy_extras/nodes_nag.py
Normal file
99
comfy_extras/nodes_nag.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class NAGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="NAGuidance",
|
||||
display_name="Normalized Attention Guidance",
|
||||
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply NAG to."),
|
||||
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
|
||||
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
|
||||
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
|
||||
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
|
||||
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with NAG enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
|
||||
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
|
||||
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
|
||||
|
||||
def nag_attention_output_patch(out, extra_options):
|
||||
cond_or_uncond = extra_options.get("cond_or_uncond", None)
|
||||
if cond_or_uncond is None:
|
||||
return out
|
||||
|
||||
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
|
||||
return out
|
||||
|
||||
# sigma = extra_options.get("sigmas", None)
|
||||
# if sigma is not None and len(sigma) > 0:
|
||||
# sigma = sigma[0].item()
|
||||
# if sigma > sigma_start or sigma < sigma_end:
|
||||
# return out
|
||||
|
||||
img_slice = extra_options.get("img_slice", None)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out = out
|
||||
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
|
||||
|
||||
batch_size = out.shape[0]
|
||||
half_size = batch_size // len(cond_or_uncond)
|
||||
|
||||
ind_neg = cond_or_uncond.index(1)
|
||||
ind_pos = cond_or_uncond.index(0)
|
||||
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
|
||||
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
|
||||
|
||||
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
|
||||
|
||||
eps = 1e-6
|
||||
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
|
||||
ratio = norm_guided / norm_pos
|
||||
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
|
||||
|
||||
guided_normalized = guided * scale_factor
|
||||
|
||||
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
return orig_out
|
||||
else:
|
||||
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
|
||||
return out
|
||||
|
||||
m.set_model_attn1_output_patch(nag_attention_output_patch)
|
||||
m.disable_model_cfg1_optimization()
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class NagExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
NAGuidance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NagExtension:
|
||||
return NagExtension()
|
||||
Reference in New Issue
Block a user