From c61f69acf50d849f493e403447a3de0dfd821462 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 12 Feb 2026 17:57:19 -0500 Subject: [PATCH] Use torch RMSNorm for flux models and refactor hunyuan video code. --- comfy/ldm/chroma/layers.py | 3 +- comfy/ldm/chroma_radiance/layers.py | 8 ++--- comfy/ldm/flux/layers.py | 51 ++++++++--------------------- comfy/ldm/flux/model.py | 3 +- comfy/ldm/hunyuan_video/model.py | 11 +++---- comfy/supported_models.py | 23 +++++++++++-- comfy/utils.py | 12 +++---- 7 files changed, 50 insertions(+), 61 deletions(-) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 2d5684348..df348a8ed 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -3,7 +3,6 @@ from torch import Tensor, nn from comfy.ldm.flux.layers import ( MLPEmbedder, - RMSNorm, ModulationOut, ) @@ -29,7 +28,7 @@ class Approximator(nn.Module): super().__init__() self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) - self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) + self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)]) self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) @property diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py index 3c7bc9b6b..08d31e0ba 100644 --- a/comfy/ldm/chroma_radiance/layers.py +++ b/comfy/ldm/chroma_radiance/layers.py @@ -4,8 +4,6 @@ from functools import lru_cache import torch from torch import nn -from comfy.ldm.flux.layers import RMSNorm - class NerfEmbedder(nn.Module): """ @@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module): # We now need to generate parameters for 3 matrices. total_params = 3 * hidden_size_x**2 * mlp_ratio self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) - self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device) self.mlp_ratio = mlp_ratio @@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module): class NerfFinalLayer(nn.Module): def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): super().__init__() - self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module): class NerfFinalLayerConv(nn.Module): def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() - self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) self.conv = operations.Conv2d( in_channels=hidden_size, out_channels=out_channels, diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 60f2bdae2..c42f7c20e 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) - - def forward(self, x: Tensor): - return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6) - class QKNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): super().__init__() - self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) - self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) + self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device) + self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: q = self.query_norm(q) @@ -169,7 +161,7 @@ class SiLUActivation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module): self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) - self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): if self.modulation: img_mod1, img_mod2 = self.img_mod(vec) @@ -224,32 +214,17 @@ class DoubleStreamBlock(nn.Module): del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - if self.flipped_img_txt: - q = torch.cat((img_q, txt_q), dim=2) - del img_q, txt_q - k = torch.cat((img_k, txt_k), dim=2) - del img_k, txt_k - v = torch.cat((img_v, txt_v), dim=2) - del img_v, txt_v - # run actual attention - attn = attention(q, k, v, - pe=pe, mask=attn_mask, transformer_options=transformer_options) - del q, k, v + q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q + k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k + v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v + # run actual attention + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v - img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] - else: - q = torch.cat((txt_q, img_q), dim=2) - del txt_q, img_q - k = torch.cat((txt_k, img_k), dim=2) - del txt_k, img_k - v = torch.cat((txt_v, img_v), dim=2) - del txt_v, img_v - # run actual attention - attn = attention(q, k, v, - pe=pe, mask=attn_mask, transformer_options=transformer_options) - del q, k, v - - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index f40c2a7a9..260ccad7e 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -16,7 +16,6 @@ from .layers import ( SingleStreamBlock, timestep_embedding, Modulation, - RMSNorm ) @dataclass @@ -81,7 +80,7 @@ class Flux(nn.Module): self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) if params.txt_norm: - self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device) else: self.txt_norm = None diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 55ab550f8..563f28f6b 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, - flipped_img_txt=True, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module): extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) - ids = torch.cat((img_ids, txt_ids), dim=1) + ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) img_len = img.shape[1] if txt_mask is not None: attn_mask_len = img_len + txt.shape[1] attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) - attn_mask[:, 0, img_len:] = txt_mask + attn_mask[:, 0, :txt.shape[1]] = txt_mask else: attn_mask = None @@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module): if add is not None: img += add - img = torch.cat((img, txt), 1) + img = torch.cat((txt, img), 1) transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" @@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module): if i < len(control_o): add = control_o[i] if add is not None: - img[:, : img_len] += add + img[:, txt.shape[1]: img_len + txt.shape[1]] += add - img = img[:, : img_len] + img = img[:, txt.shape[1]: img_len + txt.shape[1]] if ref_latent is not None: img = img[:, ref_latent.shape[1]:] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d33db7507..3be8f48dd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith("_norm.scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd + vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] @@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE): key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.") key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.") key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.") - key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale") - key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale") + key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight") + key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight") key_out = key_out.replace("_attn_proj.", "_attn.proj.") key_out = key_out.replace(".modulation.linear.", ".modulation.lin.") key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.") + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) out_sd[key_out] = state_dict[k] return out_sd @@ -1341,6 +1352,14 @@ class Chroma(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd def get_model(self, state_dict, prefix="", device=None): out = model_base.Chroma(self, device=device) diff --git a/comfy/utils.py b/comfy/utils.py index e0a94e2e1..d553a7c1b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -675,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.linear_in.bias": "txt_mlp.0.bias", "ff_context.linear_out.weight": "txt_mlp.2.weight", "ff_context.linear_out.bias": "txt_mlp.2.bias", - "attn.norm_q.weight": "img_attn.norm.query_norm.scale", - "attn.norm_k.weight": "img_attn.norm.key_norm.scale", - "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", - "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + "attn.norm_q.weight": "img_attn.norm.query_norm.weight", + "attn.norm_k.weight": "img_attn.norm.key_norm.weight", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight", } for k in block_map: @@ -701,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "norm.linear.bias": "modulation.lin.bias", "proj_out.weight": "linear2.weight", "proj_out.bias": "linear2.bias", - "attn.norm_q.weight": "norm.query_norm.scale", - "attn.norm_k.weight": "norm.key_norm.scale", + "attn.norm_q.weight": "norm.query_norm.weight", + "attn.norm_k.weight": "norm.key_norm.weight", "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 "attn.to_out.weight": "linear2.weight", # Flux 2 }