diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 91f62a5a..d0663e91 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1336,6 +1336,14 @@ class BaseSDTrainProcess(BaseTrainProcess): latent_multiplier = normalizer latents = latents * latent_multiplier + + if self.train_config.do_blank_stabilization: + # zero out latents with blank prompts + blank_latent = torch.zeros_like(latents) + for i, prompt in enumerate(conditioned_prompts): + if prompt.strip() == '': + latents[i] = blank_latent[i] + batch.latents = latents # normalize latents to a mean of 0 and an std of 1 diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 974e0cb6..d93f471b 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -559,6 +559,9 @@ class TrainConfig: # for multi stage models, how often to switch the boundary self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1) + # stabilizes empty prompts to be zeroed predictions + self.do_blank_stabilization = kwargs.get('do_blank_stabilization', False) + ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']