mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 16:30:40 +00:00
Adjust signal amplification target. Allow signal amplification strength in config.
This commit is contained in:
@@ -578,6 +578,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
raise ValueError("Signal amplification is only supported for flow matching models")
|
||||
with torch.no_grad():
|
||||
nas = 1.0 - (timesteps / 1000).to(noise.device, dtype=noise.dtype)
|
||||
nas = nas * self.train_config.signal_amplification_strength
|
||||
nas = nas.clamp(min=0.1)
|
||||
while len(nas.shape) < len(noise.shape):
|
||||
nas = nas.unsqueeze(-1)
|
||||
aug = batch.latents * nas
|
||||
|
||||
@@ -402,6 +402,7 @@ class TrainConfig:
|
||||
self.do_batch_noise_correction = kwargs.get('do_batch_noise_correction', False)
|
||||
self.batch_noise_correction_scale = kwargs.get('batch_noise_correction_scale', 0.1)
|
||||
self.do_signal_amplification = kwargs.get('do_signal_amplification', False)
|
||||
self.signal_amplification_strength = kwargs.get('signal_amplification_strength', 0.5)
|
||||
|
||||
self.signal_correction_noise_scale = kwargs.get('signal_correction_noise_scale', 1.0)
|
||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||
|
||||
Reference in New Issue
Block a user