From 5890e67a46f0931253b05cebec97599c0b240322 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 29 Apr 2025 09:30:33 -0600 Subject: [PATCH] Various bug fixes --- extensions_built_in/sd_trainer/SDTrainer.py | 4 ++-- toolkit/stable_diffusion_model.py | 12 +++++------- toolkit/util/vae.py | 2 +- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 5bf17441..8a8ba738 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -198,8 +198,8 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.diffusion_feature_extractor_path is not None: vae = None - if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": - vae = self.sd.vae + # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": + # vae = self.sd.vae self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae) self.dfe.to(self.device_torch) if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 3780792f..4ada3896 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -756,7 +756,10 @@ class StableDiffusion: scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") self.print_and_status_update("Loading VAE") - vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) + if self.model_config.vae_path is not None: + vae = load_vae(self.model_config.vae_path, dtype) + else: + vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) flush() self.print_and_status_update("Loading T5") @@ -2844,12 +2847,7 @@ class StableDiffusion: def save_device_state(self): # saves the current device state for all modules # this is useful for when we want to alter the state and restore it - if self.is_lumina2: - unet_has_grad = self.unet.x_embedder.weight.requires_grad - elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: - unet_has_grad = self.unet.proj_out.weight.requires_grad - else: - unet_has_grad = self.unet.conv_in.weight.requires_grad + unet_has_grad = False self.device_state = { **empty_preset, diff --git a/toolkit/util/vae.py b/toolkit/util/vae.py index 2681c6db..9a7c4052 100644 --- a/toolkit/util/vae.py +++ b/toolkit/util/vae.py @@ -10,7 +10,7 @@ def load_vae(vae_path, dtype): except Exception as e: try: vae = AutoencoderKL.from_pretrained( - vae_path.vae_path, + vae_path, subfolder="vae", torch_dtype=dtype, )