mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Wokr on lumina2
This commit is contained in:
@@ -63,7 +63,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
torch.nn.Module.__init__(self)
|
||||
self.lora_name = lora_name
|
||||
self.orig_module_ref = weakref.ref(org_module)
|
||||
self.scalar = torch.tensor(1.0)
|
||||
self.scalar = torch.tensor(1.0, device=org_module.weight.device)
|
||||
# check if parent has bias. if not force use_bias to False
|
||||
if org_module.bias is None:
|
||||
use_bias = False
|
||||
@@ -275,7 +275,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
unet_prefix = self.LORA_PREFIX_UNET
|
||||
if self.peft_format:
|
||||
unet_prefix = self.PEFT_PREFIX_UNET
|
||||
if is_pixart or is_v3 or is_auraflow or is_flux:
|
||||
if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2:
|
||||
unet_prefix = f"lora_transformer"
|
||||
if self.peft_format:
|
||||
unet_prefix = "transformer"
|
||||
|
||||
@@ -28,10 +28,13 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rota
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
||||
|
||||
import torch
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
do_profile = False
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
@@ -472,7 +475,18 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
attention_mask: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if do_profile:
|
||||
prof = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
|
||||
prof.start()
|
||||
|
||||
# 1. Condition, positional & patch embedding
|
||||
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
|
||||
@@ -534,6 +548,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if do_profile:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
prof.stop()
|
||||
|
||||
print("\n==== Profile Results ====")
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -914,6 +914,7 @@ class StableDiffusion:
|
||||
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
|
||||
# pixart and sd3 dont use a unet
|
||||
self.unet = pipe.transformer
|
||||
self.unet_unwrapped = pipe.transformer
|
||||
else:
|
||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
@@ -2048,13 +2049,14 @@ class StableDiffusion:
|
||||
elif self.is_lumina2:
|
||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||
t = 1 - timestep / self.noise_scheduler.config.num_train_timesteps
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=t,
|
||||
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
with self.accelerator.autocast():
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=t,
|
||||
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
# lumina2 does this before stepping. Should we do it here?
|
||||
noise_pred = -noise_pred
|
||||
|
||||
Reference in New Issue
Block a user