Fixed some breaking changes with diffusers gradient checkpointing.

This commit is contained in:
Jaret Burkett
2025-02-10 09:35:31 -07:00
parent c6d8eedb94
commit 10aa7e9d5e

View File

@@ -1452,10 +1452,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
# print_acc("sage attention is not installed. Using SDP instead")
if self.train_config.gradient_checkpointing:
if self.sd.is_flux:
unet.gradient_checkpointing = True
else:
if hasattr(unet, 'enable_gradient_checkpointing'):
unet.enable_gradient_checkpointing()
elif hasattr(unet, 'gradient_checkpointing'):
unet.gradient_checkpointing = True
if isinstance(text_encoder, list):
for te in text_encoder:
if hasattr(te, 'enable_gradient_checkpointing'):