diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index da2baacf..f3f750bd 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -145,7 +145,11 @@ class SDTrainer(BaseSDTrainProcess): loss = loss.mean([1, 2, 3]) - if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: + + if self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: + # add snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: # add min_snr_gamma loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 53266604..2f1c9176 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -559,12 +559,12 @@ class BaseSDTrainProcess(BaseTrainProcess): # for content / structure, it is best to favor earlier timesteps # for style, it is best to favor later timesteps - timesteps = torch.rand((batch_size,), device=latents.device) + orig_timesteps = torch.rand((batch_size,), device=latents.device) - if self.train_config.content_or_style == 'style': - timesteps = timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps'] - elif self.train_config.content_or_style == 'content': - timesteps = (1 - timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps'] + if self.train_config.content_or_style == 'content': + timesteps = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps'] + elif self.train_config.content_or_style == 'style': + timesteps = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps'] timesteps = value_map( timesteps, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 71315b18..267a406a 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -168,6 +168,7 @@ class TrainConfig: self.train_unet = kwargs.get('train_unet', True) self.train_text_encoder = kwargs.get('train_text_encoder', True) self.min_snr_gamma = kwargs.get('min_snr_gamma', None) + self.snr_gamma = kwargs.get('snr_gamma', None) self.noise_offset = kwargs.get('noise_offset', 0.0) self.skip_first_sample = kwargs.get('skip_first_sample', False) self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 99cf85f9..fe7a9964 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -688,13 +688,17 @@ def apply_snr_weight( loss, timesteps, noise_scheduler: Union['DDPMScheduler'], - gamma + gamma, + fixed=False, ): - # will get it form noise scheduler if exist or will calculate it if not + # will get it from noise scheduler if exist or will calculate it if not all_snr = get_all_snr(noise_scheduler, loss.device) snr = torch.stack([all_snr[t] for t in timesteps]) gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) - snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper + if fixed: + snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr + else: + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) loss = loss * snr_weight return loss