diff --git a/backend/diffusion_engine/sd35.py b/backend/diffusion_engine/sd35.py new file mode 100644 index 00000000..711fb442 --- /dev/null +++ b/backend/diffusion_engine/sd35.py @@ -0,0 +1,149 @@ +import torch + +from huggingface_guess import model_list +from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects +from backend.patcher.clip import CLIP +from backend.patcher.vae import VAE +from backend.patcher.unet import UnetPatcher +from backend.text_processing.classic_engine import ClassicTextProcessingEngine +from backend.text_processing.t5_engine import T5TextProcessingEngine +from backend.args import dynamic_args +from backend import memory_management +from backend.modules.k_prediction import PredictionDiscreteFlow + +from modules.shared import opts + + +## patch SD3 Class in huggingface_guess.model_list +def SD3_clip_target(self, state_dict={}): + return {'clip_l': 'text_encoder', 'clip_g': 'text_encoder_2', 't5xxl': 'text_encoder_3'} + +model_list.SD3.unet_target = 'transformer' +model_list.SD3.clip_target = SD3_clip_target +## end patch + +class StableDiffusion3(ForgeDiffusionEngine): + matched_guesses = [model_list.SD3] + + def __init__(self, estimated_config, huggingface_components): + super().__init__(estimated_config, huggingface_components) + self.is_inpaint = False + + clip = CLIP( + model_dict={ + 'clip_l': huggingface_components['text_encoder'], + 'clip_g': huggingface_components['text_encoder_2'], + 't5xxl' : huggingface_components['text_encoder_3'] + }, + tokenizer_dict={ + 'clip_l': huggingface_components['tokenizer'], + 'clip_g': huggingface_components['tokenizer_2'], + 't5xxl' : huggingface_components['tokenizer_3'] + } + ) + + k_predictor = PredictionDiscreteFlow(shift=3.0) + + vae = VAE(model=huggingface_components['vae']) + + unet = UnetPatcher.from_model( + model=huggingface_components['transformer'], + diffusers_scheduler= None, + k_predictor=k_predictor, + config=estimated_config + ) + + self.text_processing_engine_l = ClassicTextProcessingEngine( + text_encoder=clip.cond_stage_model.clip_l, + tokenizer=clip.tokenizer.clip_l, + embedding_dir=dynamic_args['embedding_dir'], + embedding_key='clip_l', + embedding_expected_shape=768, + emphasis_name=dynamic_args['emphasis_name'], + text_projection=True, + minimal_clip_skip=1, + clip_skip=1, + return_pooled=True, + final_layer_norm=False, + ) + + self.text_processing_engine_g = ClassicTextProcessingEngine( + text_encoder=clip.cond_stage_model.clip_g, + tokenizer=clip.tokenizer.clip_g, + embedding_dir=dynamic_args['embedding_dir'], + embedding_key='clip_g', + embedding_expected_shape=1280, + emphasis_name=dynamic_args['emphasis_name'], + text_projection=True, + minimal_clip_skip=1, + clip_skip=1, + return_pooled=True, + final_layer_norm=False, + ) + + self.text_processing_engine_t5 = T5TextProcessingEngine( + text_encoder=clip.cond_stage_model.t5xxl, + tokenizer=clip.tokenizer.t5xxl, + emphasis_name=dynamic_args['emphasis_name'], + ) + + self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None) + self.forge_objects_original = self.forge_objects.shallow_copy() + self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy() + + # WebUI Legacy + self.is_sd3 = True + + def set_clip_skip(self, clip_skip): + self.text_processing_engine_l.clip_skip = clip_skip + self.text_processing_engine_g.clip_skip = clip_skip + + @torch.inference_mode() + def get_learned_conditioning(self, prompt: list[str]): + memory_management.load_model_gpu(self.forge_objects.clip.patcher) + + cond_g, g_pooled = self.text_processing_engine_g(prompt) + cond_l, l_pooled = self.text_processing_engine_l(prompt) + if opts.sd3_enable_t5: + cond_t5 = self.text_processing_engine_t5(prompt) + else: + cond_t5 = torch.zeros([len(prompt), 256, 4096]) + + is_negative_prompt = getattr(prompt, 'is_negative_prompt', False) + + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt) + + if force_zero_negative_prompt: + l_pooled = torch.zeros_like(l_pooled) + g_pooled = torch.zeros_like(g_pooled) + cond_l = torch.zeros_like(cond_l) + cond_g = torch.zeros_like(cond_g) + cond_t5 = torch.zeros_like(cond_t5) + + cond_lg = torch.cat([cond_l, cond_g], dim=-1) + cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1])) + + cond = dict( + crossattn=torch.cat([cond_lg, cond_t5], dim=-2), + vector=torch.cat([l_pooled, g_pooled], dim=-1), + ) + + return cond + + @torch.inference_mode() + def get_prompt_lengths_on_ui(self, prompt): + token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0]) + return token_count, max(255, token_count) + + @torch.inference_mode() + def encode_first_stage(self, x): + sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) + sample = self.forge_objects.vae.first_stage_model.process_in(sample) + return sample.to(x) + + @torch.inference_mode() + def decode_first_stage(self, x): + sample = self.forge_objects.vae.first_stage_model.process_out(x) + sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 + + return sample.to(x) diff --git a/backend/loader.py b/backend/loader.py index 50803245..e824f6b8 100644 --- a/backend/loader.py +++ b/backend/loader.py @@ -20,10 +20,11 @@ from backend.nn.unet import IntegratedUNet2DConditionModel from backend.diffusion_engine.sd15 import StableDiffusion from backend.diffusion_engine.sd20 import StableDiffusion2 from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner +from backend.diffusion_engine.sd35 import StableDiffusion3 from backend.diffusion_engine.flux import Flux -possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, Flux] +possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Flux] logging.getLogger("diffusers").setLevel(logging.ERROR) @@ -107,15 +108,18 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale']) return model - if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']: + if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']: assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!' model_loader = None if cls_name == 'UNet2DConditionModel': model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c) - if cls_name == 'FluxTransformer2DModel': + elif cls_name == 'FluxTransformer2DModel': from backend.nn.flux import IntegratedFluxTransformer2DModel model_loader = lambda c: IntegratedFluxTransformer2DModel(**c) + elif cls_name == 'SD3Transformer2DModel': + from backend.nn.mmditx import MMDiTX + model_loader = lambda c: MMDiTX(**c) unet_config = guess.unet_config.copy() state_dict_parameters = memory_management.state_dict_parameters(state_dict) @@ -246,10 +250,10 @@ def replace_state_dict(sd, asd, guess): "-" : None, "sd1" : None, "sd2" : None, - "xlrf": "conditioner.embedders.0.model.", - "sdxl": "conditioner.embedders.1.model.", + "xlrf": "conditioner.embedders.0.model.transformer.", + "sdxl": "conditioner.embedders.1.model.transformer.", "flux": None, - "sd3" : "text_encoders.clip_g.", + "sd3" : "text_encoders.clip_g.transformer.", } ## prefixes used by various model types for CLIP-H prefix_H = { @@ -292,10 +296,10 @@ def replace_state_dict(sd, asd, guess): ## CLIP-G CLIP_G = { # key to identify source model old_prefix - 'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.', - 'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.', + 'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.transformer.', + 'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.transformer.', 'text_model.encoder.layers.0.layer_norm1.bias' : '', - 'transformer.resblocks.0.ln_1.bias' : '' + 'transformer.resblocks.0.ln_1.bias' : 'transformer.' } for CLIP_key in CLIP_G.keys(): if CLIP_key in asd and asd[CLIP_key].shape[0] == 1280: @@ -303,7 +307,7 @@ def replace_state_dict(sd, asd, guess): old_prefix = CLIP_G[CLIP_key] if new_prefix is not None: - if "resblocks" not in CLIP_key: # need to convert + if "resblocks" not in CLIP_key and model_type != "sd3": # need to convert def convert_transformers(statedict, prefix_from, prefix_to, number): keys_to_replace = { "{}text_model.embeddings.position_embedding.weight" : "{}positional_embedding", @@ -320,15 +324,15 @@ def replace_state_dict(sd, asd, guess): "self_attn.out_proj" : "attn.out_proj" , } - for x in keys_to_replace: + for x in keys_to_replace: # remove trailing 'transformer.' from new prefix k = x.format(prefix_from) - statedict[keys_to_replace[x].format(prefix_to)] = statedict.pop(k) + statedict[keys_to_replace[x].format(prefix_to[:-12])] = statedict.pop(k) for resblock in range(number): for y in ["weight", "bias"]: for x in resblock_to_replace: k = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, x, y) - k_to = "{}transformer.resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) + k_to = "{}resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) statedict[k_to] = statedict.pop(k) k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.q_proj", y) @@ -338,14 +342,16 @@ def replace_state_dict(sd, asd, guess): k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.v_proj", y) weightsV = statedict.pop(k_from) - k_to = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y) + k_to = "{}resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y) + statedict[k_to] = torch.cat((weightsQ, weightsK, weightsV)) return statedict asd = convert_transformers(asd, old_prefix, new_prefix, 32) - new_prefix = "" + for k, v in asd.items(): + sd[k] = v - if old_prefix == "": + elif old_prefix == "": for k, v in asd.items(): new_k = new_prefix + k sd[new_k] = v @@ -360,7 +366,7 @@ def replace_state_dict(sd, asd, guess): 'conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'conditioner.embedders.0.transformer.', 'text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_l.transformer.', 'text_model.encoder.layers.0.layer_norm1.bias' : '', - 'transformer.resblocks.0.ln_1.bias' : '' + 'transformer.resblocks.0.ln_1.bias' : 'transformer.' } for CLIP_key in CLIP_L.keys(): @@ -376,6 +382,7 @@ def replace_state_dict(sd, asd, guess): "token_embedding.weight": "{}text_model.embeddings.token_embedding.weight", "ln_final.weight" : "{}text_model.final_layer_norm.weight", "ln_final.bias" : "{}text_model.final_layer_norm.bias", + "text_projection" : "text_projection.weight", } resblock_to_replace = { "ln_1" : "layer_norm1", @@ -391,11 +398,11 @@ def replace_state_dict(sd, asd, guess): for resblock in range(number): for y in ["weight", "bias"]: for x in resblock_to_replace: - k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) + k = "{}resblocks.{}.{}.{}".format(prefix_from, resblock, x, y) k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y) statedict[k_to] = statedict.pop(k) - k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) + k_from = "{}resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y) weights = statedict.pop(k_from) shape_from = weights.shape[0] // 3 for x in range(3): @@ -405,9 +412,10 @@ def replace_state_dict(sd, asd, guess): return statedict asd = transformers_convert(asd, old_prefix, new_prefix, 12) - new_prefix = "" + for k, v in asd.items(): + sd[k] = v - if old_prefix == "": + elif old_prefix == "": for k, v in asd.items(): new_k = new_prefix + k sd[new_k] = v diff --git a/backend/modules/k_prediction.py b/backend/modules/k_prediction.py index e6c3c150..573e56d8 100644 --- a/backend/modules/k_prediction.py +++ b/backend/modules/k_prediction.py @@ -250,6 +250,38 @@ class PredictionFlow(AbstractPrediction): return 1.0 - percent +class PredictionDiscreteFlow(AbstractPrediction): + def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.0, timesteps = 1000): + super().__init__(sigma_data=sigma_data, prediction_type=prediction_type) + self.shift = shift + ts = self.sigma(torch.arange(1, timesteps + 1, 1)) + self.register_buffer("sigmas", ts) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 1.0 + if percent >= 1.0: + return 0.0 + return 1.0 - percent + + class PredictionFlux(AbstractPrediction): def __init__(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, pseudo_timestep_range=10000, mu=None): super().__init__(sigma_data=1.0, prediction_type='const') diff --git a/backend/nn/mmditx.py b/backend/nn/mmditx.py new file mode 100644 index 00000000..7f87948d --- /dev/null +++ b/backend/nn/mmditx.py @@ -0,0 +1,971 @@ +### This file contains impls for MM-DiT, the core model component of SD3 + +## source https://github.com/Stability-AI/sd3.5 +## attention, Mlp : other_impls.py +## all else : mmditx.py + +## minor modifications to MMDiTX.__init__() and MMDiTX.forward() + +import math +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange, repeat + +def attention(q, k, v, heads, mask=None): + """Convenience wrapper around a basic attention operation""" + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + return out.transpose(1, 2).reshape(b, -1, heads * dim_head) + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + dtype=None, + device=None, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear( + in_features, hidden_features, bias=bias, dtype=dtype, device=device + ) + self.act = act_layer + self.fc2 = nn.Linear( + hidden_features, out_features, bias=bias, dtype=dtype, device=device + ) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + flatten: bool = True, + bias: bool = True, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, + dtype=None, + device=None, + ): + super().__init__() + + self.patch_size = (patch_size, patch_size) + if img_size is not None: + self.img_size = (img_size, img_size) + self.grid_size = tuple( + [s // p for s, p in zip(self.img_size, self.patch_size)] + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + dtype=dtype, + device=device, + ) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + return x + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + cls_token=False, + extra_tokens=0, + scaling_factor=None, + offset=None, +): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + + +class TimestepEmbedder(nn.Module): + """Embeds scalar timesteps into vector representations.""" + + def __init__( + self, hidden_size, frequency_embedding_size=256, dtype=None, device=None + ): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + hidden_size, + bias=True, + dtype=dtype, + device=device, + ), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + def forward(self, t, dtype, **kwargs): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class VectorEmbedder(nn.Module): + """Embeds a flat vector of dimension input_dim""" + + def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +################################################################################# +# Core DiT Model # +################################################################################# + + +def split_qkv(qkv, head_dim): + qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0) + return qkv[0], qkv[1], qkv[2] + + +def optimized_attention(qkv, num_heads): + return attention(qkv[0], qkv[1], qkv[2], num_heads) + + +class SelfAttention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + pre_only: bool = False, + qk_norm: Optional[str] = None, + rmsnorm: bool = False, + dtype=None, + device=None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + if not pre_only: + self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm( + self.head_dim, + elementwise_affine=True, + eps=1.0e-6, + dtype=dtype, + device=device, + ) + self.ln_k = RMSNorm( + self.head_dim, + elementwise_affine=True, + eps=1.0e-6, + dtype=dtype, + device=device, + ) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm( + self.head_dim, + elementwise_affine=True, + eps=1.0e-6, + dtype=dtype, + device=device, + ) + self.ln_k = nn.LayerNorm( + self.head_dim, + elementwise_affine=True, + eps=1.0e-6, + dtype=dtype, + device=device, + ) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor): + B, L, C = x.shape + qkv = self.qkv(x) + q, k, v = split_qkv(qkv, self.head_dim) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + (q, k, v) = self.pre_attention(x) + x = attention(q, k, v, self.num_heads) + x = self.post_attention(x) + return x + + +class RMSNorm(torch.nn.Module): + def __init__( + self, + dim: int, + elementwise_affine: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The normalized tensor. + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = self._norm(x) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float] = None, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +class DismantledBlock(nn.Module): + """A DiT block with gated adaptive layer norm (adaLN) conditioning.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + x_block_self_attn: bool = False, + dtype=None, + device=None, + **block_kwargs, + ): + super().__init__() + if not rmsnorm: + self.norm1 = nn.LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + dtype=dtype, + device=device, + ) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + pre_only=pre_only, + qk_norm=qk_norm, + rmsnorm=rmsnorm, + dtype=dtype, + device=device, + ) + if x_block_self_attn: + assert not pre_only + assert not scale_mod_only + self.x_block_self_attn = True + self.attn2 = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + pre_only=False, + qk_norm=qk_norm, + rmsnorm=rmsnorm, + dtype=dtype, + device=device, + ) + else: + self.x_block_self_attn = False + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + dtype=dtype, + device=device, + ) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=nn.GELU(approximate="tanh"), + dtype=dtype, + device=device, + ) + else: + self.mlp = SwiGLUFeedForward( + dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256 + ) + self.scale_mod_only = scale_mod_only + if x_block_self_attn: + assert not pre_only + assert not scale_mod_only + n_mods = 9 + elif not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device + ), + ) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor): + assert x is not None, "pre_attention called with None input" + if not self.pre_only: + if not self.scale_mod_only: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(c).chunk(6, dim=1) + ) + else: + shift_msa = None + shift_mlp = None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation( + c + ).chunk(4, dim=1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) + else: + if not self.scale_mod_only: + shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp( + modulate(self.norm2(x), shift_mlp, scale_mlp) + ) + return x + + def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert self.x_block_self_attn + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + shift_msa2, + scale_msa2, + gate_msa2, + ) = self.adaLN_modulation(c).chunk(9, dim=1) + x_norm = self.norm1(x) + qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) + qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) + return ( + qkv, + qkv2, + ( + x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + gate_msa2, + ), + ) + + def post_attention_x( + self, + attn, + attn2, + x, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + gate_msa2, + attn1_dropout: float = 0.0, + ): + assert not self.pre_only + if attn1_dropout > 0.0: + # Use torch.bernoulli to implement dropout, only dropout the batch dimension + attn1_dropout = torch.bernoulli( + torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device) + ) + attn_ = ( + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout + ) + else: + attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + attn_ + attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2) + x = x + attn2_ + mlp_ = gate_mlp.unsqueeze(1) * self.mlp( + modulate(self.norm2(x), shift_mlp, scale_mlp) + ) + x = x + mlp_ + return x + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + if self.x_block_self_attn: + (q, k, v), (q2, k2, v2), intermediates = self.pre_attention_x(x, c) + attn = attention(q, k, v, self.attn.num_heads) + attn2 = attention(q2, k2, v2, self.attn2.num_heads) + return self.post_attention_x(attn, attn2, *intermediates) + else: + (q, k, v), intermediates = self.pre_attention(x, c) + attn = attention(q, k, v, self.attn.num_heads) + return self.post_attention(attn, *intermediates) + + +def block_mixing(context, x, context_block, x_block, c): + assert context is not None, "block_mixing called with None context" + context_qkv, context_intermediates = context_block.pre_attention(context, c) + + if x_block.x_block_self_attn: + x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c) + else: + x_qkv, x_intermediates = x_block.pre_attention(x, c) + + q, k, v = tuple( + torch.cat(tuple(qkv[i] for qkv in [context_qkv, x_qkv]), dim=1) + for i in range(3) + ) + attn = attention(q, k, v, x_block.attn.num_heads) + context_attn, x_attn = ( + attn[:, : context_qkv[0].shape[1]], + attn[:, context_qkv[0].shape[1] :], + ) + + if not context_block.pre_only: + context = context_block.post_attention(context_attn, *context_intermediates) + else: + context = None + + if x_block.x_block_self_attn: + x_q2, x_k2, x_v2 = x_qkv2 + attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads) + x = x_block.post_attention_x(x_attn, attn2, *x_intermediates) + else: + x = x_block.post_attention(x_attn, *x_intermediates) + + return context, x + + +class JointBlock(nn.Module): + """just a small wrapper to serve as a fsdp unit""" + + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + qk_norm = kwargs.pop("qk_norm", None) + x_block_self_attn = kwargs.pop("x_block_self_attn", False) + self.context_block = DismantledBlock( + *args, pre_only=pre_only, qk_norm=qk_norm, **kwargs + ) + self.x_block = DismantledBlock( + *args, + pre_only=False, + qk_norm=qk_norm, + x_block_self_attn=x_block_self_attn, + **kwargs, + ) + + def forward(self, *args, **kwargs): + return block_mixing( + *args, context_block=self.context_block, x_block=self.x_block, **kwargs + ) + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + total_out_channels: Optional[int] = None, + dtype=None, + device=None, + ): + super().__init__() + self.norm_final = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.linear = ( + nn.Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + dtype=dtype, + device=device, + ) + if (total_out_channels is None) + else nn.Linear( + hidden_size, total_out_channels, bias=True, dtype=dtype, device=device + ) + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device + ), + ) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class MMDiTX(nn.Module): + """Diffusion model with a Transformer backbone.""" + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_config: Optional[Dict] = None, + register_length: int = 0, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches=None, + qk_norm: Optional[str] = None, + x_block_self_attn_layers: Optional[List[int]] = [], + qkv_bias: bool = True, + dtype=None, + device=None, + verbose=False, + ): + super().__init__() + if verbose: + print( + f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}" + ) + self.dtype = dtype + self.learn_sigma = learn_sigma + in_channels = int(in_channels) + self.in_channels = in_channels + # default_out_channels = in_channels * 2 if learn_sigma else in_channels + # self.out_channels = ( + # out_channels if out_channels is not None else default_out_channels + # ) + self.out_channels = 16 # hard coded - detected value can be vastly wrong if nf4 + # but always 16 for sd3 and sd3.5 (learn_sigma always False) + patch_size = int(patch_size) + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = int(pos_embed_max_size) + self.x_block_self_attn_layers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] if self.pos_embed_max_size == 384 else x_block_self_attn_layers + + # apply magic --> this defines a head_size of 64 + depth = int(depth) + hidden_size = int(64 * depth) + num_heads = depth + + self.num_heads = num_heads + + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + hidden_size, + bias=True, + strict_img_size=self.pos_embed_max_size is None, + dtype=dtype, + device=device, + ) + self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device) + + adm_in_channels = int(adm_in_channels) # 2048 + + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = VectorEmbedder( + adm_in_channels, hidden_size, dtype=dtype, device=device + ) + + self.context_embedder = nn.Identity() + if context_embedder_config is not None: + if context_embedder_config["target"] == "torch.nn.Linear": + self.context_embedder = nn.Linear( + **context_embedder_config["params"], dtype=dtype, device=device + ) + + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter( + torch.randn(1, register_length, hidden_size, dtype=dtype, device=device) + ) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + num_patches = int(num_patches) + self.register_buffer( + "pos_embed", + torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device), + ) + else: + self.pos_embed = None + + self.joint_blocks = nn.ModuleList( + [ + JointBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + pre_only=i == depth - 1, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + x_block_self_attn=(i in self.x_block_self_attn_layers), + dtype=dtype, + device=device, + ) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer( + hidden_size, patch_size, self.out_channels, dtype=dtype, device=device + ) + + def cropped_pos_embed(self, hw): + assert self.pos_embed_max_size is not None + p = self.x_embedder.patch_size[0] + h, w = hw + # patched size + h = h // p + w = w // p + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + spatial_pos_embed = rearrange( + self.pos_embed, + "1 (h w) c -> 1 h w c", + h=self.pos_embed_max_size, + w=self.pos_embed_max_size, + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c") + return spatial_pos_embed + + def unpatchify(self, x, hw=None): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, C, H, W) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + if hw is None: + h = w = int(x.shape[1] ** 0.5) + else: + h, w = hw + h = h // p + w = w // p + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + def forward_core_with_concat( + self, + x: torch.Tensor, + c_mod: torch.Tensor, + context: Optional[torch.Tensor] = None, + skip_layers: Optional[List] = [], + controlnet_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.register_length > 0: + context = torch.cat( + ( + repeat(self.register, "1 ... -> b ...", b=x.shape[0]), + context if context is not None else torch.Tensor([]).type_as(x), + ), + 1, + ) + + # context is B, L', D + # x is B, L, D + for i, block in enumerate(self.joint_blocks): + if i in skip_layers: + continue + context, x = block(context, x, c=c_mod) + if controlnet_hidden_states is not None: + controlnet_block_interval = len(self.joint_blocks) // len( + controlnet_hidden_states + ) + x = x + controlnet_hidden_states[i // controlnet_block_interval] + + x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) + return x + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + control=None, transformer_options={}, **kwargs) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + + skip_layers = transformer_options.get("skip_layers", []) + + hw = x.shape[-2:] + + # x = x[:,:16,:,:] + + x = self.x_embedder(x) + self.cropped_pos_embed(hw).to(x.device, x.dtype) + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + context = self.context_embedder(context) + + x = self.forward_core_with_concat(x, c, context, skip_layers, control) + + x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) + return x \ No newline at end of file diff --git a/backend/nn/t5.py b/backend/nn/t5.py index d867f0de..74e0ab70 100644 --- a/backend/nn/t5.py +++ b/backend/nn/t5.py @@ -2,11 +2,13 @@ import torch import math from backend.attention import attention_pytorch as attention_function +from transformers.activations import NewGELUActivation activations = { "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"), "relu": torch.nn.functional.relu, + "gelu_new": lambda a: NewGELUActivation()(a), } diff --git a/backend/text_processing/classic_engine.py b/backend/text_processing/classic_engine.py index 3f82fcd6..57e385ed 100644 --- a/backend/text_processing/classic_engine.py +++ b/backend/text_processing/classic_engine.py @@ -141,7 +141,7 @@ class ClassicTextProcessingEngine: if self.return_pooled: pooled_output = outputs.pooler_output - if self.text_projection: + if self.text_projection and self.embedding_key != 'clip_l': pooled_output = self.text_encoder.transformer.text_projection(pooled_output) z.pooled = pooled_output