Allow for inverted masked prior

This commit is contained in:
Jaret Burkett
2023-10-21 06:50:17 -06:00
parent d46112a354
commit 0e9fc42816
2 changed files with 24 additions and 11 deletions

View File

@@ -64,14 +64,20 @@ class SDTrainer(BaseSDTrainProcess):
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
control_pred: Union[torch.Tensor, None] = None,
prior_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
# add latents and unaug latents
if control_pred is not None:
if self.train_config.inverted_mask_prior:
# we need to make the noise prediction be a masked blending of noise and prior_pred
prior_multiplier = 1.0 - mask_multiplier
target = (noise * mask_multiplier) + (prior_pred * prior_multiplier)
# set masked multiplier to 1.0 so we dont double apply it
mask_multiplier = 1.0
elif prior_pred is not None:
# matching adapter prediction
target = control_pred
target = prior_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
@@ -280,15 +286,15 @@ class SDTrainer(BaseSDTrainProcess):
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
control_pred = None
if has_adapter_img and self.assistant_adapter and match_adapter_assist:
with self.timer('predict_with_adapter'):
prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.train_config.inverted_mask_prior:
with self.timer('prior predict'):
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
# dont use network on this
network.multiplier = 0.0
self.sd.unet.eval()
control_pred = self.sd.predict_noise(
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
@@ -296,12 +302,14 @@ class SDTrainer(BaseSDTrainProcess):
**pred_kwargs # adapter residuals in here
)
self.sd.unet.train()
control_pred = control_pred.detach()
prior_pred = prior_pred.detach()
# remove the residuals as we wont use them on prediction when matching control
del pred_kwargs['down_block_additional_residuals']
if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs:
del pred_kwargs['down_block_additional_residuals']
# restore network
network.multiplier = network_weight_list
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
with torch.no_grad():
@@ -326,7 +334,7 @@ class SDTrainer(BaseSDTrainProcess):
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
control_pred=control_pred,
prior_pred=prior_pred,
)
# check if nan
if torch.isnan(loss):

View File

@@ -129,6 +129,11 @@ class TrainConfig:
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise
# When a mask is passed in a dataset, and this is true,
# we will predict noise without a the LoRa network and use the prediction as a target for
# unmasked reign. It is unmasked regularization basically
self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False)
# legacy
if match_adapter_assist and self.match_adapter_chance == 0.0:
self.match_adapter_chance = 1.0