diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 03c5f213..7d9f9471 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -375,12 +375,18 @@ class SDTrainer(BaseSDTrainProcess): loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) loss = loss_per_element else: - # handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100 + if self.train_config.loss_type == "mae": loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + # handle linear timesteps and only adjust the weight of the timesteps + if self.sd.is_flow_matching and self.train_config.linear_timesteps: + # calculate the weights for the timesteps + timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype) + loss = loss * timestep_weight + if self.train_config.do_prior_divergence and prior_pred is not None: loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 5782f45f..2765ce77 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -1,3 +1,4 @@ +import math from typing import Union from diffusers import FlowMatchEulerDiscreteScheduler @@ -5,6 +6,45 @@ import torch class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + + # generate the multiplier based on cosmap loss weighing + # this is only used on linear timesteps for now + + # cosine map weighing is higher in the middle and lower at the ends + # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 + # cosmap_weighing = 2 / (math.pi * bot) + + # sigma sqrt weighing is significantly higher at the end and lower at the beginning + sigma_sqrt_weighing = (self.sigmas ** -2.0).float() + # clip at 1e4 (1e6 is too high) + sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) + # bring to a mean of 1 + sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() + + # Create linear timesteps from 1000 to 0 + timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') + + self.linear_timesteps = timesteps + # self.linear_timesteps_weights = cosmap_weighing + self.linear_timesteps_weights = sigma_sqrt_weighing + + # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') + pass + + def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + + # Get the weights for the timesteps + weights = self.linear_timesteps_weights[step_indices].flatten() + + return weights def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: sigmas = self.sigmas.to(device=device, dtype=dtype)