Wokr on lumina2

This commit is contained in:
Jaret Burkett
2025-02-08 14:52:39 -07:00
parent d138f07365
commit 9a7266275d
3 changed files with 34 additions and 11 deletions

View File

@@ -914,6 +914,7 @@ class StableDiffusion:
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2:
# pixart and sd3 dont use a unet
self.unet = pipe.transformer
self.unet_unwrapped = pipe.transformer
else:
self.unet: 'UNet2DConditionModel' = pipe.unet
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
@@ -2048,13 +2049,14 @@ class StableDiffusion:
elif self.is_lumina2:
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
t = 1 - timestep / self.noise_scheduler.config.num_train_timesteps
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=t,
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
with self.accelerator.autocast():
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=t,
attention_mask=text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64),
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
# lumina2 does this before stepping. Should we do it here?
noise_pred = -noise_pred