Wokr on lumina2

This commit is contained in:
Jaret Burkett
2025-02-08 14:52:39 -07:00
parent d138f07365
commit 9a7266275d
3 changed files with 34 additions and 11 deletions

View File

@@ -28,10 +28,13 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rota
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
import torch
from torch.profiler import profile, record_function, ProfilerActivity
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
do_profile = False
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
@@ -472,7 +475,18 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
attention_mask: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
batch_size = hidden_states.size(0)
if do_profile:
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
)
prof.start()
# 1. Condition, positional & patch embedding
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
@@ -534,6 +548,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
output = torch.stack(output, dim=0)
if do_profile:
torch.cuda.synchronize() # Make sure all CUDA ops are done
prof.stop()
print("\n==== Profile Results ====")
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
return Transformer2DModelOutput(sample=output)