From 2bf3e529ce59bfa9d3a6c45c458d65bde5bdaa2f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 1 Aug 2023 15:43:27 -0600 Subject: [PATCH] Set gradient checkpointing on unet enabled by default. Help out immensly with sdxl backprop spikes --- config/examples/train_slider.example.yml | 3 +++ toolkit/config_modules.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/config/examples/train_slider.example.yml b/config/examples/train_slider.example.yml index c339e5aa..84b9eca6 100644 --- a/config/examples/train_slider.example.yml +++ b/config/examples/train_slider.example.yml @@ -34,6 +34,8 @@ config: steps: 500 # I have had good results with 4e-4 to 1e-4 at 500 steps lr: 1e-4 + # enables gradient checkpoint, saves vram, leave it on + gradient_checkpointing: true # train the unet. I recommend leaving this true train_unet: true # train the text encoder. I don't recommend this unless you have a special use case @@ -66,6 +68,7 @@ config: name_or_path: "runwayml/stable-diffusion-v1-5" is_v2: false # for v2 models is_v_pred: false # for v-prediction models (most v2 models) + is_xl: false # for SDXL models # saving config save: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ef8f3c95..2f618512 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -60,7 +60,7 @@ class TrainConfig: self.noise_offset = kwargs.get('noise_offset', 0.0) self.optimizer_params = kwargs.get('optimizer_params', {}) self.skip_first_sample = kwargs.get('skip_first_sample', False) - self.gradient_checkpointing = kwargs.get('gradient_checkpointing', False) + self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) class ModelConfig: