From d138f0736535ebffebec1515a1f038995bf51c69 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 8 Feb 2025 10:59:53 -0700 Subject: [PATCH 1/4] Imitial lumina3 support --- extensions_built_in/sd_trainer/SDTrainer.py | 2 +- jobs/process/BaseSDTrainProcess.py | 21 +- toolkit/config_modules.py | 1 + toolkit/lora_special.py | 10 +- toolkit/models/lumina2.py | 539 +++++++++++++++++++ toolkit/sampler.py | 25 +- toolkit/samplers/custom_flowmatch_sampler.py | 2 +- toolkit/stable_diffusion_model.py | 184 ++++++- 8 files changed, 769 insertions(+), 15 deletions(-) create mode 100644 toolkit/models/lumina2.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 09156909..a5a72a0d 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -916,7 +916,7 @@ class SDTrainer(BaseSDTrainProcess): # self.network.multiplier = 0.0 self.sd.unet.eval() - if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux: + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux and not self.sd.is_lumina2: # we need to remove the image embeds from the prompt except for flux embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 6f057455..e30ddae0 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -335,6 +335,8 @@ class BaseSDTrainProcess(BaseTrainProcess): o_dict['ss_base_model_version'] = 'sdxl_1.0' elif self.model_config.is_flux: o_dict['ss_base_model_version'] = 'flux.1' + elif self.model_config.is_lumina2: + o_dict['ss_base_model_version'] = 'lumina2' else: o_dict['ss_base_model_version'] = 'sd_1.5' @@ -1387,12 +1389,19 @@ class BaseSDTrainProcess(BaseTrainProcess): self.load_training_state_from_metadata(latest_save_path) # get the noise scheduler + arch = 'sd' + if self.model_config.is_pixart: + arch = 'pixart' + if self.model_config.is_flux: + arch = 'flux' + if self.model_config.is_lumina2: + arch = 'lumina2' sampler = get_sampler( self.train_config.noise_scheduler, { "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", }, - 'sd' if not self.model_config.is_pixart else 'pixart' + arch=arch, ) if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: @@ -1452,10 +1461,13 @@ class BaseSDTrainProcess(BaseTrainProcess): # print_acc("sage attention is not installed. Using SDP instead") if self.train_config.gradient_checkpointing: - if self.sd.is_flux: + # if has method enable_gradient_checkpointing + if hasattr(unet, 'enable_gradient_checkpointing'): + unet.enable_gradient_checkpointing() + elif hasattr(unet, 'gradient_checkpointing'): unet.gradient_checkpointing = True else: - unet.enable_gradient_checkpointing() + print("Gradient checkpointing not supported on this model") if isinstance(text_encoder, list): for te in text_encoder: if hasattr(te, 'enable_gradient_checkpointing'): @@ -1547,6 +1559,7 @@ class BaseSDTrainProcess(BaseTrainProcess): is_pixart=self.model_config.is_pixart, is_auraflow=self.model_config.is_auraflow, is_flux=self.model_config.is_flux, + is_lumina2=self.model_config.is_lumina2, is_ssd=self.model_config.is_ssd, is_vega=self.model_config.is_vega, dropout=self.network_config.dropout, @@ -2165,6 +2178,8 @@ class BaseSDTrainProcess(BaseTrainProcess): tags.append("stable-diffusion-xl") if self.model_config.is_flux: tags.append("flux") + if self.model_config.is_lumina2: + tags.append("lumina2") if self.model_config.is_v3: tags.append("sd3") if self.network_config: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 30762c72..60e76352 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -424,6 +424,7 @@ class ModelConfig: self.is_auraflow: bool = kwargs.get('is_auraflow', False) self.is_v3: bool = kwargs.get('is_v3', False) self.is_flux: bool = kwargs.get('is_flux', False) + self.is_lumina2: bool = kwargs.get('is_lumina2', False) if self.is_pixart_sigma: self.is_pixart = True self.use_flux_cfg = kwargs.get('use_flux_cfg', False) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 6c53439a..4ebcab14 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -163,6 +163,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): is_pixart: bool = False, is_auraflow: bool = False, is_flux: bool = False, + is_lumina2: bool = False, use_bias: bool = False, is_lorm: bool = False, ignore_if_contains = None, @@ -223,6 +224,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.is_pixart = is_pixart self.is_auraflow = is_auraflow self.is_flux = is_flux + self.is_lumina2 = is_lumina2 self.network_type = network_type self.is_assistant_adapter = is_assistant_adapter if self.network_type.lower() == "dora": @@ -232,7 +234,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.peft_format = peft_format # always do peft for flux only for now - if self.is_flux or self.is_v3: + if self.is_flux or self.is_v3 or self.is_lumina2: self.peft_format = True if self.peft_format: @@ -326,6 +328,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if self.transformer_only and self.is_flux and is_unet: if "transformer_blocks" not in lora_name: skip = True + if self.transformer_only and self.is_lumina2 and is_unet: + if "layers$$" not in lora_name and "noise_refiner$$" not in lora_name and "context_refiner$$" not in lora_name: + skip = True if self.transformer_only and self.is_v3 and is_unet: if "transformer_blocks" not in lora_name: skip = True @@ -431,6 +436,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if is_flux: target_modules = ["FluxTransformer2DModel"] + + if is_lumina2: + target_modules = ["Lumina2Transformer2DModel"] if train_unet: self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) diff --git a/toolkit/models/lumina2.py b/toolkit/models/lumina2.py new file mode 100644 index 00000000..0078ab62 --- /dev/null +++ b/toolkit/models/lumina2.py @@ -0,0 +1,539 @@ +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.utils import logging +from diffusers.models.attention import LuminaFeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + cap_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True) + ) + + def forward( + self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).type_as(hidden_states) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(encoder_hidden_states) + return time_embed, caption_embed + + +class Lumina2AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key Norm if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Apply proportional attention if true + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # perform Grouped-qurey Attention (GQA) + n_rep = attn.heads // kv_heads + if n_rep >= 1: + key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.type_as(query) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class Lumina2TransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=Lumina2AttnProcessor2_0(), + ) + + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + if modulation: + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=True, + ) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.modulation: + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class Lumina2RotaryPosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta) + + def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: + freqs_cis = [] + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64) + freqs_cis.append(emb) + return freqs_cis + + def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: + result = [] + for i in range(len(self.axes_dim)): + freqs = self.freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1) + + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): + batch_size = len(hidden_states) + p_h = p_w = self.patch_size + device = hidden_states[0].device + + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + # TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape + img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] + l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes] + + max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))) + max_img_len = max(l_effective_img_len) + + position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) + + for i in range(batch_size): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + H, W = img_sizes[i] + H_tokens, W_tokens = H // p_h, W // p_w + assert H_tokens * W_tokens == img_len + + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + position_ids[i, cap_len : cap_len + img_len, 0] = cap_len + row_ids = ( + torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + ) + col_ids = ( + torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + ) + position_ids[i, cap_len : cap_len + img_len, 1] = row_ids + position_ids[i, cap_len : cap_len + img_len, 2] = col_ids + + freqs_cis = self._get_freqs_cis(position_ids) + + cap_freqs_cis_shape = list(freqs_cis.shape) + cap_freqs_cis_shape[1] = attention_mask.shape[1] + cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) + + img_freqs_cis_shape = list(freqs_cis.shape) + img_freqs_cis_shape[1] = max_img_len + img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) + + for i in range(batch_size): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] + + flat_hidden_states = [] + for i in range(batch_size): + img = hidden_states[i] + C, H, W = img.size() + img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) + flat_hidden_states.append(img) + hidden_states = flat_hidden_states + padded_img_embed = torch.zeros( + batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype + ) + padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i] + padded_img_mask[i, : l_effective_img_len[i]] = True + + return ( + padded_img_embed, + padded_img_mask, + img_sizes, + l_effective_cap_len, + l_effective_img_len, + freqs_cis, + cap_freqs_cis, + img_freqs_cis, + max_seq_len, + ) + + +class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + r""" + Lumina2NextDiT: Diffusion model with a Transformer backbone. + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): + The size of each patch in the image. This parameter defines the resolution of patches fed into the model. + in_channels (`int`, *optional*, defaults to 4): + The number of input channels for the model. Typically, this matches the number of channels in the input + images. + hidden_size (`int`, *optional*, defaults to 4096): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + num_layers (`int`, *optional*, default to 32): + The number of layers in the model. This defines the depth of the neural network. + num_attention_heads (`int`, *optional*, defaults to 32): + The number of attention heads in each attention layer. This parameter specifies how many separate attention + mechanisms are used. + num_kv_heads (`int`, *optional*, defaults to 8): + The number of key-value heads in the attention mechanism, if different from the number of attention heads. + If None, it defaults to num_attention_heads. + multiple_of (`int`, *optional*, defaults to 256): + A factor that the hidden size should be a multiple of. This can help optimize certain hardware + configurations. + ffn_dim_multiplier (`float`, *optional*): + A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on + the model configuration. + norm_eps (`float`, *optional*, defaults to 1e-5): + A small value added to the denominator for numerical stability in normalization layers. + scaling_factor (`float`, *optional*, defaults to 1.0): + A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the + overall scale of the model's operations. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Lumina2TransformerBlock"] + _skip_layerwise_casting_patterns = ["x_embedder", "norm"] + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 2304, + num_layers: int = 26, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + scaling_factor: float = 1.0, + axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), + axes_lens: Tuple[int, int, int] = (300, 512, 512), + cap_feat_dim: int = 1024, + ) -> None: + super().__init__() + self.out_channels = out_channels or in_channels + + # 1. Positional, patch & conditional embeddings + self.rope_embedder = Lumina2RotaryPosEmbed( + theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size + ) + + self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps + ) + + # 2. Noise and context refinement blocks + self.noise_refiner = nn.ModuleList( + [ + Lumina2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + Lumina2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False, + ) + for _ in range(num_refiner_layers) + ] + ) + + # 3. Transformer blocks + self.layers = nn.ModuleList( + [ + Lumina2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + batch_size = hidden_states.size(0) + + # 1. Condition, positional & patch embedding + temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) + + ( + hidden_states, + hidden_mask, + hidden_sizes, + encoder_hidden_len, + hidden_len, + joint_rotary_emb, + encoder_rotary_emb, + hidden_rotary_emb, + max_seq_len, + ) = self.rope_embedder(hidden_states, attention_mask) + + hidden_states = self.x_embedder(hidden_states) + + # 2. Context & noise refinement + for layer in self.context_refiner: + encoder_hidden_states = layer(encoder_hidden_states, attention_mask, encoder_rotary_emb) + + for layer in self.noise_refiner: + hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb) + + # 3. Attention mask preparation + mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i in range(batch_size): + cap_len = encoder_hidden_len[i] + img_len = hidden_len[i] + mask[i, : cap_len + img_len] = True + padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len] + padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len] + hidden_states = padded_hidden_states + + # 4. Transformer blocks + for layer in self.layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(layer, hidden_states, mask, joint_rotary_emb, temb) + else: + hidden_states = layer(hidden_states, mask, joint_rotary_emb, temb) + + # 5. Output norm & projection & unpatchify + hidden_states = self.norm_out(hidden_states, temb) + + height_tokens = width_tokens = self.config.patch_size + output = [] + for i in range(len(hidden_sizes)): + height, width = hidden_sizes[i] + begin = encoder_hidden_len[i] + end = begin + (height // height_tokens) * (width // width_tokens) + output.append( + hidden_states[i][begin:end] + .view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + output = torch.stack(output, dim=0) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/toolkit/sampler.py b/toolkit/sampler.py index aae6e379..7a0e50df 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -88,6 +88,23 @@ flux_config = { "use_dynamic_shifting": True } +lumina2_config = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.33.0.dev0", + "base_image_seq_len": 256, + "base_shift": 0.5, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 6.0, + "shift_terminal": None, + "use_beta_sigmas": False, + "use_dynamic_shifting": False, + "use_exponential_sigmas": False, + "use_karras_sigmas": False +} + def get_sampler( sampler: str, @@ -132,7 +149,13 @@ def get_sampler( scheduler_cls = CustomLCMScheduler elif sampler == "flowmatch": scheduler_cls = CustomFlowMatchEulerDiscreteScheduler - config_to_use = copy.deepcopy(flux_config) + if arch == "flux": + config_to_use = copy.deepcopy(flux_config) + elif arch == "lumina2": + config_to_use = copy.deepcopy(lumina2_config) + else: + # use flux by default + config_to_use = copy.deepcopy(flux_config) else: raise ValueError(f"Sampler {sampler} not supported") diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index a4c53db1..2a7a1cfd 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -124,7 +124,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): self.timesteps = timesteps.to(device=device) return timesteps - elif timestep_type == 'flux_shift': + elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift': # matches inference dynamic shifting timesteps = np.linspace( self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 7bdc586b..47702f5e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -49,7 +49,8 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \ StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \ - FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel + FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline +from toolkit.models.lumina2 import Lumina2Transformer2DModel import diffusers from diffusers import \ AutoencoderKL, \ @@ -67,6 +68,7 @@ from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc from diffusers import FluxFillPipeline +from transformers import AutoModel, AutoTokenizer, Gemma2Model if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork @@ -182,6 +184,7 @@ class StableDiffusion: self.is_pixart = model_config.is_pixart self.is_auraflow = model_config.is_auraflow self.is_flux = model_config.is_flux + self.is_lumina2 = model_config.is_lumina2 self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_2 = model_config.use_text_encoder_2 @@ -189,7 +192,7 @@ class StableDiffusion: self.config_file = None self.is_flow_matching = False - if self.is_flux or self.is_v3 or self.is_auraflow or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): + if self.is_flux or self.is_v3 or self.is_auraflow or self.is_lumina2 or isinstance(self.noise_scheduler, CustomFlowMatchEulerDiscreteScheduler): self.is_flow_matching = True self.quantize_device = self.device_torch @@ -745,6 +748,97 @@ class StableDiffusion: text_encoder[1].eval() pipe.transformer = pipe.transformer.to(self.device_torch) flush() + elif self.model_config.is_lumina2: + print_acc("Loading Lumina2 model") + # base_model_path = "black-forest-labs/FLUX.1-schnell" + base_model_path = self.model_config.name_or_path_original + print_acc("Loading transformer") + subfolder = 'transformer' + transformer_path = model_path + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + transformer = Lumina2Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + + if self.model_config.split_model_over_gpus: + raise ValueError("Splitting model over gpus is not supported for Lumina2 models") + + transformer.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + raise ValueError("Assistant LoRA is not supported for Lumina2 models currently") + + if self.model_config.lora_path is not None: + raise ValueError("Loading LoRA is not supported for Lumina2 models currently") + + flush() + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = qfloat8 + print_acc("Quantizing transformer") + quantize(transformer, weights=quantization_type, **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") + print_acc("Loading vae") + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + print_acc("Loading Gemma2") + tokenizer = AutoTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder = AutoModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + print_acc("Quantizing Gemma2") + quantize(text_encoder, weights=qfloat8) + freeze(text_encoder) + flush() + + print_acc("making pipe") + pipe: Lumina2Text2ImgPipeline = Lumina2Text2ImgPipeline( + scheduler=scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + ) + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + print_acc("preparing") + + text_encoder = pipe.text_encoder + tokenizer = pipe.tokenizer + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder.to(self.device_torch) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() else: if self.custom_pipeline is not None: pipln = self.custom_pipeline @@ -817,7 +911,7 @@ class StableDiffusion: # add hacks to unet to help training # pipe.unet = prepare_unet_for_training(pipe.unet) - if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: + if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2: # pixart and sd3 dont use a unet self.unet = pipe.transformer else: @@ -832,7 +926,7 @@ class StableDiffusion: self.unet.eval() # load any loras we have - if self.model_config.lora_path is not None and not self.is_flux: + if self.model_config.lora_path is not None and not self.is_flux and not self.is_lumina2: pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") pipe.fuse_lora() # unfortunately, not an easier way with peft @@ -974,12 +1068,19 @@ class StableDiffusion: "prediction_type": self.prediction_type, }) else: + arch = 'sd' + if self.is_pixart: + arch = 'pixart' + if self.is_flux: + arch = 'flux' + if self.is_lumina2: + arch = 'lumina2' noise_scheduler = get_sampler( sampler, { "prediction_type": self.prediction_type, }, - 'sd' if not self.is_pixart else 'pixart' + arch=arch ) try: @@ -1056,6 +1157,15 @@ class StableDiffusion: **extra_args ) pipeline.watermark = None + elif self.is_lumina2: + pipeline = Lumina2Text2ImgPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=noise_scheduler, + **extra_args + ) elif self.is_v3: pipeline = Pipe( vae=self.vae, @@ -1361,6 +1471,22 @@ class StableDiffusion: callback_on_step_end=callback_on_step_end, **extra ).images[0] + elif self.is_lumina2: + pipeline: Lumina2Text2ImgPipeline = pipeline + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64), + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] elif self.is_pixart: # needs attention masks for some reason img = pipeline( @@ -1919,6 +2045,19 @@ class StableDiffusion: if bypass_guidance_embedding: restore_flux_guidance(self.unet) + elif self.is_lumina2: + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + t = 1 - timestep / self.noise_scheduler.config.num_train_timesteps + noise_pred = self.unet( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=t, + attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64), + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample + + # lumina2 does this before stepping. Should we do it here? + noise_pred = -noise_pred elif self.is_v3: noise_pred = self.unet( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), @@ -2163,6 +2302,23 @@ class StableDiffusion: pe.pooled_embeds = pooled_prompt_embeds return pe + elif self.is_lumina2: + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + num_images_per_prompt=1, + device=self.device_torch, + max_sequence_length=256, # should it be 512? + ) + return PromptEmbeds( + prompt_embeds, + attention_mask=prompt_attention_mask, + ) elif isinstance(self.text_encoder, T5EncoderModel): embeds, attention_mask = train_tools.encode_prompts_pixart( @@ -2355,7 +2511,7 @@ class StableDiffusion: for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): named_params[name] = param if unet: - if self.is_flux: + if self.is_flux or self.is_lumina2: for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): named_params[name] = param else: @@ -2467,6 +2623,14 @@ class StableDiffusion: save_directory=os.path.join(output_file, 'transformer'), safe_serialization=True, ) + elif self.is_lumina2: + # only save the unet + transformer: Lumina2Transformer2DModel = unwrap_model(self.unet) + transformer.save_pretrained( + save_directory=os.path.join(output_file, 'transformer'), + safe_serialization=True, + ) + else: self.pipeline.save_pretrained( @@ -2523,7 +2687,7 @@ class StableDiffusion: named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True) unet_lr = unet_lr if unet_lr is not None else default_lr params = [] - if self.is_pixart or self.is_auraflow or self.is_flux: + if self.is_pixart or self.is_auraflow or self.is_flux or self.is_v3 or self.is_lumina2: for param in named_params.values(): if param.requires_grad: params.append(param) @@ -2569,7 +2733,9 @@ class StableDiffusion: def save_device_state(self): # saves the current device state for all modules # this is useful for when we want to alter the state and restore it - if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: + if self.is_lumina2: + unet_has_grad = self.unet.x_embedder.weight.requires_grad + elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: unet_has_grad = self.unet.proj_out.weight.requires_grad else: unet_has_grad = self.unet.conv_in.weight.requires_grad @@ -2602,6 +2768,8 @@ class StableDiffusion: else: if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + elif isinstance(self.text_encoder, Gemma2Model): + te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad else: te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad From 9a7266275d7e6322d3348e699581e9da58c85a94 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 8 Feb 2025 14:52:39 -0700 Subject: [PATCH 2/4] Wokr on lumina2 --- toolkit/lora_special.py | 4 ++-- toolkit/models/lumina2.py | 25 +++++++++++++++++++++++-- toolkit/stable_diffusion_model.py | 16 +++++++++------- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 4ebcab14..27317be9 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -63,7 +63,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): torch.nn.Module.__init__(self) self.lora_name = lora_name self.orig_module_ref = weakref.ref(org_module) - self.scalar = torch.tensor(1.0) + self.scalar = torch.tensor(1.0, device=org_module.weight.device) # check if parent has bias. if not force use_bias to False if org_module.bias is None: use_bias = False @@ -275,7 +275,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): unet_prefix = self.LORA_PREFIX_UNET if self.peft_format: unet_prefix = self.PEFT_PREFIX_UNET - if is_pixart or is_v3 or is_auraflow or is_flux: + if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2: unet_prefix = f"lora_transformer" if self.peft_format: unet_prefix = "transformer" diff --git a/toolkit/models/lumina2.py b/toolkit/models/lumina2.py index 0078ab62..f26e90ca 100644 --- a/toolkit/models/lumina2.py +++ b/toolkit/models/lumina2.py @@ -28,10 +28,13 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rota from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm - +import torch +from torch.profiler import profile, record_function, ProfilerActivity logger = logging.get_logger(__name__) # pylint: disable=invalid-name +do_profile = False + class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__( @@ -472,7 +475,18 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): attention_mask: torch.Tensor, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: + batch_size = hidden_states.size(0) + + if do_profile: + prof = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) + + prof.start() # 1. Condition, positional & patch embedding temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) @@ -534,6 +548,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ) output = torch.stack(output, dim=0) + if do_profile: + torch.cuda.synchronize() # Make sure all CUDA ops are done + prof.stop() + + print("\n==== Profile Results ====") + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000)) + if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 47702f5e..132890a1 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -914,6 +914,7 @@ class StableDiffusion: if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2: # pixart and sd3 dont use a unet self.unet = pipe.transformer + self.unet_unwrapped = pipe.transformer else: self.unet: 'UNet2DConditionModel' = pipe.unet self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) @@ -2048,13 +2049,14 @@ class StableDiffusion: elif self.is_lumina2: # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image t = 1 - timestep / self.noise_scheduler.config.num_train_timesteps - noise_pred = self.unet( - hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), - timestep=t, - attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64), - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), - **kwargs, - ).sample + with self.accelerator.autocast(): + noise_pred = self.unet( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=t, + attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64), + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample # lumina2 does this before stepping. Should we do it here? noise_pred = -noise_pred From 4de6a825fac513a761fcfbca3b65afff4cdde0f8 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 8 Feb 2025 15:16:35 -0700 Subject: [PATCH 3/4] Update lumina requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2f2cdc3a..b8620b4c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch==2.5.1 torchvision==0.20.1 safetensors -diffusers==0.32.2 +git+https://github.com/zhuole1025/diffusers@lumina2 transformers lycoris-lora==1.8.3 flatten_json From ed1deb71c49affbb7e8539e23ddd193e511326a9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 8 Feb 2025 16:13:18 -0700 Subject: [PATCH 4/4] Added examples for training lumina2 --- .../examples/train_full_fine_tune_lumina.yaml | 99 +++++++++++++++++++ config/examples/train_lora_lumina.yaml | 96 ++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 config/examples/train_full_fine_tune_lumina.yaml create mode 100644 config/examples/train_lora_lumina.yaml diff --git a/config/examples/train_full_fine_tune_lumina.yaml b/config/examples/train_full_fine_tune_lumina.yaml new file mode 100644 index 00000000..51a61737 --- /dev/null +++ b/config/examples/train_full_fine_tune_lumina.yaml @@ -0,0 +1,99 @@ +--- +# This configuration requires 24GB of VRAM or more to operate +job: extension +config: + # this name will be the folder and filename name + name: "my_first_lumina_finetune_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps + # performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # trigger_word: "p3r5on" + save: + dtype: bf16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 2 # how many intermittent saves to keep + save_format: 'diffusers' # 'diffusers' + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + # cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions + train: + batch_size: 1 + + # can be 'sigmoid', 'linear', or 'lumina2_shift' + timestep_type: 'lumina2_shift' + + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with lumina2 + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adafactor" + lr: 3e-5 + + # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0. + # 0.1 is 10% of paramiters active at easc step. Only works with adafactor + + # do_paramiter_swapping: true + # paramiter_swapping_factor: 0.9 + + # uncomment this to skip the pre training sample + # skip_first_sample: true + # uncomment to completely disable sampling + # disable_sampling: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram + # ema_config: + # use_ema: true + # ema_decay: 0.99 + + # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Alpha-VLLM/Lumina-Image-2.0" + is_lumina2: true # lumina2 architecture + # you can quantize just the Gemma2 text encoder here to save vram + quantize_te: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear." + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4.0 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0' diff --git a/config/examples/train_lora_lumina.yaml b/config/examples/train_lora_lumina.yaml new file mode 100644 index 00000000..e5d2d756 --- /dev/null +++ b/config/examples/train_lora_lumina.yaml @@ -0,0 +1,96 @@ +--- +# This configuration requires 20GB of VRAM or more to operate +job: extension +config: + # this name will be the folder and filename name + name: "my_first_lumina_lora_v1" + process: + - type: 'sd_trainer' + # root folder to save training sessions/samples/weights + training_folder: "output" + # uncomment to see performance stats in the terminal every N steps + # performance_log_every: 1000 + device: cuda:0 + # if a trigger word is specified, it will be added to captions of training data if it does not already exist + # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word + # trigger_word: "p3r5on" + network: + type: "lora" + linear: 16 + linear_alpha: 16 + save: + dtype: bf16 # precision to save + save_every: 250 # save every this many steps + max_step_saves_to_keep: 2 # how many intermittent saves to keep + save_format: 'diffusers' # 'diffusers' + datasets: + # datasets are a folder of images. captions need to be txt files with the same name as the image + # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently + # images will automatically be resized and bucketed into the resolution specified + # on windows, escape back slashes with another backslash so + # "C:\\path\\to\\images\\folder" + - folder_path: "/path/to/images/folder" + caption_ext: "txt" + caption_dropout_rate: 0.05 # will drop out the caption 5% of time + shuffle_tokens: false # shuffle caption order, split by commas + # cache_latents_to_disk: true # leave this true unless you know what you're doing + resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions + train: + batch_size: 1 + + # can be 'sigmoid', 'linear', or 'lumina2_shift' + timestep_type: 'lumina2_shift' + + steps: 2000 # total number of steps to train 500 - 4000 is a good range + gradient_accumulation: 1 + train_unet: true + train_text_encoder: false # probably won't work with lumina2 + gradient_checkpointing: true # need the on unless you have a ton of vram + noise_scheduler: "flowmatch" # for training only + optimizer: "adamw8bit" + lr: 1e-4 + # uncomment this to skip the pre training sample + # skip_first_sample: true + # uncomment to completely disable sampling + # disable_sampling: true + + # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram + ema_config: + use_ema: true + ema_decay: 0.99 + + # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly + dtype: bf16 + model: + # huggingface model name or path + name_or_path: "Alpha-VLLM/Lumina-Image-2.0" + is_lumina2: true # lumina2 architecture + # you can quantize just the Gemma2 text encoder here to save vram + quantize_te: true + sample: + sampler: "flowmatch" # must match train.noise_scheduler + sample_every: 250 # sample every this many steps + width: 1024 + height: 1024 + prompts: + # you can add [trigger] to the prompts here and it will be replaced with the trigger word + # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\ + - "woman with red hair, playing chess at the park, bomb going off in the background" + - "a woman holding a coffee cup, in a beanie, sitting at a cafe" + - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini" + - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background" + - "a bear building a log cabin in the snow covered mountains" + - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker" + - "hipster man with a beard, building a chair, in a wood shop" + - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear." + - "a man holding a sign that says, 'this is a sign'" + - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle" + neg: "" + seed: 42 + walk_seed: true + guidance_scale: 4.0 + sample_steps: 25 +# you can add any additional meta info here. [name] is replaced with config name at top +meta: + name: "[name]" + version: '1.0'