Set gradient checkpointing on unet enabled by default. Help out immensly with sdxl backprop spikes

This commit is contained in:
Jaret Burkett
2023-08-01 15:43:27 -06:00
parent f53fd08690
commit 2bf3e529ce
2 changed files with 4 additions and 1 deletions

View File

@@ -60,7 +60,7 @@ class TrainConfig:
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
class ModelConfig: