mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
merged in lumina2 branch
This commit is contained in:
@@ -424,6 +424,7 @@ class ModelConfig:
|
||||
self.is_auraflow: bool = kwargs.get('is_auraflow', False)
|
||||
self.is_v3: bool = kwargs.get('is_v3', False)
|
||||
self.is_flux: bool = kwargs.get('is_flux', False)
|
||||
self.is_lumina2: bool = kwargs.get('is_lumina2', False)
|
||||
if self.is_pixart_sigma:
|
||||
self.is_pixart = True
|
||||
self.use_flux_cfg = kwargs.get('use_flux_cfg', False)
|
||||
|
||||
@@ -63,7 +63,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
torch.nn.Module.__init__(self)
|
||||
self.lora_name = lora_name
|
||||
self.orig_module_ref = weakref.ref(org_module)
|
||||
self.scalar = torch.tensor(1.0)
|
||||
self.scalar = torch.tensor(1.0, device=org_module.weight.device)
|
||||
# check if parent has bias. if not force use_bias to False
|
||||
if org_module.bias is None:
|
||||
use_bias = False
|
||||
@@ -163,6 +163,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
is_pixart: bool = False,
|
||||
is_auraflow: bool = False,
|
||||
is_flux: bool = False,
|
||||
is_lumina2: bool = False,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
ignore_if_contains = None,
|
||||
@@ -223,6 +224,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.is_pixart = is_pixart
|
||||
self.is_auraflow = is_auraflow
|
||||
self.is_flux = is_flux
|
||||
self.is_lumina2 = is_lumina2
|
||||
self.network_type = network_type
|
||||
self.is_assistant_adapter = is_assistant_adapter
|
||||
if self.network_type.lower() == "dora":
|
||||
@@ -232,7 +234,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.peft_format = peft_format
|
||||
|
||||
# always do peft for flux only for now
|
||||
if self.is_flux or self.is_v3:
|
||||
if self.is_flux or self.is_v3 or self.is_lumina2:
|
||||
self.peft_format = True
|
||||
|
||||
if self.peft_format:
|
||||
@@ -273,7 +275,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
unet_prefix = self.LORA_PREFIX_UNET
|
||||
if self.peft_format:
|
||||
unet_prefix = self.PEFT_PREFIX_UNET
|
||||
if is_pixart or is_v3 or is_auraflow or is_flux:
|
||||
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2:
|
||||
unet_prefix = f"lora_transformer"
|
||||
if self.peft_format:
|
||||
unet_prefix = "transformer"
|
||||
@@ -326,6 +328,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if self.transformer_only and self.is_flux and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
if self.transformer_only and self.is_lumina2 and is_unet:
|
||||
if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name:
|
||||
skip = True
|
||||
if self.transformer_only and self.is_v3 and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
@@ -431,6 +436,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
|
||||
if is_flux:
|
||||
target_modules = ["FluxTransformer2DModel"]
|
||||
|
||||
if is_lumina2:
|
||||
target_modules = ["Lumina2Transformer2DModel"]
|
||||
|
||||
if train_unet:
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
|
||||
560
toolkit/models/lumina2.py
Normal file
560
toolkit/models/lumina2.py
Normal file
@@ -0,0 +1,560 @@
|
||||
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import PeftAdapterMixin
|
||||
from diffusers.utils import logging
|
||||
from diffusers.models.attention import LuminaFeedForward
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
||||
import torch
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
do_profile = False
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 4096,
|
||||
cap_feat_dim: int = 2048,
|
||||
frequency_embedding_size: int = 256,
|
||||
norm_eps: float = 1e-5,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
||||
)
|
||||
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
|
||||
)
|
||||
|
||||
self.caption_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
timestep_proj = self.time_proj(timestep).type_as(hidden_states)
|
||||
time_embed = self.timestep_embedder(timestep_proj)
|
||||
caption_embed = self.caption_embedder(encoder_hidden_states)
|
||||
return time_embed, caption_embed
|
||||
|
||||
|
||||
class Lumina2AttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Get Query-Key-Value Pair
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply Query-Key Norm if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Apply proportional attention if true
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# perform Grouped-qurey Attention (GQA)
|
||||
n_rep = attn.heads // kv_heads
|
||||
if n_rep >= 1:
|
||||
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
||||
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Lumina2TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
num_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
modulation: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_dim = dim // num_attention_heads
|
||||
self.modulation = modulation
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
qk_norm="rms_norm",
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_kv_heads,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=Lumina2AttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
dim=dim,
|
||||
inner_dim=4 * dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
)
|
||||
|
||||
if modulation:
|
||||
self.norm1 = LuminaRMSNormZero(
|
||||
embedding_dim=dim,
|
||||
norm_eps=norm_eps,
|
||||
norm_elementwise_affine=True,
|
||||
)
|
||||
else:
|
||||
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.modulation:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + self.norm2(attn_output)
|
||||
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
|
||||
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Lumina2RotaryPosEmbed(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
self.axes_lens = axes_lens
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
|
||||
|
||||
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
|
||||
freqs_cis = []
|
||||
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
|
||||
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64)
|
||||
freqs_cis.append(emb)
|
||||
return freqs_cis
|
||||
|
||||
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
result = []
|
||||
for i in range(len(self.axes_dim)):
|
||||
freqs = self.freqs_cis[i].to(ids.device)
|
||||
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
|
||||
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
|
||||
batch_size = len(hidden_states)
|
||||
p_h = p_w = self.patch_size
|
||||
device = hidden_states[0].device
|
||||
|
||||
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
|
||||
# TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
|
||||
l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes]
|
||||
|
||||
max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)))
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
|
||||
for i in range(batch_size):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // p_h, W // p_w
|
||||
assert H_tokens * W_tokens == img_len
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
|
||||
row_ids = (
|
||||
torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
)
|
||||
col_ids = (
|
||||
torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
)
|
||||
position_ids[i, cap_len : cap_len + img_len, 1] = row_ids
|
||||
position_ids[i, cap_len : cap_len + img_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self._get_freqs_cis(position_ids)
|
||||
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
cap_freqs_cis_shape[1] = attention_mask.shape[1]
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i in range(batch_size):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]
|
||||
|
||||
flat_hidden_states = []
|
||||
for i in range(batch_size):
|
||||
img = hidden_states[i]
|
||||
C, H, W = img.size()
|
||||
img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
||||
flat_hidden_states.append(img)
|
||||
hidden_states = flat_hidden_states
|
||||
padded_img_embed = torch.zeros(
|
||||
batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype
|
||||
)
|
||||
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
|
||||
for i in range(batch_size):
|
||||
padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i]
|
||||
padded_img_mask[i, : l_effective_img_len[i]] = True
|
||||
|
||||
return (
|
||||
padded_img_embed,
|
||||
padded_img_mask,
|
||||
img_sizes,
|
||||
l_effective_cap_len,
|
||||
l_effective_img_len,
|
||||
freqs_cis,
|
||||
cap_freqs_cis,
|
||||
img_freqs_cis,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
|
||||
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
Lumina2NextDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
|
||||
The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
|
||||
in_channels (`int`, *optional*, defaults to 4):
|
||||
The number of input channels for the model. Typically, this matches the number of channels in the input
|
||||
images.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
||||
hidden representations.
|
||||
num_layers (`int`, *optional*, default to 32):
|
||||
The number of layers in the model. This defines the depth of the neural network.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
The number of attention heads in each attention layer. This parameter specifies how many separate attention
|
||||
mechanisms are used.
|
||||
num_kv_heads (`int`, *optional*, defaults to 8):
|
||||
The number of key-value heads in the attention mechanism, if different from the number of attention heads.
|
||||
If None, it defaults to num_attention_heads.
|
||||
multiple_of (`int`, *optional*, defaults to 256):
|
||||
A factor that the hidden size should be a multiple of. This can help optimize certain hardware
|
||||
configurations.
|
||||
ffn_dim_multiplier (`float`, *optional*):
|
||||
A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
|
||||
the model configuration.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
A small value added to the denominator for numerical stability in normalization layers.
|
||||
scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
|
||||
overall scale of the model's operations.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["Lumina2TransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["x_embedder", "norm"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 2304,
|
||||
num_layers: int = 26,
|
||||
num_refiner_layers: int = 2,
|
||||
num_attention_heads: int = 24,
|
||||
num_kv_heads: int = 8,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
scaling_factor: float = 1.0,
|
||||
axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
|
||||
axes_lens: Tuple[int, int, int] = (300, 512, 512),
|
||||
cap_feat_dim: int = 1024,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Positional, patch & conditional embeddings
|
||||
self.rope_embedder = Lumina2RotaryPosEmbed(
|
||||
theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
|
||||
)
|
||||
|
||||
self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
|
||||
|
||||
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||
hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
|
||||
)
|
||||
|
||||
# 2. Noise and context refinement blocks
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
Lumina2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True,
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
Lumina2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=False,
|
||||
)
|
||||
for _ in range(num_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Lumina2TransformerBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
modulation=True,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
out_dim=patch_size * patch_size * self.out_channels,
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if do_profile:
|
||||
prof = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
|
||||
prof.start()
|
||||
|
||||
# 1. Condition, positional & patch embedding
|
||||
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
|
||||
|
||||
(
|
||||
hidden_states,
|
||||
hidden_mask,
|
||||
hidden_sizes,
|
||||
encoder_hidden_len,
|
||||
hidden_len,
|
||||
joint_rotary_emb,
|
||||
encoder_rotary_emb,
|
||||
hidden_rotary_emb,
|
||||
max_seq_len,
|
||||
) = self.rope_embedder(hidden_states, attention_mask)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# 2. Context & noise refinement
|
||||
for layer in self.context_refiner:
|
||||
encoder_hidden_states = layer(encoder_hidden_states, attention_mask, encoder_rotary_emb)
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb)
|
||||
|
||||
# 3. Attention mask preparation
|
||||
mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
|
||||
padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
|
||||
for i in range(batch_size):
|
||||
cap_len = encoder_hidden_len[i]
|
||||
img_len = hidden_len[i]
|
||||
mask[i, : cap_len + img_len] = True
|
||||
padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len]
|
||||
padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len]
|
||||
hidden_states = padded_hidden_states
|
||||
|
||||
# 4. Transformer blocks
|
||||
for layer in self.layers:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(layer, hidden_states, mask, joint_rotary_emb, temb)
|
||||
else:
|
||||
hidden_states = layer(hidden_states, mask, joint_rotary_emb, temb)
|
||||
|
||||
# 5. Output norm & projection & unpatchify
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
height_tokens = width_tokens = self.config.patch_size
|
||||
output = []
|
||||
for i in range(len(hidden_sizes)):
|
||||
height, width = hidden_sizes[i]
|
||||
begin = encoder_hidden_len[i]
|
||||
end = begin + (height // height_tokens) * (width // width_tokens)
|
||||
output.append(
|
||||
hidden_states[i][begin:end]
|
||||
.view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels)
|
||||
.permute(4, 0, 2, 1, 3)
|
||||
.flatten(3, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if do_profile:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
prof.stop()
|
||||
|
||||
print("\n==== Profile Results ====")
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -100,6 +100,23 @@ sd_flow_config = {
|
||||
"use_dynamic_shifting": False
|
||||
}
|
||||
|
||||
lumina2_config = {
|
||||
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
||||
"_diffusers_version": "0.33.0.dev0",
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.5,
|
||||
"invert_sigmas": False,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 1.15,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 6.0,
|
||||
"shift_terminal": None,
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": False,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_karras_sigmas": False
|
||||
}
|
||||
|
||||
|
||||
def get_sampler(
|
||||
sampler: str,
|
||||
@@ -147,6 +164,13 @@ def get_sampler(
|
||||
config_to_use = copy.deepcopy(flux_config)
|
||||
if arch == "sd":
|
||||
config_to_use = copy.deepcopy(sd_flow_config)
|
||||
if arch == "flux":
|
||||
config_to_use = copy.deepcopy(flux_config)
|
||||
elif arch == "lumina2":
|
||||
config_to_use = copy.deepcopy(lumina2_config)
|
||||
else:
|
||||
# use flux by default
|
||||
config_to_use = copy.deepcopy(flux_config)
|
||||
else:
|
||||
raise ValueError(f"Sampler {sampler} not supported")
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
return timesteps
|
||||
elif timestep_type == 'flux_shift':
|
||||
elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift':
|
||||
# matches inference dynamic shifting
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps
|
||||
|
||||
@@ -49,7 +49,8 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
||||
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline
|
||||
from toolkit.models.lumina2 import Lumina2Transformer2DModel
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -67,6 +68,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.print import print_acc
|
||||
from diffusers import FluxFillPipeline
|
||||
from transformers import AutoModel, AutoTokenizer, Gemma2Model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
@@ -182,6 +184,7 @@ class StableDiffusion:
|
||||
self.is_pixart = model_config.is_pixart
|
||||
self.is_auraflow = model_config.is_auraflow
|
||||
self.is_flux = model_config.is_flux
|
||||
self.is_lumina2 = model_config.is_lumina2
|
||||
|
||||
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||
@@ -189,7 +192,7 @@ class StableDiffusion:
|
||||
self.config_file = None
|
||||
|
||||
self.is_flow_matching = False
|
||||
if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler):
|
||||
if self.is_flux or self.is_v3 or self.is_auraflow or self.is_lumina2 or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler):
|
||||
self.is_flow_matching = True
|
||||
|
||||
self.quantize_device = self.device_torch
|
||||
@@ -745,6 +748,97 @@ class StableDiffusion:
|
||||
text_encoder[1].eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
elif self.model_config.is_lumina2:
|
||||
print_acc("Loading Lumina2 model")
|
||||
# base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
base_model_path = self.model_config.name_or_path_original
|
||||
print_acc("Loading transformer")
|
||||
subfolder = 'transformer'
|
||||
transformer_path = model_path
|
||||
if os.path.exists(transformer_path):
|
||||
subfolder = None
|
||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||
# check if the path is a full checkpoint.
|
||||
te_folder_path = os.path.join(model_path, 'text_encoder')
|
||||
# if we have the te, this folder is a full checkpoint, use it as the base
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
|
||||
transformer = Lumina2Transformer2DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=subfolder,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
if self.model_config.split_model_over_gpus:
|
||||
raise ValueError("Splitting model over gpus is not supported for Lumina2 models")
|
||||
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||
raise ValueError("Assistant LoRA is not supported for Lumina2 models currently")
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError("Loading LoRA is not supported for Lumina2 models currently")
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
print_acc("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs)
|
||||
freeze(transformer)
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print_acc("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
print_acc("Loading Gemma2")
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize_te:
|
||||
print_acc("Quantizing Gemma2")
|
||||
quantize(text_encoder, weights=qfloat8)
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
print_acc("making pipe")
|
||||
pipe: Lumina2Text2ImgPipeline = Lumina2Text2ImgPipeline(
|
||||
scheduler=scheduler,
|
||||
text_encoder=None,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
transformer=None,
|
||||
)
|
||||
pipe.text_encoder = text_encoder
|
||||
pipe.transformer = transformer
|
||||
|
||||
print_acc("preparing")
|
||||
|
||||
text_encoder = pipe.text_encoder
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
|
||||
flush()
|
||||
text_encoder.to(self.device_torch)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
else:
|
||||
if self.custom_pipeline is not None:
|
||||
pipln = self.custom_pipeline
|
||||
@@ -817,9 +911,10 @@ class StableDiffusion:
|
||||
# add hacks to unet to help training
|
||||
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
||||
|
||||
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
||||
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
|
||||
# pixart and sd3 dont use a unet
|
||||
self.unet = pipe.transformer
|
||||
self.unet_unwrapped = pipe.transformer
|
||||
else:
|
||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
@@ -832,7 +927,7 @@ class StableDiffusion:
|
||||
self.unet.eval()
|
||||
|
||||
# load any loras we have
|
||||
if self.model_config.lora_path is not None and not self.is_flux:
|
||||
if self.model_config.lora_path is not None and not self.is_flux and not self.is_lumina2:
|
||||
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
||||
pipe.fuse_lora()
|
||||
# unfortunately, not an easier way with peft
|
||||
@@ -975,16 +1070,18 @@ class StableDiffusion:
|
||||
})
|
||||
else:
|
||||
arch = 'sd'
|
||||
if self.model_config.is_pixart:
|
||||
if self.is_pixart:
|
||||
arch = 'pixart'
|
||||
if self.model_config.is_flux:
|
||||
if self.is_flux:
|
||||
arch = 'flux'
|
||||
if self.is_lumina2:
|
||||
arch = 'lumina2'
|
||||
noise_scheduler = get_sampler(
|
||||
sampler,
|
||||
{
|
||||
"prediction_type": self.prediction_type,
|
||||
},
|
||||
arch
|
||||
arch=arch
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1061,6 +1158,15 @@ class StableDiffusion:
|
||||
**extra_args
|
||||
)
|
||||
pipeline.watermark = None
|
||||
elif self.is_lumina2:
|
||||
pipeline = Lumina2Text2ImgPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=noise_scheduler,
|
||||
**extra_args
|
||||
)
|
||||
elif self.is_v3:
|
||||
pipeline = Pipe(
|
||||
vae=self.vae,
|
||||
@@ -1366,6 +1472,22 @@ class StableDiffusion:
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
**extra
|
||||
).images[0]
|
||||
elif self.is_lumina2:
|
||||
pipeline: Lumina2Text2ImgPipeline = pipeline
|
||||
|
||||
img = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
guidance_scale=gen_config.guidance_scale,
|
||||
latents=gen_config.latents,
|
||||
generator=generator,
|
||||
**extra
|
||||
).images[0]
|
||||
elif self.is_pixart:
|
||||
# needs attention masks for some reason
|
||||
img = pipeline(
|
||||
@@ -1924,6 +2046,20 @@ class StableDiffusion:
|
||||
|
||||
if bypass_guidance_embedding:
|
||||
restore_flux_guidance(self.unet)
|
||||
elif self.is_lumina2:
|
||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||
t = 1 - timestep / self.noise_scheduler.config.num_train_timesteps
|
||||
with self.accelerator.autocast():
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=t,
|
||||
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
# lumina2 does this before stepping. Should we do it here?
|
||||
noise_pred = -noise_pred
|
||||
elif self.is_v3:
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
@@ -2168,6 +2304,23 @@ class StableDiffusion:
|
||||
pe.pooled_embeds = pooled_prompt_embeds
|
||||
return pe
|
||||
|
||||
elif self.is_lumina2:
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.pipeline.encode_prompt(
|
||||
prompt,
|
||||
do_classifier_free_guidance=False,
|
||||
num_images_per_prompt=1,
|
||||
device=self.device_torch,
|
||||
max_sequence_length=256, # should it be 512?
|
||||
)
|
||||
return PromptEmbeds(
|
||||
prompt_embeds,
|
||||
attention_mask=prompt_attention_mask,
|
||||
)
|
||||
|
||||
elif isinstance(self.text_encoder, T5EncoderModel):
|
||||
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
||||
@@ -2360,7 +2513,7 @@ class StableDiffusion:
|
||||
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
|
||||
named_params[name] = param
|
||||
if unet:
|
||||
if self.is_flux:
|
||||
if self.is_flux or self.is_lumina2:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
@@ -2472,6 +2625,14 @@ class StableDiffusion:
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
elif self.is_lumina2:
|
||||
# only save the unet
|
||||
transformer: Lumina2Transformer2DModel = unwrap_model(self.unet)
|
||||
transformer.save_pretrained(
|
||||
save_directory=os.path.join(output_file, 'transformer'),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
self.pipeline.save_pretrained(
|
||||
@@ -2528,7 +2689,7 @@ class StableDiffusion:
|
||||
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
|
||||
unet_lr = unet_lr if unet_lr is not None else default_lr
|
||||
params = []
|
||||
if self.is_pixart or self.is_auraflow or self.is_flux:
|
||||
if self.is_pixart or self.is_auraflow or self.is_flux or self.is_v3 or self.is_lumina2:
|
||||
for param in named_params.values():
|
||||
if param.requires_grad:
|
||||
params.append(param)
|
||||
@@ -2574,7 +2735,9 @@ class StableDiffusion:
|
||||
def save_device_state(self):
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
||||
if self.is_lumina2:
|
||||
unet_has_grad = self.unet.x_embedder.weight.requires_grad
|
||||
elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
||||
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
||||
else:
|
||||
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
||||
@@ -2607,6 +2770,8 @@ class StableDiffusion:
|
||||
else:
|
||||
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
|
||||
te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
||||
elif isinstance(self.text_encoder, Gemma2Model):
|
||||
te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad
|
||||
else:
|
||||
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
|
||||
|
||||
Reference in New Issue
Block a user