mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 22:49:48 +00:00
Added ability to activate experimental blank stabilization during training to zero out latents with blank prompts.
This commit is contained in:
@@ -1336,6 +1336,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
latent_multiplier = normalizer
|
||||
|
||||
latents = latents * latent_multiplier
|
||||
|
||||
if self.train_config.do_blank_stabilization:
|
||||
# zero out latents with blank prompts
|
||||
blank_latent = torch.zeros_like(latents)
|
||||
for i, prompt in enumerate(conditioned_prompts):
|
||||
if prompt.strip() == '':
|
||||
latents[i] = blank_latent[i]
|
||||
|
||||
batch.latents = latents
|
||||
|
||||
# normalize latents to a mean of 0 and an std of 1
|
||||
|
||||
@@ -559,6 +559,9 @@ class TrainConfig:
|
||||
# for multi stage models, how often to switch the boundary
|
||||
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
|
||||
|
||||
# stabilizes empty prompts to be zeroed predictions
|
||||
self.do_blank_stabilization = kwargs.get('do_blank_stabilization', False)
|
||||
|
||||
|
||||
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user