mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 16:30:40 +00:00
Add signal amplification
This commit is contained in:
@@ -573,7 +573,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
||||
|
||||
elif self.train_config.do_signal_amplification:
|
||||
if not self.sd.is_flow_matching:
|
||||
raise ValueError("Signal amplification is only supported for flow matching models")
|
||||
with torch.no_grad():
|
||||
nas = 1.0 - (timesteps / 1000).to(noise.device, dtype=noise.dtype)
|
||||
while len(nas.shape) < len(noise.shape):
|
||||
nas = nas.unsqueeze(-1)
|
||||
aug = batch.latents * nas
|
||||
target = noise - (batch.latents + aug)
|
||||
target = target.detach()
|
||||
elif hasattr(self.sd, 'get_loss_target'):
|
||||
target = self.sd.get_loss_target(
|
||||
noise=noise,
|
||||
|
||||
@@ -401,6 +401,7 @@ class TrainConfig:
|
||||
# batch noise correction adds other images in the batch as noise to correct away from other images
|
||||
self.do_batch_noise_correction = kwargs.get('do_batch_noise_correction', False)
|
||||
self.batch_noise_correction_scale = kwargs.get('batch_noise_correction_scale', 0.1)
|
||||
self.do_signal_amplification = kwargs.get('do_signal_amplification', False)
|
||||
|
||||
self.signal_correction_noise_scale = kwargs.get('signal_correction_noise_scale', 1.0)
|
||||
self.random_noise_shift = kwargs.get('random_noise_shift', 0.0)
|
||||
|
||||
Reference in New Issue
Block a user