Added experimental modified sigma sqrt weight mapping for linear timestep scheduling for flowmatching

This commit is contained in:
Jaret Burkett
2024-08-12 17:03:09 -06:00
parent 599fafe01f
commit 9ee1ef2a0a
2 changed files with 47 additions and 1 deletions

View File

@@ -375,12 +375,18 @@ class SDTrainer(BaseSDTrainProcess):
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
if self.train_config.loss_type == "mae":
loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none")
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# handle linear timesteps and only adjust the weight of the timesteps
if self.sd.is_flow_matching and self.train_config.linear_timesteps:
# calculate the weights for the timesteps
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype)
loss = loss * timestep_weight
if self.train_config.do_prior_divergence and prior_pred is not None:
loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0)