mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
SD3+ (#2688)
Co-authored-by: graemeniedermayer graemeniedermayer@users.noreply.github.com
This commit is contained in:
149
backend/diffusion_engine/sd35.py
Normal file
149
backend/diffusion_engine/sd35.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import torch
|
||||
|
||||
from huggingface_guess import model_list
|
||||
from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
from backend.patcher.unet import UnetPatcher
|
||||
from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.text_processing.t5_engine import T5TextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
from backend.modules.k_prediction import PredictionDiscreteFlow
|
||||
|
||||
from modules.shared import opts
|
||||
|
||||
|
||||
## patch SD3 Class in huggingface_guess.model_list
|
||||
def SD3_clip_target(self, state_dict={}):
|
||||
return {'clip_l': 'text_encoder', 'clip_g': 'text_encoder_2', 't5xxl': 'text_encoder_3'}
|
||||
|
||||
model_list.SD3.unet_target = 'transformer'
|
||||
model_list.SD3.clip_target = SD3_clip_target
|
||||
## end patch
|
||||
|
||||
class StableDiffusion3(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD3]
|
||||
|
||||
def __init__(self, estimated_config, huggingface_components):
|
||||
super().__init__(estimated_config, huggingface_components)
|
||||
self.is_inpaint = False
|
||||
|
||||
clip = CLIP(
|
||||
model_dict={
|
||||
'clip_l': huggingface_components['text_encoder'],
|
||||
'clip_g': huggingface_components['text_encoder_2'],
|
||||
't5xxl' : huggingface_components['text_encoder_3']
|
||||
},
|
||||
tokenizer_dict={
|
||||
'clip_l': huggingface_components['tokenizer'],
|
||||
'clip_g': huggingface_components['tokenizer_2'],
|
||||
't5xxl' : huggingface_components['tokenizer_3']
|
||||
}
|
||||
)
|
||||
|
||||
k_predictor = PredictionDiscreteFlow(shift=3.0)
|
||||
|
||||
vae = VAE(model=huggingface_components['vae'])
|
||||
|
||||
unet = UnetPatcher.from_model(
|
||||
model=huggingface_components['transformer'],
|
||||
diffusers_scheduler= None,
|
||||
k_predictor=k_predictor,
|
||||
config=estimated_config
|
||||
)
|
||||
|
||||
self.text_processing_engine_l = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_l,
|
||||
tokenizer=clip.tokenizer.clip_l,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_l',
|
||||
embedding_expected_shape=768,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_g = ClassicTextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.clip_g,
|
||||
tokenizer=clip.tokenizer.clip_g,
|
||||
embedding_dir=dynamic_args['embedding_dir'],
|
||||
embedding_key='clip_g',
|
||||
embedding_expected_shape=1280,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
text_projection=True,
|
||||
minimal_clip_skip=1,
|
||||
clip_skip=1,
|
||||
return_pooled=True,
|
||||
final_layer_norm=False,
|
||||
)
|
||||
|
||||
self.text_processing_engine_t5 = T5TextProcessingEngine(
|
||||
text_encoder=clip.cond_stage_model.t5xxl,
|
||||
tokenizer=clip.tokenizer.t5xxl,
|
||||
emphasis_name=dynamic_args['emphasis_name'],
|
||||
)
|
||||
|
||||
self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None)
|
||||
self.forge_objects_original = self.forge_objects.shallow_copy()
|
||||
self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy()
|
||||
|
||||
# WebUI Legacy
|
||||
self.is_sd3 = True
|
||||
|
||||
def set_clip_skip(self, clip_skip):
|
||||
self.text_processing_engine_l.clip_skip = clip_skip
|
||||
self.text_processing_engine_g.clip_skip = clip_skip
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(self, prompt: list[str]):
|
||||
memory_management.load_model_gpu(self.forge_objects.clip.patcher)
|
||||
|
||||
cond_g, g_pooled = self.text_processing_engine_g(prompt)
|
||||
cond_l, l_pooled = self.text_processing_engine_l(prompt)
|
||||
if opts.sd3_enable_t5:
|
||||
cond_t5 = self.text_processing_engine_t5(prompt)
|
||||
else:
|
||||
cond_t5 = torch.zeros([len(prompt), 256, 4096])
|
||||
|
||||
is_negative_prompt = getattr(prompt, 'is_negative_prompt', False)
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in prompt)
|
||||
|
||||
if force_zero_negative_prompt:
|
||||
l_pooled = torch.zeros_like(l_pooled)
|
||||
g_pooled = torch.zeros_like(g_pooled)
|
||||
cond_l = torch.zeros_like(cond_l)
|
||||
cond_g = torch.zeros_like(cond_g)
|
||||
cond_t5 = torch.zeros_like(cond_t5)
|
||||
|
||||
cond_lg = torch.cat([cond_l, cond_g], dim=-1)
|
||||
cond_lg = torch.nn.functional.pad(cond_lg, (0, 4096 - cond_lg.shape[-1]))
|
||||
|
||||
cond = dict(
|
||||
crossattn=torch.cat([cond_lg, cond_t5], dim=-2),
|
||||
vector=torch.cat([l_pooled, g_pooled], dim=-1),
|
||||
)
|
||||
|
||||
return cond
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_prompt_lengths_on_ui(self, prompt):
|
||||
token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0])
|
||||
return token_count, max(255, token_count)
|
||||
|
||||
@torch.inference_mode()
|
||||
def encode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5)
|
||||
sample = self.forge_objects.vae.first_stage_model.process_in(sample)
|
||||
return sample.to(x)
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode_first_stage(self, x):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
|
||||
return sample.to(x)
|
||||
@@ -20,10 +20,11 @@ from backend.nn.unet import IntegratedUNet2DConditionModel
|
||||
from backend.diffusion_engine.sd15 import StableDiffusion
|
||||
from backend.diffusion_engine.sd20 import StableDiffusion2
|
||||
from backend.diffusion_engine.sdxl import StableDiffusionXL, StableDiffusionXLRefiner
|
||||
from backend.diffusion_engine.sd35 import StableDiffusion3
|
||||
from backend.diffusion_engine.flux import Flux
|
||||
|
||||
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, Flux]
|
||||
possible_models = [StableDiffusion, StableDiffusion2, StableDiffusionXLRefiner, StableDiffusionXL, StableDiffusion3, Flux]
|
||||
|
||||
|
||||
logging.getLogger("diffusers").setLevel(logging.ERROR)
|
||||
@@ -107,15 +108,18 @@ def load_huggingface_component(guess, component_name, lib_name, cls_name, repo_p
|
||||
load_state_dict(model, state_dict, log_name=cls_name, ignore_errors=['transformer.encoder.embed_tokens.weight', 'logit_scale'])
|
||||
|
||||
return model
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel']:
|
||||
if cls_name in ['UNet2DConditionModel', 'FluxTransformer2DModel', 'SD3Transformer2DModel']:
|
||||
assert isinstance(state_dict, dict) and len(state_dict) > 16, 'You do not have model state dict!'
|
||||
|
||||
model_loader = None
|
||||
if cls_name == 'UNet2DConditionModel':
|
||||
model_loader = lambda c: IntegratedUNet2DConditionModel.from_config(c)
|
||||
if cls_name == 'FluxTransformer2DModel':
|
||||
elif cls_name == 'FluxTransformer2DModel':
|
||||
from backend.nn.flux import IntegratedFluxTransformer2DModel
|
||||
model_loader = lambda c: IntegratedFluxTransformer2DModel(**c)
|
||||
elif cls_name == 'SD3Transformer2DModel':
|
||||
from backend.nn.mmditx import MMDiTX
|
||||
model_loader = lambda c: MMDiTX(**c)
|
||||
|
||||
unet_config = guess.unet_config.copy()
|
||||
state_dict_parameters = memory_management.state_dict_parameters(state_dict)
|
||||
@@ -246,10 +250,10 @@ def replace_state_dict(sd, asd, guess):
|
||||
"-" : None,
|
||||
"sd1" : None,
|
||||
"sd2" : None,
|
||||
"xlrf": "conditioner.embedders.0.model.",
|
||||
"sdxl": "conditioner.embedders.1.model.",
|
||||
"xlrf": "conditioner.embedders.0.model.transformer.",
|
||||
"sdxl": "conditioner.embedders.1.model.transformer.",
|
||||
"flux": None,
|
||||
"sd3" : "text_encoders.clip_g.",
|
||||
"sd3" : "text_encoders.clip_g.transformer.",
|
||||
}
|
||||
## prefixes used by various model types for CLIP-H
|
||||
prefix_H = {
|
||||
@@ -292,10 +296,10 @@ def replace_state_dict(sd, asd, guess):
|
||||
|
||||
## CLIP-G
|
||||
CLIP_G = { # key to identify source model old_prefix
|
||||
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.',
|
||||
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.',
|
||||
'conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias' : 'conditioner.embedders.1.model.transformer.',
|
||||
'text_encoders.clip_g.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_g.transformer.',
|
||||
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
||||
'transformer.resblocks.0.ln_1.bias' : ''
|
||||
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
|
||||
}
|
||||
for CLIP_key in CLIP_G.keys():
|
||||
if CLIP_key in asd and asd[CLIP_key].shape[0] == 1280:
|
||||
@@ -303,7 +307,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
old_prefix = CLIP_G[CLIP_key]
|
||||
|
||||
if new_prefix is not None:
|
||||
if "resblocks" not in CLIP_key: # need to convert
|
||||
if "resblocks" not in CLIP_key and model_type != "sd3": # need to convert
|
||||
def convert_transformers(statedict, prefix_from, prefix_to, number):
|
||||
keys_to_replace = {
|
||||
"{}text_model.embeddings.position_embedding.weight" : "{}positional_embedding",
|
||||
@@ -320,15 +324,15 @@ def replace_state_dict(sd, asd, guess):
|
||||
"self_attn.out_proj" : "attn.out_proj" ,
|
||||
}
|
||||
|
||||
for x in keys_to_replace:
|
||||
for x in keys_to_replace: # remove trailing 'transformer.' from new prefix
|
||||
k = x.format(prefix_from)
|
||||
statedict[keys_to_replace[x].format(prefix_to)] = statedict.pop(k)
|
||||
statedict[keys_to_replace[x].format(prefix_to[:-12])] = statedict.pop(k)
|
||||
|
||||
for resblock in range(number):
|
||||
for y in ["weight", "bias"]:
|
||||
for x in resblock_to_replace:
|
||||
k = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}transformer.resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
k_to = "{}resblocks.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
statedict[k_to] = statedict.pop(k)
|
||||
|
||||
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.q_proj", y)
|
||||
@@ -338,14 +342,16 @@ def replace_state_dict(sd, asd, guess):
|
||||
k_from = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_from, resblock, "self_attn.v_proj", y)
|
||||
weightsV = statedict.pop(k_from)
|
||||
|
||||
k_to = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
|
||||
k_to = "{}resblocks.{}.attn.in_proj_{}".format(prefix_to, resblock, y)
|
||||
|
||||
statedict[k_to] = torch.cat((weightsQ, weightsK, weightsV))
|
||||
return statedict
|
||||
|
||||
asd = convert_transformers(asd, old_prefix, new_prefix, 32)
|
||||
new_prefix = ""
|
||||
for k, v in asd.items():
|
||||
sd[k] = v
|
||||
|
||||
if old_prefix == "":
|
||||
elif old_prefix == "":
|
||||
for k, v in asd.items():
|
||||
new_k = new_prefix + k
|
||||
sd[new_k] = v
|
||||
@@ -360,7 +366,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
'conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'conditioner.embedders.0.transformer.',
|
||||
'text_encoders.clip_l.transformer.text_model.encoder.layers.0.layer_norm1.bias' : 'text_encoders.clip_l.transformer.',
|
||||
'text_model.encoder.layers.0.layer_norm1.bias' : '',
|
||||
'transformer.resblocks.0.ln_1.bias' : ''
|
||||
'transformer.resblocks.0.ln_1.bias' : 'transformer.'
|
||||
}
|
||||
|
||||
for CLIP_key in CLIP_L.keys():
|
||||
@@ -376,6 +382,7 @@ def replace_state_dict(sd, asd, guess):
|
||||
"token_embedding.weight": "{}text_model.embeddings.token_embedding.weight",
|
||||
"ln_final.weight" : "{}text_model.final_layer_norm.weight",
|
||||
"ln_final.bias" : "{}text_model.final_layer_norm.bias",
|
||||
"text_projection" : "text_projection.weight",
|
||||
}
|
||||
resblock_to_replace = {
|
||||
"ln_1" : "layer_norm1",
|
||||
@@ -391,11 +398,11 @@ def replace_state_dict(sd, asd, guess):
|
||||
for resblock in range(number):
|
||||
for y in ["weight", "bias"]:
|
||||
for x in resblock_to_replace:
|
||||
k = "{}transformer.resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k = "{}resblocks.{}.{}.{}".format(prefix_from, resblock, x, y)
|
||||
k_to = "{}text_model.encoder.layers.{}.{}.{}".format(prefix_to, resblock, resblock_to_replace[x], y)
|
||||
statedict[k_to] = statedict.pop(k)
|
||||
|
||||
k_from = "{}transformer.resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
k_from = "{}resblocks.{}.attn.in_proj_{}".format(prefix_from, resblock, y)
|
||||
weights = statedict.pop(k_from)
|
||||
shape_from = weights.shape[0] // 3
|
||||
for x in range(3):
|
||||
@@ -405,9 +412,10 @@ def replace_state_dict(sd, asd, guess):
|
||||
return statedict
|
||||
|
||||
asd = transformers_convert(asd, old_prefix, new_prefix, 12)
|
||||
new_prefix = ""
|
||||
for k, v in asd.items():
|
||||
sd[k] = v
|
||||
|
||||
if old_prefix == "":
|
||||
elif old_prefix == "":
|
||||
for k, v in asd.items():
|
||||
new_k = new_prefix + k
|
||||
sd[new_k] = v
|
||||
|
||||
@@ -250,6 +250,38 @@ class PredictionFlow(AbstractPrediction):
|
||||
return 1.0 - percent
|
||||
|
||||
|
||||
class PredictionDiscreteFlow(AbstractPrediction):
|
||||
def __init__(self, sigma_data=1.0, prediction_type='const', shift=1.0, timesteps = 1000):
|
||||
super().__init__(sigma_data=sigma_data, prediction_type=prediction_type)
|
||||
self.shift = shift
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.register_buffer("sigmas", ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep: torch.Tensor):
|
||||
timestep = timestep / 1000.0
|
||||
if self.shift == 1.0:
|
||||
return timestep
|
||||
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 1.0
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return 1.0 - percent
|
||||
|
||||
|
||||
class PredictionFlux(AbstractPrediction):
|
||||
def __init__(self, seq_len=4096, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15, pseudo_timestep_range=10000, mu=None):
|
||||
super().__init__(sigma_data=1.0, prediction_type='const')
|
||||
|
||||
971
backend/nn/mmditx.py
Normal file
971
backend/nn/mmditx.py
Normal file
@@ -0,0 +1,971 @@
|
||||
### This file contains impls for MM-DiT, the core model component of SD3
|
||||
|
||||
## source https://github.com/Stability-AI/sd3.5
|
||||
## attention, Mlp : other_impls.py
|
||||
## all else : mmditx.py
|
||||
|
||||
## minor modifications to MMDiTX.__init__() and MMDiTX.forward()
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
def attention(q, k, v, heads, mask=None):
|
||||
"""Convenience wrapper around a basic attention operation"""
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
return out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
bias=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
||||
self.fc1 = nn.Linear(
|
||||
in_features, hidden_features, bias=bias, dtype=dtype, device=device
|
||||
)
|
||||
self.act = act_layer
|
||||
self.fc2 = nn.Linear(
|
||||
hidden_features, out_features, bias=bias, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Optional[int] = 224,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
flatten: bool = True,
|
||||
bias: bool = True,
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple(
|
||||
[s // p for s, p in zip(self.img_size, self.patch_size)]
|
||||
)
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
return x
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
if shift is None:
|
||||
shift = torch.zeros_like(scale)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim,
|
||||
grid_size,
|
||||
cls_token=False,
|
||||
extra_tokens=0,
|
||||
scaling_factor=None,
|
||||
offset=None,
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
if scaling_factor is not None:
|
||||
grid = grid / scaling_factor
|
||||
if offset is not None:
|
||||
grid = grid - offset
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate(
|
||||
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""Embeds scalar timesteps into vector representations."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size, frequency_embedding_size=256, dtype=None, device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
frequency_embedding_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
if torch.is_floating_point(t):
|
||||
embedding = embedding.to(dtype=t.dtype)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class VectorEmbedder(nn.Module):
|
||||
"""Embeds a flat vector of dimension input_dim"""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core DiT Model #
|
||||
#################################################################################
|
||||
|
||||
|
||||
def split_qkv(qkv, head_dim):
|
||||
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
|
||||
return qkv[0], qkv[1], qkv[2]
|
||||
|
||||
|
||||
def optimized_attention(qkv, num_heads):
|
||||
return attention(qkv[0], qkv[1], qkv[2], num_heads)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: Optional[float] = None,
|
||||
pre_only: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
rmsnorm: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if not pre_only:
|
||||
self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.pre_only = pre_only
|
||||
|
||||
if qk_norm == "rms":
|
||||
self.ln_q = RMSNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.ln_k = RMSNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
elif qk_norm == "ln":
|
||||
self.ln_q = nn.LayerNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.ln_k = nn.LayerNorm(
|
||||
self.head_dim,
|
||||
elementwise_affine=True,
|
||||
eps=1.0e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
elif qk_norm is None:
|
||||
self.ln_q = nn.Identity()
|
||||
self.ln_k = nn.Identity()
|
||||
else:
|
||||
raise ValueError(qk_norm)
|
||||
|
||||
def pre_attention(self, x: torch.Tensor):
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = split_qkv(qkv, self.head_dim)
|
||||
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
|
||||
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
|
||||
return (q, k, v)
|
||||
|
||||
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
(q, k, v) = self.pre_attention(x)
|
||||
x = attention(q, k, v, self.num_heads)
|
||||
x = self.post_attention(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine: bool = False,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.learnable_scale = elementwise_affine
|
||||
if self.learnable_scale:
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
"""
|
||||
x = self._norm(x)
|
||||
if self.learnable_scale:
|
||||
return x * self.weight.to(device=x.device, dtype=x.dtype)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLUFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the FeedForward module.
|
||||
|
||||
Args:
|
||||
dim (int): Input dimension.
|
||||
hidden_dim (int): Hidden dimension of the feedforward layer.
|
||||
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
||||
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
||||
|
||||
Attributes:
|
||||
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
||||
w2 (RowParallelLinear): Linear transformation for the second layer.
|
||||
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
# custom dim factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
|
||||
|
||||
|
||||
class DismantledBlock(nn.Module):
|
||||
"""A DiT block with gated adaptive layer norm (adaLN) conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = False,
|
||||
pre_only: bool = False,
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn: bool = False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
**block_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if not rmsnorm:
|
||||
self.norm1 = nn.LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
pre_only=pre_only,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
self.x_block_self_attn = True
|
||||
self.attn2 = SelfAttention(
|
||||
dim=hidden_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
rmsnorm=rmsnorm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.x_block_self_attn = False
|
||||
if not pre_only:
|
||||
if not rmsnorm:
|
||||
self.norm2 = nn.LayerNorm(
|
||||
hidden_size,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
if not pre_only:
|
||||
if not swiglu:
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=nn.GELU(approximate="tanh"),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.mlp = SwiGLUFeedForward(
|
||||
dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256
|
||||
)
|
||||
self.scale_mod_only = scale_mod_only
|
||||
if x_block_self_attn:
|
||||
assert not pre_only
|
||||
assert not scale_mod_only
|
||||
n_mods = 9
|
||||
elif not scale_mod_only:
|
||||
n_mods = 6 if not pre_only else 2
|
||||
else:
|
||||
n_mods = 4 if not pre_only else 1
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
self.pre_only = pre_only
|
||||
|
||||
def pre_attention(self, x: torch.Tensor, c: torch.Tensor):
|
||||
assert x is not None, "pre_attention called with None input"
|
||||
if not self.pre_only:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.adaLN_modulation(c).chunk(6, dim=1)
|
||||
)
|
||||
else:
|
||||
shift_msa = None
|
||||
shift_mlp = None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(
|
||||
c
|
||||
).chunk(4, dim=1)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)
|
||||
else:
|
||||
if not self.scale_mod_only:
|
||||
shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
else:
|
||||
shift_msa = None
|
||||
scale_msa = self.adaLN_modulation(c)
|
||||
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
return qkv, None
|
||||
|
||||
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
|
||||
assert not self.pre_only
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
return x
|
||||
|
||||
def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert self.x_block_self_attn
|
||||
(
|
||||
shift_msa,
|
||||
scale_msa,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
shift_msa2,
|
||||
scale_msa2,
|
||||
gate_msa2,
|
||||
) = self.adaLN_modulation(c).chunk(9, dim=1)
|
||||
x_norm = self.norm1(x)
|
||||
qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa))
|
||||
qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2))
|
||||
return (
|
||||
qkv,
|
||||
qkv2,
|
||||
(
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
),
|
||||
)
|
||||
|
||||
def post_attention_x(
|
||||
self,
|
||||
attn,
|
||||
attn2,
|
||||
x,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
gate_msa2,
|
||||
attn1_dropout: float = 0.0,
|
||||
):
|
||||
assert not self.pre_only
|
||||
if attn1_dropout > 0.0:
|
||||
# Use torch.bernoulli to implement dropout, only dropout the batch dimension
|
||||
attn1_dropout = torch.bernoulli(
|
||||
torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)
|
||||
)
|
||||
attn_ = (
|
||||
gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout
|
||||
)
|
||||
else:
|
||||
attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
|
||||
x = x + attn_
|
||||
attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2)
|
||||
x = x + attn2_
|
||||
mlp_ = gate_mlp.unsqueeze(1) * self.mlp(
|
||||
modulate(self.norm2(x), shift_mlp, scale_mlp)
|
||||
)
|
||||
x = x + mlp_
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
assert not self.pre_only
|
||||
if self.x_block_self_attn:
|
||||
(q, k, v), (q2, k2, v2), intermediates = self.pre_attention_x(x, c)
|
||||
attn = attention(q, k, v, self.attn.num_heads)
|
||||
attn2 = attention(q2, k2, v2, self.attn2.num_heads)
|
||||
return self.post_attention_x(attn, attn2, *intermediates)
|
||||
else:
|
||||
(q, k, v), intermediates = self.pre_attention(x, c)
|
||||
attn = attention(q, k, v, self.attn.num_heads)
|
||||
return self.post_attention(attn, *intermediates)
|
||||
|
||||
|
||||
def block_mixing(context, x, context_block, x_block, c):
|
||||
assert context is not None, "block_mixing called with None context"
|
||||
context_qkv, context_intermediates = context_block.pre_attention(context, c)
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
x_qkv, x_qkv2, x_intermediates = x_block.pre_attention_x(x, c)
|
||||
else:
|
||||
x_qkv, x_intermediates = x_block.pre_attention(x, c)
|
||||
|
||||
q, k, v = tuple(
|
||||
torch.cat(tuple(qkv[i] for qkv in [context_qkv, x_qkv]), dim=1)
|
||||
for i in range(3)
|
||||
)
|
||||
attn = attention(q, k, v, x_block.attn.num_heads)
|
||||
context_attn, x_attn = (
|
||||
attn[:, : context_qkv[0].shape[1]],
|
||||
attn[:, context_qkv[0].shape[1] :],
|
||||
)
|
||||
|
||||
if not context_block.pre_only:
|
||||
context = context_block.post_attention(context_attn, *context_intermediates)
|
||||
else:
|
||||
context = None
|
||||
|
||||
if x_block.x_block_self_attn:
|
||||
x_q2, x_k2, x_v2 = x_qkv2
|
||||
attn2 = attention(x_q2, x_k2, x_v2, x_block.attn2.num_heads)
|
||||
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
|
||||
else:
|
||||
x = x_block.post_attention(x_attn, *x_intermediates)
|
||||
|
||||
return context, x
|
||||
|
||||
|
||||
class JointBlock(nn.Module):
|
||||
"""just a small wrapper to serve as a fsdp unit"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
pre_only = kwargs.pop("pre_only")
|
||||
qk_norm = kwargs.pop("qk_norm", None)
|
||||
x_block_self_attn = kwargs.pop("x_block_self_attn", False)
|
||||
self.context_block = DismantledBlock(
|
||||
*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs
|
||||
)
|
||||
self.x_block = DismantledBlock(
|
||||
*args,
|
||||
pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=x_block_self_attn,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return block_mixing(
|
||||
*args, context_block=self.context_block, x_block=self.x_block, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
patch_size: int,
|
||||
out_channels: int,
|
||||
total_out_channels: Optional[int] = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
||||
)
|
||||
self.linear = (
|
||||
nn.Linear(
|
||||
hidden_size,
|
||||
patch_size * patch_size * out_channels,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
if (total_out_channels is None)
|
||||
else nn.Linear(
|
||||
hidden_size, total_out_channels, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class MMDiTX(nn.Module):
|
||||
"""Diffusion model with a Transformer backbone."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int = 32,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 4,
|
||||
depth: int = 28,
|
||||
mlp_ratio: float = 4.0,
|
||||
learn_sigma: bool = False,
|
||||
adm_in_channels: Optional[int] = None,
|
||||
context_embedder_config: Optional[Dict] = None,
|
||||
register_length: int = 0,
|
||||
rmsnorm: bool = False,
|
||||
scale_mod_only: bool = False,
|
||||
swiglu: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
pos_embed_scaling_factor: Optional[float] = None,
|
||||
pos_embed_offset: Optional[float] = None,
|
||||
pos_embed_max_size: Optional[int] = None,
|
||||
num_patches=None,
|
||||
qk_norm: Optional[str] = None,
|
||||
x_block_self_attn_layers: Optional[List[int]] = [],
|
||||
qkv_bias: bool = True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
if verbose:
|
||||
print(
|
||||
f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.learn_sigma = learn_sigma
|
||||
in_channels = int(in_channels)
|
||||
self.in_channels = in_channels
|
||||
# default_out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
# self.out_channels = (
|
||||
# out_channels if out_channels is not None else default_out_channels
|
||||
# )
|
||||
self.out_channels = 16 # hard coded - detected value can be vastly wrong if nf4
|
||||
# but always 16 for sd3 and sd3.5 (learn_sigma always False)
|
||||
patch_size = int(patch_size)
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_scaling_factor = pos_embed_scaling_factor
|
||||
self.pos_embed_offset = pos_embed_offset
|
||||
self.pos_embed_max_size = int(pos_embed_max_size)
|
||||
self.x_block_self_attn_layers = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] if self.pos_embed_max_size == 384 else x_block_self_attn_layers
|
||||
|
||||
# apply magic --> this defines a head_size of 64
|
||||
depth = int(depth)
|
||||
hidden_size = int(64 * depth)
|
||||
num_heads = depth
|
||||
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.x_embedder = PatchEmbed(
|
||||
input_size,
|
||||
patch_size,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
strict_img_size=self.pos_embed_max_size is None,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)
|
||||
|
||||
adm_in_channels = int(adm_in_channels) # 2048
|
||||
|
||||
if adm_in_channels is not None:
|
||||
assert isinstance(adm_in_channels, int)
|
||||
self.y_embedder = VectorEmbedder(
|
||||
adm_in_channels, hidden_size, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.context_embedder = nn.Identity()
|
||||
if context_embedder_config is not None:
|
||||
if context_embedder_config["target"] == "torch.nn.Linear":
|
||||
self.context_embedder = nn.Linear(
|
||||
**context_embedder_config["params"], dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.register_length = register_length
|
||||
if self.register_length > 0:
|
||||
self.register = nn.Parameter(
|
||||
torch.randn(1, register_length, hidden_size, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
# num_patches = self.x_embedder.num_patches
|
||||
# Will use fixed sin-cos embedding:
|
||||
# just use a buffer already
|
||||
if num_patches is not None:
|
||||
num_patches = int(num_patches)
|
||||
self.register_buffer(
|
||||
"pos_embed",
|
||||
torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
self.joint_blocks = nn.ModuleList(
|
||||
[
|
||||
JointBlock(
|
||||
hidden_size,
|
||||
num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
pre_only=i == depth - 1,
|
||||
rmsnorm=rmsnorm,
|
||||
scale_mod_only=scale_mod_only,
|
||||
swiglu=swiglu,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn=(i in self.x_block_self_attn_layers),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(
|
||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
def cropped_pos_embed(self, hw):
|
||||
assert self.pos_embed_max_size is not None
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h, w = hw
|
||||
# patched size
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
|
||||
top = (self.pos_embed_max_size - h) // 2
|
||||
left = (self.pos_embed_max_size - w) // 2
|
||||
spatial_pos_embed = rearrange(
|
||||
self.pos_embed,
|
||||
"1 (h w) c -> 1 h w c",
|
||||
h=self.pos_embed_max_size,
|
||||
w=self.pos_embed_max_size,
|
||||
)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
|
||||
spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c")
|
||||
return spatial_pos_embed
|
||||
|
||||
def unpatchify(self, x, hw=None):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, C, H, W)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
if hw is None:
|
||||
h = w = int(x.shape[1] ** 0.5)
|
||||
else:
|
||||
h, w = hw
|
||||
h = h // p
|
||||
w = w // p
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum("nhwpqc->nchpwq", x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
|
||||
def forward_core_with_concat(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c_mod: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
skip_layers: Optional[List] = [],
|
||||
controlnet_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.register_length > 0:
|
||||
context = torch.cat(
|
||||
(
|
||||
repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
|
||||
context if context is not None else torch.Tensor([]).type_as(x),
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
# context is B, L', D
|
||||
# x is B, L, D
|
||||
for i, block in enumerate(self.joint_blocks):
|
||||
if i in skip_layers:
|
||||
continue
|
||||
context, x = block(context, x, c=c_mod)
|
||||
if controlnet_hidden_states is not None:
|
||||
controlnet_block_interval = len(self.joint_blocks) // len(
|
||||
controlnet_hidden_states
|
||||
)
|
||||
x = x + controlnet_hidden_states[i // controlnet_block_interval]
|
||||
|
||||
x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
control=None, transformer_options={}, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of class labels
|
||||
"""
|
||||
|
||||
skip_layers = transformer_options.get("skip_layers", [])
|
||||
|
||||
hw = x.shape[-2:]
|
||||
|
||||
# x = x[:,:16,:,:]
|
||||
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw).to(x.device, x.dtype)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
c = c + y # (N, D)
|
||||
|
||||
context = self.context_embedder(context)
|
||||
|
||||
x = self.forward_core_with_concat(x, c, context, skip_layers, control)
|
||||
|
||||
x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W)
|
||||
return x
|
||||
@@ -2,11 +2,13 @@ import torch
|
||||
import math
|
||||
|
||||
from backend.attention import attention_pytorch as attention_function
|
||||
from transformers.activations import NewGELUActivation
|
||||
|
||||
|
||||
activations = {
|
||||
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
|
||||
"relu": torch.nn.functional.relu,
|
||||
"gelu_new": lambda a: NewGELUActivation()(a),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ class ClassicTextProcessingEngine:
|
||||
if self.return_pooled:
|
||||
pooled_output = outputs.pooler_output
|
||||
|
||||
if self.text_projection:
|
||||
if self.text_projection and self.embedding_key != 'clip_l':
|
||||
pooled_output = self.text_encoder.transformer.text_projection(pooled_output)
|
||||
|
||||
z.pooled = pooled_output
|
||||
|
||||
Reference in New Issue
Block a user