From 0e9fc42816d26e2d5b31859cc7c257800178ed93 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 21 Oct 2023 06:50:17 -0600 Subject: [PATCH] Allow for inverted masked prior --- extensions_built_in/sd_trainer/SDTrainer.py | 30 +++++++++++++-------- toolkit/config_modules.py | 5 ++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 02d01ec1..ca608b6e 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -64,14 +64,20 @@ class SDTrainer(BaseSDTrainProcess): timesteps: torch.Tensor, batch: 'DataLoaderBatchDTO', mask_multiplier: Union[torch.Tensor, float] = 1.0, - control_pred: Union[torch.Tensor, None] = None, + prior_pred: Union[torch.Tensor, None] = None, **kwargs ): loss_target = self.train_config.loss_target - # add latents and unaug latents - if control_pred is not None: + + if self.train_config.inverted_mask_prior: + # we need to make the noise prediction be a masked blending of noise and prior_pred + prior_multiplier = 1.0 - mask_multiplier + target = (noise * mask_multiplier) + (prior_pred * prior_multiplier) + # set masked multiplier to 1.0 so we dont double apply it + mask_multiplier = 1.0 + elif prior_pred is not None: # matching adapter prediction - target = control_pred + target = prior_pred elif self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) @@ -280,15 +286,15 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals - control_pred = None - if has_adapter_img and self.assistant_adapter and match_adapter_assist: - with self.timer('predict_with_adapter'): + prior_pred = None + if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.train_config.inverted_mask_prior: + with self.timer('prior predict'): # do a prediction here so we can match its output with network multiplier set to 0.0 with torch.no_grad(): # dont use network on this network.multiplier = 0.0 self.sd.unet.eval() - control_pred = self.sd.predict_noise( + prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), timestep=timesteps, @@ -296,12 +302,14 @@ class SDTrainer(BaseSDTrainProcess): **pred_kwargs # adapter residuals in here ) self.sd.unet.train() - control_pred = control_pred.detach() + prior_pred = prior_pred.detach() # remove the residuals as we wont use them on prediction when matching control - del pred_kwargs['down_block_additional_residuals'] + if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: + del pred_kwargs['down_block_additional_residuals'] # restore network network.multiplier = network_weight_list + if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter'): with torch.no_grad(): @@ -326,7 +334,7 @@ class SDTrainer(BaseSDTrainProcess): timesteps=timesteps, batch=batch, mask_multiplier=mask_multiplier, - control_pred=control_pred, + prior_pred=prior_pred, ) # check if nan if torch.isnan(loss): diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 10f335b0..fe0dd59a 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -129,6 +129,11 @@ class TrainConfig: self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0) self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, differential_noise + # When a mask is passed in a dataset, and this is true, + # we will predict noise without a the LoRa network and use the prediction as a target for + # unmasked reign. It is unmasked regularization basically + self.inverted_mask_prior = kwargs.get('inverted_mask_prior', False) + # legacy if match_adapter_assist and self.match_adapter_chance == 0.0: self.match_adapter_chance = 1.0