From 6ec23ed226bd72d491fecb9a2df2ce6956a6e9f9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 26 Feb 2025 12:12:32 -0700 Subject: [PATCH] Fixed issue when doing inverted masked prior with flowmatching algos --- extensions_built_in/sd_trainer/SDTrainer.py | 30 +++++++++------------ 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 5e91c344..e22f225c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -359,24 +359,20 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: assert not self.train_config.train_turbo with torch.no_grad(): - # we need to make the noise prediction be a masked blending of noise and prior_pred - stretched_mask_multiplier = value_map( - mask_multiplier, - batch.file_items[0].dataset_config.mask_min_value, - 1.0, - 0.0, - 1.0 - ) + prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype) + # resize to size of noise_pred + prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic') + # stack first channel to match channels of noise_pred + prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1) - prior_mask_multiplier = 1.0 - stretched_mask_multiplier - - - # target_mask_multiplier = mask_multiplier - # mask_multiplier = 1.0 - target = noise - # target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier) - # set masked multiplier to 1.0 so we dont double apply it - # mask_multiplier = 1.0 + prior_mask_multiplier = 1.0 - prior_mask + + # scale so it is a mean of 1 + prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean() + if self.sd.is_flow_matching: + target = (noise - batch.latents).detach() + else: + target = noise elif prior_pred is not None and not self.train_config.do_prior_divergence: assert not self.train_config.train_turbo # matching adapter prediction