Make the mean of the mask multiplier be 1.0 for a more balanced loss.

This commit is contained in:
Jaret Burkett
2025-02-06 10:57:06 +00:00
parent 0e75724b4d
commit ff3d54bb5b

View File

@@ -1096,6 +1096,8 @@ class SDTrainer(BaseSDTrainProcess):
# expand to match latents
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
# make avg 1.0
mask_multiplier = mask_multiplier / mask_multiplier.mean()
def get_adapter_multiplier():
if self.adapter and isinstance(self.adapter, T2IAdapter):