diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 541e9b03..bbb9bb05 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -5,6 +5,7 @@ from collections import OrderedDict import os from typing import Union, List +import numpy as np from diffusers import T2IAdapter # from lycoris.config import PRESET from torch.utils.data import DataLoader @@ -538,6 +539,14 @@ class BaseSDTrainProcess(BaseTrainProcess): noise_offset=self.train_config.noise_offset ).to(self.device_torch, dtype=dtype) + noise_multiplier = self.train_config.noise_multiplier + + noise = noise * noise_multiplier + + img_multiplier = self.train_config.img_multiplier + + latents = latents * img_multiplier + noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps) # remove grads for these @@ -997,10 +1006,21 @@ class BaseSDTrainProcess(BaseTrainProcess): self.progress_bar.set_postfix_str(prog_bar_string) + # apply network normalizer if we are using it, not on regularization steps + if self.network is not None and self.network.is_normalizing and not is_reg_step: + with self.timer('apply_normalizer'): + self.network.apply_stored_normalizer() + + # if the batch is a DataLoaderBatchDTO, then we need to clean it up + if isinstance(batch, DataLoaderBatchDTO): + with self.timer('batch_cleanup'): + batch.cleanup() + # don't do on first step if self.step_num != self.start_step: if is_sample_step: self.progress_bar.pause() + flush() # print above the progress bar if self.train_config.free_u: self.sd.pipeline.disable_freeu() @@ -1036,16 +1056,6 @@ class BaseSDTrainProcess(BaseTrainProcess): # end of step self.step_num = step - # apply network normalizer if we are using it, not on regularization steps - if self.network is not None and self.network.is_normalizing and not is_reg_step: - with self.timer('apply_normalizer'): - self.network.apply_stored_normalizer() - - # if the batch is a DataLoaderBatchDTO, then we need to clean it up - if isinstance(batch, DataLoaderBatchDTO): - with self.timer('batch_cleanup'): - batch.cleanup() - # flush every 10 steps # if self.step_num % 10 == 0: # flush() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index e3149505..60bfe601 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -122,6 +122,8 @@ class TrainConfig: self.start_step = kwargs.get('start_step', None) self.free_u = kwargs.get('free_u', False) self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) + self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) + self.img_multiplier = kwargs.get('img_multiplier', 1.0) match_adapter_assist = kwargs.get('match_adapter_assist', False) self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)