mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Wokr on lumina2
This commit is contained in:
@@ -28,10 +28,13 @@ from diffusers.models.embeddings import TimestepEmbedding, Timesteps, apply_rota
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
|
||||
|
||||
import torch
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
do_profile = False
|
||||
|
||||
|
||||
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
@@ -472,7 +475,18 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
attention_mask: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
if do_profile:
|
||||
prof = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
|
||||
prof.start()
|
||||
|
||||
# 1. Condition, positional & patch embedding
|
||||
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
|
||||
@@ -534,6 +548,13 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if do_profile:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
prof.stop()
|
||||
|
||||
print("\n==== Profile Results ====")
|
||||
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000))
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
Reference in New Issue
Block a user