From 10aa7e9d5e1133c751b6169b4ff3c36f9904bdfc Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 10 Feb 2025 09:35:31 -0700 Subject: [PATCH] Fixed some breaking changes with diffusers gradient checkpointing. --- jobs/process/BaseSDTrainProcess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 6f057455..89ceac4f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1452,10 +1452,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # print_acc("sage attention is not installed. Using SDP instead") if self.train_config.gradient_checkpointing: - if self.sd.is_flux: - unet.gradient_checkpointing = True - else: + if hasattr(unet, 'enable_gradient_checkpointing'): unet.enable_gradient_checkpointing() + elif hasattr(unet, 'gradient_checkpointing'): + unet.gradient_checkpointing = True if isinstance(text_encoder, list): for te in text_encoder: if hasattr(te, 'enable_gradient_checkpointing'):