fixed issue with batch sizes larget than 1

This commit is contained in:
Jaret Burkett
2024-08-15 12:21:38 -06:00
parent 9001e5c933
commit b99d36dfdb

View File

@@ -382,9 +382,10 @@ class SDTrainer(BaseSDTrainProcess):
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# handle linear timesteps and only adjust the weight of the timesteps
if self.sd.is_flow_matching and self.train_config.linear_timesteps:
if self.sd.is_flow_matching and self.train_config.linear_timestqeps:
# calculate the weights for the timesteps
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype)
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
loss = loss * timestep_weight
if self.train_config.do_prior_divergence and prior_pred is not None: