mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-28 18:31:31 +00:00
Use torch RMSNorm for flux models and refactor hunyuan video code.
This commit is contained in:
@@ -3,7 +3,6 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
RMSNorm,
|
|
||||||
ModulationOut,
|
ModulationOut,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from functools import lru_cache
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import RMSNorm
|
|
||||||
|
|
||||||
|
|
||||||
class NerfEmbedder(nn.Module):
|
class NerfEmbedder(nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
# We now need to generate parameters for 3 matrices.
|
# We now need to generate parameters for 3 matrices.
|
||||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
|
||||||
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
class NerfFinalLayer(nn.Module):
|
class NerfFinalLayer(nn.Module):
|
||||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
|||||||
class NerfFinalLayerConv(nn.Module):
|
class NerfFinalLayerConv(nn.Module):
|
||||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.conv = operations.Conv2d(
|
self.conv = operations.Conv2d(
|
||||||
in_channels=hidden_size,
|
in_channels=hidden_size,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
|
|||||||
@@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||||
q = self.query_norm(q)
|
q = self.query_norm(q)
|
||||||
@@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
@@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
@@ -224,32 +214,17 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
del txt_qkv
|
del txt_qkv
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
if self.flipped_img_txt:
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
q = torch.cat((img_q, txt_q), dim=2)
|
del txt_q, img_q
|
||||||
del img_q, txt_q
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
k = torch.cat((img_k, txt_k), dim=2)
|
del txt_k, img_k
|
||||||
del img_k, txt_k
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
v = torch.cat((img_v, txt_v), dim=2)
|
del txt_v, img_v
|
||||||
del img_v, txt_v
|
# run actual attention
|
||||||
# run actual attention
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
attn = attention(q, k, v,
|
del q, k, v
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
else:
|
|
||||||
q = torch.cat((txt_q, img_q), dim=2)
|
|
||||||
del txt_q, img_q
|
|
||||||
k = torch.cat((txt_k, img_k), dim=2)
|
|
||||||
del txt_k, img_k
|
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
|
||||||
del txt_v, img_v
|
|
||||||
# run actual attention
|
|
||||||
attn = attention(q, k, v,
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from .layers import (
|
|||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
Modulation,
|
Modulation,
|
||||||
RMSNorm
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -81,7 +80,7 @@ class Flux(nn.Module):
|
|||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
if params.txt_norm:
|
if params.txt_norm:
|
||||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.txt_norm = None
|
self.txt_norm = None
|
||||||
|
|
||||||
|
|||||||
@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
flipped_img_txt=True,
|
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||||
|
|
||||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
img_len = img.shape[1]
|
img_len = img.shape[1]
|
||||||
if txt_mask is not None:
|
if txt_mask is not None:
|
||||||
attn_mask_len = img_len + txt.shape[1]
|
attn_mask_len = img_len + txt.shape[1]
|
||||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||||
attn_mask[:, 0, img_len:] = txt_mask
|
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
@@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img += add
|
||||||
|
|
||||||
img = torch.cat((img, txt), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
@@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
|
|||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
add = control_o[i]
|
add = control_o[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img[:, : img_len] += add
|
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||||
|
|
||||||
img = img[:, : img_len]
|
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||||
if ref_latent is not None:
|
if ref_latent is not None:
|
||||||
img = img[:, ref_latent.shape[1]:]
|
img = img[:, ref_latent.shape[1]:]
|
||||||
|
|
||||||
|
|||||||
@@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith("_norm.scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||||
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||||
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
|
||||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
|
||||||
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||||
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||||
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
out_sd[key_out] = state_dict[k]
|
out_sd[key_out] = state_dict[k]
|
||||||
return out_sd
|
return out_sd
|
||||||
|
|
||||||
@@ -1341,6 +1352,14 @@ class Chroma(supported_models_base.BASE):
|
|||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Chroma(self, device=device)
|
out = model_base.Chroma(self, device=device)
|
||||||
|
|||||||
@@ -675,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
||||||
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
||||||
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
||||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
|
||||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
|
||||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
|
||||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
@@ -701,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"norm.linear.bias": "modulation.lin.bias",
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
"proj_out.weight": "linear2.weight",
|
"proj_out.weight": "linear2.weight",
|
||||||
"proj_out.bias": "linear2.bias",
|
"proj_out.bias": "linear2.bias",
|
||||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
"attn.norm_q.weight": "norm.query_norm.weight",
|
||||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
"attn.norm_k.weight": "norm.key_norm.weight",
|
||||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user