Fixed issue when doing inverted masked prior with flowmatching algos

This commit is contained in:
Jaret Burkett
2025-02-26 12:12:32 -07:00
parent f6e16e582a
commit 6ec23ed226

View File

@@ -359,24 +359,20 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
with torch.no_grad():
# we need to make the noise prediction be a masked blending of noise and prior_pred
stretched_mask_multiplier = value_map(
mask_multiplier,
batch.file_items[0].dataset_config.mask_min_value,
1.0,
0.0,
1.0
)
prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype)
# resize to size of noise_pred
prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic')
# stack first channel to match channels of noise_pred
prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1)
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
# target_mask_multiplier = mask_multiplier
# mask_multiplier = 1.0
target = noise
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
# set masked multiplier to 1.0 so we dont double apply it
# mask_multiplier = 1.0
prior_mask_multiplier = 1.0 - prior_mask
# scale so it is a mean of 1
prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean()
if self.sd.is_flow_matching:
target = (noise - batch.latents).detach()
else:
target = noise
elif prior_pred is not None and not self.train_config.do_prior_divergence:
assert not self.train_config.train_turbo
# matching adapter prediction