diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 6f057455..89ceac4f 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1452,10 +1452,10 @@ class BaseSDTrainProcess(BaseTrainProcess): # print_acc("sage attention is not installed. Using SDP instead") if self.train_config.gradient_checkpointing: - if self.sd.is_flux: - unet.gradient_checkpointing = True - else: + if hasattr(unet, 'enable_gradient_checkpointing'): unet.enable_gradient_checkpointing() + elif hasattr(unet, 'gradient_checkpointing'): + unet.gradient_checkpointing = True if isinstance(text_encoder, list): for te in text_encoder: if hasattr(te, 'enable_gradient_checkpointing'):