actually got gradient checkpointing working, again, again, maybe

This commit is contained in:
Jaret Burkett
2023-09-09 11:27:42 -06:00
parent 4ed03a8d92
commit 408c50ead1
5 changed files with 102 additions and 70 deletions

View File

@@ -72,6 +72,7 @@ class SDTrainer(BaseSDTrainProcess):
# 9.18 gb
noise = noise.to(self.device_torch, dtype=dtype).detach()
if self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)