mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-07 19:49:49 +00:00
Switched to trailing timestep spacing to make timesteps for consistant across schedulers. Honed in on targeted guidance. It is finally perfect. (I think)
This commit is contained in:
@@ -26,8 +26,8 @@ SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
sdxl_sampler_config = {
|
||||
"_class_name": "EulerDiscreteScheduler",
|
||||
"_diffusers_version": "0.19.0.dev0",
|
||||
"_class_name": "EulerAncestralDiscreteScheduler",
|
||||
"_diffusers_version": "0.24.0.dev0",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
@@ -37,11 +37,10 @@ sdxl_sampler_config = {
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": False,
|
||||
"skip_prk_steps": True,
|
||||
"skip_prk_steps": False,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "leading",
|
||||
"trained_betas": None,
|
||||
"use_karras_sigmas": False
|
||||
"timestep_spacing": "trailing",
|
||||
"trained_betas": None
|
||||
}
|
||||
|
||||
|
||||
@@ -86,7 +85,6 @@ def get_sampler(
|
||||
|
||||
scheduler = scheduler_cls.from_config(config)
|
||||
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
|
||||
@@ -674,14 +674,8 @@ class StableDiffusion:
|
||||
|
||||
for idx in range(original_samples.shape[0]):
|
||||
|
||||
if scheduler_class_name not in index_noise_schedulers:
|
||||
# convert to idx
|
||||
noise_timesteps = [(self.noise_scheduler.timesteps == t).nonzero().item() for t in timesteps_chunks[idx]]
|
||||
noise_timesteps = torch.tensor(noise_timesteps, device=self.device_torch)
|
||||
else:
|
||||
noise_timesteps = timesteps_chunks[idx]
|
||||
|
||||
# the add noise for ddpm solver is broken, do it ourselves
|
||||
noise_timesteps = timesteps_chunks[idx]
|
||||
if scheduler_class_name == 'DPMSolverMultistepScheduler':
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.noise_scheduler.sigmas.to(device=original_samples_chunks[idx].device, dtype=original_samples_chunks[idx].dtype)
|
||||
|
||||
Reference in New Issue
Block a user