mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 17:51:41 +00:00
Fixed an issue training lumina 2
This commit is contained in:
@@ -50,8 +50,7 @@ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAda
|
||||
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
||||
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
|
||||
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, Lumina2Text2ImgPipeline, \
|
||||
FluxControlPipeline
|
||||
from toolkit.models.lumina2 import Lumina2Transformer2DModel
|
||||
FluxControlPipeline, Lumina2Transformer2DModel
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -2179,7 +2178,7 @@ class StableDiffusion:
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=t,
|
||||
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
encoder_attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
Reference in New Issue
Block a user