mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Partial implementation for training auraflow.
This commit is contained in:
@@ -883,6 +883,26 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
def end_of_training_loop(self):
|
||||
pass
|
||||
|
||||
def predict_noise(
|
||||
self,
|
||||
noisy_latents: torch.Tensor,
|
||||
timesteps: Union[int, torch.Tensor] = 1,
|
||||
conditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
return self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
@@ -1453,14 +1473,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with self.timer('predict_unet'):
|
||||
if unconditional_embeds is not None:
|
||||
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach()
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeddings=unconditional_embeds,
|
||||
timestep=timesteps,
|
||||
guidance_scale=self.train_config.cfg_scale,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=self.train_config.cfg_rescale,
|
||||
noise_pred = self.predict_noise(
|
||||
noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
timesteps=timesteps,
|
||||
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
Reference in New Issue
Block a user