mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added a new experimental linear weighing technique
This commit is contained in:
@@ -390,9 +390,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
# handle linear timesteps and only adjust the weight of the timesteps
|
||||
if self.sd.is_flow_matching and self.train_config.linear_timesteps:
|
||||
if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2):
|
||||
# calculate the weights for the timesteps
|
||||
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype)
|
||||
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(
|
||||
timesteps,
|
||||
v2=self.train_config.linear_timesteps2
|
||||
).to(loss.device, dtype=loss.dtype)
|
||||
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
|
||||
loss = loss * timestep_weight
|
||||
|
||||
|
||||
Reference in New Issue
Block a user