mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
A lot of pixart sigma training tweaks
This commit is contained in:
@@ -169,15 +169,6 @@ class StableDiffusion:
|
||||
if self.is_loaded:
|
||||
return
|
||||
dtype = get_torch_dtype(self.dtype)
|
||||
# sch = KDPM2DiscreteScheduler
|
||||
if self.noise_scheduler is None:
|
||||
scheduler = get_sampler(
|
||||
'ddpm', {
|
||||
"prediction_type": self.prediction_type,
|
||||
},
|
||||
'sd' if not self.is_pixart else 'pixart'
|
||||
)
|
||||
self.noise_scheduler = scheduler
|
||||
|
||||
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
||||
# self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
||||
@@ -190,9 +181,10 @@ class StableDiffusion:
|
||||
from toolkit.civitai import get_model_path_from_url
|
||||
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
||||
|
||||
load_args = {
|
||||
'scheduler': self.noise_scheduler,
|
||||
}
|
||||
load_args = {}
|
||||
if self.noise_scheduler:
|
||||
load_args['scheduler'] = self.noise_scheduler
|
||||
|
||||
if self.model_config.vae_path is not None:
|
||||
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
||||
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
|
||||
@@ -290,6 +282,7 @@ class StableDiffusion:
|
||||
device=self.device_torch,
|
||||
torch_dtype=self.torch_dtype,
|
||||
text_encoder_3=text_encoder3,
|
||||
**load_args
|
||||
)
|
||||
|
||||
flush()
|
||||
@@ -387,6 +380,8 @@ class StableDiffusion:
|
||||
tokenizer = pipe.tokenizer
|
||||
|
||||
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
||||
if self.noise_scheduler is None:
|
||||
self.noise_scheduler = pipe.scheduler
|
||||
|
||||
|
||||
elif self.model_config.is_auraflow:
|
||||
|
||||
Reference in New Issue
Block a user