From b99d36dfdb4504ff85ef7d92374aa158747f841b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 15 Aug 2024 12:21:38 -0600 Subject: [PATCH] fixed issue with batch sizes larget than 1 --- extensions_built_in/sd_trainer/SDTrainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 7d9f9471..609ff0dd 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -382,9 +382,10 @@ class SDTrainer(BaseSDTrainProcess): loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") # handle linear timesteps and only adjust the weight of the timesteps - if self.sd.is_flow_matching and self.train_config.linear_timesteps: + if self.sd.is_flow_matching and self.train_config.linear_timestqeps: # calculate the weights for the timesteps timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype) + timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() loss = loss * timestep_weight if self.train_config.do_prior_divergence and prior_pred is not None: