Fixed an issue training lumina 2

This commit is contained in:
Jaret Burkett
2025-06-24 10:29:47 -06:00
parent f3eb1dff42
commit 03bc431279
2 changed files with 2 additions and 570 deletions

View File

@@ -50,8 +50,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \
FluxControlPipeline
from toolkit.models.lumina2 import Lumina2Transformer2DModel
FluxControlPipeline, Lumina2Transformer2DModel
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -2179,7 +2178,7 @@ class StableDiffusion:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=t,
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
encoder_attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample