Added new timestep weighing strategy

This commit is contained in:
Jaret Burkett
2025-06-04 01:16:02 -06:00
parent adc31ec77d
commit 22cdfadab6
8 changed files with 1348 additions and 9 deletions

View File

@@ -501,13 +501,22 @@ class SDTrainer(BaseSDTrainProcess):
loss = wavelet_loss(pred, batch.latents, noise)
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
do_weighted_timesteps = False
if self.sd.is_flow_matching:
if self.train_config.linear_timesteps or self.train_config.linear_timesteps2:
do_weighted_timesteps = True
if self.train_config.timestep_type == "weighted":
# use the noise scheduler to get the weights for the timesteps
do_weighted_timesteps = True
# handle linear timesteps and only adjust the weight of the timesteps
if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2):
if do_weighted_timesteps:
# calculate the weights for the timesteps
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(
timesteps,
v2=self.train_config.linear_timesteps2
v2=self.train_config.linear_timesteps2,
timestep_type=self.train_config.timestep_type
).to(loss.device, dtype=loss.dtype)
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
loss = loss * timestep_weight