Added ability to activate experimental blank stabilization during training to zero out latents with blank prompts.

This commit is contained in:
Jaret Burkett
2026-02-04 13:00:03 -07:00
parent 42acb0d4be
commit 5c37db04f9
2 changed files with 11 additions and 0 deletions

View File

@@ -1336,6 +1336,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
latent_multiplier = normalizer
latents = latents * latent_multiplier
if self.train_config.do_blank_stabilization:
# zero out latents with blank prompts
blank_latent = torch.zeros_like(latents)
for i, prompt in enumerate(conditioned_prompts):
if prompt.strip() == '':
latents[i] = blank_latent[i]
batch.latents = latents
# normalize latents to a mean of 0 and an std of 1

View File

@@ -559,6 +559,9 @@ class TrainConfig:
# for multi stage models, how often to switch the boundary
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
# stabilizes empty prompts to be zeroed predictions
self.do_blank_stabilization = kwargs.get('do_blank_stabilization', False)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']