mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed issue when doing inverted masked prior with flowmatching algos
This commit is contained in:
@@ -359,24 +359,20 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||||
assert not self.train_config.train_turbo
|
assert not self.train_config.train_turbo
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
prior_mask = batch.mask_tensor.to(self.device_torch, dtype=dtype)
|
||||||
stretched_mask_multiplier = value_map(
|
# resize to size of noise_pred
|
||||||
mask_multiplier,
|
prior_mask = torch.nn.functional.interpolate(prior_mask, size=(noise_pred.shape[2], noise_pred.shape[3]), mode='bicubic')
|
||||||
batch.file_items[0].dataset_config.mask_min_value,
|
# stack first channel to match channels of noise_pred
|
||||||
1.0,
|
prior_mask = torch.cat([prior_mask[:1]] * noise_pred.shape[1], dim=1)
|
||||||
0.0,
|
|
||||||
1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
|
prior_mask_multiplier = 1.0 - prior_mask
|
||||||
|
|
||||||
|
# scale so it is a mean of 1
|
||||||
# target_mask_multiplier = mask_multiplier
|
prior_mask_multiplier = prior_mask_multiplier / prior_mask_multiplier.mean()
|
||||||
# mask_multiplier = 1.0
|
if self.sd.is_flow_matching:
|
||||||
target = noise
|
target = (noise - batch.latents).detach()
|
||||||
# target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier)
|
else:
|
||||||
# set masked multiplier to 1.0 so we dont double apply it
|
target = noise
|
||||||
# mask_multiplier = 1.0
|
|
||||||
elif prior_pred is not None and not self.train_config.do_prior_divergence:
|
elif prior_pred is not None and not self.train_config.do_prior_divergence:
|
||||||
assert not self.train_config.train_turbo
|
assert not self.train_config.train_turbo
|
||||||
# matching adapter prediction
|
# matching adapter prediction
|
||||||
|
|||||||
Reference in New Issue
Block a user