diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 602bbeaf..09156909 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1096,6 +1096,8 @@ class SDTrainer(BaseSDTrainProcess): # expand to match latents mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + # make avg 1.0 + mask_multiplier = mask_multiplier / mask_multiplier.mean() def get_adapter_multiplier(): if self.adapter and isinstance(self.adapter, T2IAdapter):