diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 4ebcab14..27317be9 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -63,7 +63,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): torch.nn.Module.__init__(self) self.lora_name = lora_name self.orig_module_ref = weakref.ref(org_module) - self.scalar = torch.tensor(1.0) + self.scalar = torch.tensor(1.0, device=org_module.weight.device) # check if parent has bias. if not force use_bias to False if org_module.bias is None: use_bias = False @@ -275,7 +275,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): unet_prefix = self.LORA_PREFIX_UNET if self.peft_format: unet_prefix = self.PEFT_PREFIX_UNET - if is_pixart or is_v3 or is_auraflow or is_flux: + if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2: unet_prefix = f"lora_transformer" if self.peft_format: unet_prefix = "transformer" diff --git a/toolkit/models/lumina2.py b/toolkit/models/lumina2.py index 0078ab62..f26e90ca 100644 --- a/toolkit/models/lumina2.py +++ b/toolkit/models/lumina2.py @@ -28,10 +28,13 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rota from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm - +import torch +from torch.profiler import profile, record_function, ProfilerActivity logger = logging.get_logger(__name__) # pylint: disable=invalid-name +do_profile = False + class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__( @@ -472,7 +475,18 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): attention_mask: torch.Tensor, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: + batch_size = hidden_states.size(0) + + if do_profile: + prof = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + ) + + prof.start() # 1. Condition, positional & patch embedding temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) @@ -534,6 +548,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ) output = torch.stack(output, dim=0) + if do_profile: + torch.cuda.synchronize() # Make sure all CUDA ops are done + prof.stop() + + print("\n==== Profile Results ====") + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000)) + if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 47702f5e..132890a1 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -914,6 +914,7 @@ class StableDiffusion: if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2: # pixart and sd3 dont use a unet self.unet = pipe.transformer + self.unet_unwrapped = pipe.transformer else: self.unet: 'UNet2DConditionModel' = pipe.unet self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) @@ -2048,13 +2049,14 @@ class StableDiffusion: elif self.is_lumina2: # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image t = 1 - timestep / self.noise_scheduler.config.num_train_timesteps - noise_pred = self.unet( - hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), - timestep=t, - attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64), - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), - **kwargs, - ).sample + with self.accelerator.autocast(): + noise_pred = self.unet( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=t, + attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64), + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample # lumina2 does this before stepping. Should we do it here? noise_pred = -noise_pred