Fixed issue when doing inverted masked prior with flowmatching algos

This commit is contained in:
Jaret Burkett
2025-02-26 12:12:32 -07:00
parent f6e16e582a
commit 6ec23ed226

View File

@@ -359,24 +359,20 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo assert not self.train_config.train_turbo
with torch.no_grad(): with torch.no_grad():
# we need to make the noise prediction be a masked blending of noise and prior_pred prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype)
stretched_mask_multiplier = value_map( # resize to size of noise_pred
mask_multiplier, prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic')
batch.file_items[0].dataset_config.mask_min_value, # stack first channel to match channels of noise_pred
1.0, prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1)
0.0,
1.0
)
prior_mask_multiplier = 1.0 - stretched_mask_multiplier prior_mask_multiplier = 1.0 - prior_mask
# scale so it is a mean of 1
# target_mask_multiplier = mask_multiplier prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean()
# mask_multiplier = 1.0 if self.sd.is_flow_matching:
target = noise target = (noise - batch.latents).detach()
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier) else:
# set masked multiplier to 1.0 so we dont double apply it target = noise
# mask_multiplier = 1.0
elif prior_pred is not None and not self.train_config.do_prior_divergence: elif prior_pred is not None and not self.train_config.do_prior_divergence:
assert not self.train_config.train_turbo assert not self.train_config.train_turbo
# matching adapter prediction # matching adapter prediction