Wokr on lumina2

This commit is contained in:
Jaret Burkett
2025-02-08 14:52:39 -07:00
parent d138f07365
commit 9a7266275d
3 changed files with 34 additions and 11 deletions

View File

@@ -63,7 +63,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
torch.nn.Module.__init__(self)
self.lora_name = lora_name
self.orig_module_ref = weakref.ref(org_module)
self.scalar = torch.tensor(1.0)
self.scalar = torch.tensor(1.0, device=org_module.weight.device)
# check if parent has bias. if not force use_bias to False
if org_module.bias is None:
use_bias = False
@@ -275,7 +275,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
unet_prefix = self.LORA_PREFIX_UNET
if self.peft_format:
unet_prefix = self.PEFT_PREFIX_UNET
if is_pixart or is_v3 or is_auraflow or is_flux:
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2:
unet_prefix = f"lora_transformer"
if self.peft_format:
unet_prefix = "transformer"

View File

@@ -28,10 +28,13 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rota
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
import torch
from torch.profiler import profile, record_function, ProfilerActivity
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
do_profile = False
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
@@ -472,7 +475,18 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
attention_mask: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
batch_size = hidden_states.size(0)
if do_profile:
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
)
prof.start()
# 1. Condition, positional & patch embedding
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
@@ -534,6 +548,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
output = torch.stack(output, dim=0)
if do_profile:
torch.cuda.synchronize() # Make sure all CUDA ops are done
prof.stop()
print("\n==== Profile Results ====")
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
return Transformer2DModelOutput(sample=output)

View File

@@ -914,6 +914,7 @@ class StableDiffusion:
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
# pixart and sd3 dont use a unet
self.unet = pipe.transformer
self.unet_unwrapped = pipe.transformer
else:
self.unet: 'UNet2DConditionModel' = pipe.unet
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
@@ -2048,13 +2049,14 @@ class StableDiffusion:
elif self.is_lumina2:
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
t = 1 - timestep / self.noise_scheduler.config.num_train_timesteps
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=t,
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
with self.accelerator.autocast():
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=t,
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
# lumina2 does this before stepping. Should we do it here?
noise_pred = -noise_pred