From ff3d54bb5b46c38523f2480d1408add5d6c60c9b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 6 Feb 2025 10:57:06 +0000 Subject: [PATCH] Make the mean of the mask multiplier be 1.0 for a more balanced loss. --- extensions_built_in/sd_trainer/SDTrainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 602bbeaf..09156909 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1096,6 +1096,8 @@ class SDTrainer(BaseSDTrainProcess): # expand to match latents mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + # make avg 1.0 + mask_multiplier = mask_multiplier / mask_multiplier.mean() def get_adapter_multiplier(): if self.adapter and isinstance(self.adapter, T2IAdapter):