mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Wokr on lumina2
This commit is contained in:
@@ -914,6 +914,7 @@ class StableDiffusion:
|
||||
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
|
||||
# pixart and sd3 dont use a unet
|
||||
self.unet = pipe.transformer
|
||||
self.unet_unwrapped = pipe.transformer
|
||||
else:
|
||||
self.unet: 'UNet2DConditionModel' = pipe.unet
|
||||
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
@@ -2048,13 +2049,14 @@ class StableDiffusion:
|
||||
elif self.is_lumina2:
|
||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||
t = 1 - timestep / self.noise_scheduler.config.num_train_timesteps
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=t,
|
||||
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
with self.accelerator.autocast():
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=t,
|
||||
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
# lumina2 does this before stepping. Should we do it here?
|
||||
noise_pred = -noise_pred
|
||||
|
||||
Reference in New Issue
Block a user