Partial implementation for training auraflow.

This commit is contained in:
Jaret Burkett
2024-07-12 12:11:38 -06:00
parent c062b7716c
commit e4558dff4b
9 changed files with 386 additions and 19 deletions

View File

@@ -883,6 +883,26 @@ class SDTrainer(BaseSDTrainProcess):
def end_of_training_loop(self):
pass
def predict_noise(
self,
noisy_latents: torch.Tensor,
timesteps: Union[int, torch.Tensor] = 1,
conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None,
**kwargs,
):
dtype = get_torch_dtype(self.train_config.dtype)
return self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
**kwargs
)
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch)
@@ -1453,14 +1473,11 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('predict_unet'):
if unconditional_embeds is not None:
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeddings=unconditional_embeds,
timestep=timesteps,
guidance_scale=self.train_config.cfg_scale,
detach_unconditional=False,
rescale_cfg=self.train_config.cfg_rescale,
noise_pred = self.predict_noise(
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
timesteps=timesteps,
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds,
**pred_kwargs
)
self.after_unet_predict()