From cd37ccfc2e46c317fac845c2de71e61ca166a440 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 11 Apr 2025 10:45:39 -0600 Subject: [PATCH] Use gradient checkpointing on DFE models if set --- extensions_built_in/sd_trainer/SDTrainer.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ecdbfbf5..4840ed84 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -198,11 +198,21 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.diffusion_feature_extractor_path is not None: vae = None - if self.model_config.arch != "flux" or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": + if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": vae = self.sd.vae self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae) self.dfe.to(self.device_torch) - self.dfe.eval() + if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing: + # must be set to train for gradient checkpointing to work + self.dfe.vision_encoder.train() + self.dfe.vision_encoder.gradient_checkpointing = True + else: + self.dfe.eval() + + # enable gradient checkpointing on the vae + if vae is not None and self.train_config.gradient_checkpointing: + vae.enable_gradient_checkpointing() + vae.train() def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):