mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Make the mean of the mask multiplier be 1.0 for a more balanced loss.
This commit is contained in:
@@ -1096,6 +1096,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# expand to match latents
|
||||
mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1)
|
||||
mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach()
|
||||
# make avg 1.0
|
||||
mask_multiplier = mask_multiplier / mask_multiplier.mean()
|
||||
|
||||
def get_adapter_multiplier():
|
||||
if self.adapter and isinstance(self.adapter, T2IAdapter):
|
||||
|
||||
Reference in New Issue
Block a user