mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Various bug fixes
This commit is contained in:
@@ -198,8 +198,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.train_config.diffusion_feature_extractor_path is not None:
|
||||
vae = None
|
||||
if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
||||
vae = self.sd.vae
|
||||
# if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer":
|
||||
# vae = self.sd.vae
|
||||
self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae)
|
||||
self.dfe.to(self.device_torch)
|
||||
if hasattr(self.dfe, 'vision_encoder') and self.train_config.gradient_checkpointing:
|
||||
|
||||
@@ -756,7 +756,10 @@ class StableDiffusion:
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
self.print_and_status_update("Loading VAE")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
if self.model_config.vae_path is not None:
|
||||
vae = load_vae(self.model_config.vae_path, dtype)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading T5")
|
||||
@@ -2844,12 +2847,7 @@ class StableDiffusion:
|
||||
def save_device_state(self):
|
||||
# saves the current device state for all modules
|
||||
# this is useful for when we want to alter the state and restore it
|
||||
if self.is_lumina2:
|
||||
unet_has_grad = self.unet.x_embedder.weight.requires_grad
|
||||
elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
||||
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
||||
else:
|
||||
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
||||
unet_has_grad = False
|
||||
|
||||
self.device_state = {
|
||||
**empty_preset,
|
||||
|
||||
@@ -10,7 +10,7 @@ def load_vae(vae_path, dtype):
|
||||
except Exception as e:
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path.vae_path,
|
||||
vae_path,
|
||||
subfolder="vae",
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user