diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ffe8aff6..066969a1 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -227,37 +227,30 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.correct_pred_norm and not is_reg: with torch.no_grad(): + # this only works if doing a prior pred + if prior_pred is not None: + prior_mean = prior_pred.mean([2,3], keepdim=True) + prior_std = prior_pred.std([2,3], keepdim=True) + noise_mean = noise_pred.mean([2,3], keepdim=True) + noise_std = noise_pred.std([2,3], keepdim=True) - # adjust the noise target in the opposite direction of the noise pred mean and std offset - # this will apply additional force the model to correct itself to match the norm of the noise - noise_pred_mean, noise_pred_std = get_mean_std(noise_pred) - noise_mean, noise_std = get_mean_std(noise) + mean_adjust = prior_mean - noise_mean + std_adjust = prior_std - noise_std - # apply the inverse offset of the mean and std to the noise - noise_additional_mean = noise_mean - noise_pred_mean - noise_additional_std = noise_std - noise_pred_std + mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier + std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier - # adjust for multiplier - noise_additional_mean = noise_additional_mean * self.train_config.correct_pred_norm_multiplier - noise_additional_std = noise_additional_std * self.train_config.correct_pred_norm_multiplier + target_mean = noise_mean + mean_adjust + target_std = noise_std + std_adjust - noise_target_std = noise_std + noise_additional_std - noise_target_mean = noise_mean + noise_additional_mean + eps = 1e-5 - - noise_pred_target_std = noise_pred_std - noise_additional_std - noise_pred_target_mean = noise_pred_mean - noise_additional_mean - noise_pred_target_std = noise_pred_target_std.detach() - noise_pred_target_mean = noise_pred_target_mean.detach() - - # match the noise to the target - noise = (noise - noise_mean) / noise_std - noise = noise * noise_target_std + noise_target_mean - noise = noise.detach() - - # meatch the noise pred to the target - # noise_pred = (noise_pred - noise_pred_mean) / noise_pred_std - # noise_pred = noise_pred * noise_pred_target_std + noise_pred_target_mean + # adjust the noise target to match the current knowledge of the model + # noise_mean, noise_std = get_mean_std(noise) + # match the noise to the prior + noise = (noise - noise_mean) / (noise_std + eps) + noise = noise * (target_std + eps) + target_mean + noise = noise.detach() if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: assert not self.train_config.train_turbo diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b8b788df..19861c4d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -306,6 +306,7 @@ class ModelConfig: self.refiner_name_or_path = kwargs.get('refiner_name_or_path', None) self._original_refiner_name_or_path = self.refiner_name_or_path self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) + self.lora_path = kwargs.get('lora_path', None) # only for SDXL models for now self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index f485366c..e10239ca 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -281,6 +281,12 @@ class StableDiffusion: self.unet.requires_grad_(False) self.unet.eval() + # load any loras we have + if self.model_config.lora_path is not None: + pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1") + pipe.fuse_lora() + self.unet.fuse_lora() + self.tokenizer = tokenizer self.text_encoder = text_encoder self.pipeline = pipe