Various bug fixes

This commit is contained in:
Jaret Burkett
2025-04-29 09:30:33 -06:00
parent 2b4c525489
commit 5890e67a46
3 changed files with 8 additions and 10 deletions

View File

@@ -756,7 +756,10 @@ class StableDiffusion:
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
self.print_and_status_update("Loading VAE")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
if self.model_config.vae_path is not None:
vae = load_vae(self.model_config.vae_path, dtype)
else:
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
self.print_and_status_update("Loading T5")
@@ -2844,12 +2847,7 @@ class StableDiffusion:
def save_device_state(self):
# saves the current device state for all modules
# this is useful for when we want to alter the state and restore it
if self.is_lumina2:
unet_has_grad = self.unet.x_embedder.weight.requires_grad
elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
unet_has_grad = self.unet.proj_out.weight.requires_grad
else:
unet_has_grad = self.unet.conv_in.weight.requires_grad
unet_has_grad = False
self.device_state = {
**empty_preset,