mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Fixed some breaking changes with diffusers gradient checkpointing.
This commit is contained in:
@@ -1452,10 +1452,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# print_acc("sage attention is not installed. Using SDP instead")
|
# print_acc("sage attention is not installed. Using SDP instead")
|
||||||
|
|
||||||
if self.train_config.gradient_checkpointing:
|
if self.train_config.gradient_checkpointing:
|
||||||
if self.sd.is_flux:
|
if hasattr(unet, 'enable_gradient_checkpointing'):
|
||||||
unet.gradient_checkpointing = True
|
|
||||||
else:
|
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
elif hasattr(unet, 'gradient_checkpointing'):
|
||||||
|
unet.gradient_checkpointing = True
|
||||||
if isinstance(text_encoder, list):
|
if isinstance(text_encoder, list):
|
||||||
for te in text_encoder:
|
for te in text_encoder:
|
||||||
if hasattr(te, 'enable_gradient_checkpointing'):
|
if hasattr(te, 'enable_gradient_checkpointing'):
|
||||||
|
|||||||
Reference in New Issue
Block a user