A lot of pixart sigma training tweaks

This commit is contained in:
Jaret Burkett
2024-07-28 11:23:18 -06:00
parent 80aa2dbb80
commit 0bc4d555c7
8 changed files with 118 additions and 29 deletions

View File

@@ -169,15 +169,6 @@ class StableDiffusion:
if self.is_loaded:
return
dtype = get_torch_dtype(self.dtype)
# sch = KDPM2DiscreteScheduler
if self.noise_scheduler is None:
scheduler = get_sampler(
'ddpm', {
"prediction_type": self.prediction_type,
},
'sd' if not self.is_pixart else 'pixart'
)
self.noise_scheduler = scheduler
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
# self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
@@ -190,9 +181,10 @@ class StableDiffusion:
from toolkit.civitai import get_model_path_from_url
model_path = get_model_path_from_url(self.model_config.name_or_path)
load_args = {
'scheduler': self.noise_scheduler,
}
load_args = {}
if self.noise_scheduler:
load_args['scheduler'] = self.noise_scheduler
if self.model_config.vae_path is not None:
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
@@ -290,6 +282,7 @@ class StableDiffusion:
device=self.device_torch,
torch_dtype=self.torch_dtype,
text_encoder_3=text_encoder3,
**load_args
)
flush()
@@ -387,6 +380,8 @@ class StableDiffusion:
tokenizer = pipe.tokenizer
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
if self.noise_scheduler is None:
self.noise_scheduler = pipe.scheduler
elif self.model_config.is_auraflow: