mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Use gradient checkpointing on DFE models if set
This commit is contained in:
@@ -198,11 +198,21 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||||
vae = None
|
vae = None
|
||||||
if self.model_config.arch != "flux" or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
||||||
vae = self.sd.vae
|
vae = self.sd.vae
|
||||||
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
|
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
|
||||||
self.dfe.to(self.device_torch)
|
self.dfe.to(self.device_torch)
|
||||||
self.dfe.eval()
|
if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing:
|
||||||
|
# must be set to train for gradient checkpointing to work
|
||||||
|
self.dfe.vision_encoder.train()
|
||||||
|
self.dfe.vision_encoder.gradient_checkpointing = True
|
||||||
|
else:
|
||||||
|
self.dfe.eval()
|
||||||
|
|
||||||
|
# enable gradient checkpointing on the vae
|
||||||
|
if vae is not None and self.train_config.gradient_checkpointing:
|
||||||
|
vae.enable_gradient_checkpointing()
|
||||||
|
vae.train()
|
||||||
|
|
||||||
|
|
||||||
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
||||||
|
|||||||
Reference in New Issue
Block a user