mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
115 lines
3.4 KiB
Python
115 lines
3.4 KiB
Python
import copy
|
|
|
|
from diffusers import (
|
|
DDPMScheduler,
|
|
EulerAncestralDiscreteScheduler,
|
|
DPMSolverMultistepScheduler,
|
|
DPMSolverSinglestepScheduler,
|
|
LMSDiscreteScheduler,
|
|
PNDMScheduler,
|
|
DDIMScheduler,
|
|
EulerDiscreteScheduler,
|
|
HeunDiscreteScheduler,
|
|
KDPM2DiscreteScheduler,
|
|
KDPM2AncestralDiscreteScheduler,
|
|
LCMScheduler
|
|
)
|
|
|
|
from k_diffusion.external import CompVisDenoiser
|
|
|
|
from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler
|
|
|
|
# scheduler:
|
|
SCHEDULER_LINEAR_START = 0.00085
|
|
SCHEDULER_LINEAR_END = 0.0120
|
|
SCHEDULER_TIMESTEPS = 1000
|
|
SCHEDLER_SCHEDULE = "scaled_linear"
|
|
|
|
sdxl_sampler_config = {
|
|
"_class_name": "EulerAncestralDiscreteScheduler",
|
|
"_diffusers_version": "0.24.0.dev0",
|
|
"beta_end": 0.012,
|
|
"beta_schedule": "scaled_linear",
|
|
"beta_start": 0.00085,
|
|
"clip_sample": False,
|
|
"interpolation_type": "linear",
|
|
"num_train_timesteps": 1000,
|
|
"prediction_type": "epsilon",
|
|
"sample_max_value": 1.0,
|
|
"set_alpha_to_one": False,
|
|
"skip_prk_steps": False,
|
|
"steps_offset": 1,
|
|
"timestep_spacing": "trailing",
|
|
"trained_betas": None
|
|
}
|
|
|
|
|
|
def get_sampler(
|
|
sampler: str,
|
|
):
|
|
sched_init_args = {}
|
|
|
|
if sampler.startswith("k_"):
|
|
sched_init_args["use_karras_sigmas"] = True
|
|
|
|
if sampler == "ddim":
|
|
scheduler_cls = DDIMScheduler
|
|
elif sampler == "ddpm": # ddpm is not supported ?
|
|
scheduler_cls = DDPMScheduler
|
|
elif sampler == "pndm":
|
|
scheduler_cls = PNDMScheduler
|
|
elif sampler == "lms" or sampler == "k_lms":
|
|
scheduler_cls = LMSDiscreteScheduler
|
|
elif sampler == "euler" or sampler == "k_euler":
|
|
scheduler_cls = EulerDiscreteScheduler
|
|
elif sampler == "euler_a":
|
|
scheduler_cls = EulerAncestralDiscreteScheduler
|
|
elif sampler == "dpmsolver" or sampler == "dpmsolver++" or sampler == "k_dpmsolver" or sampler == "k_dpmsolver++":
|
|
scheduler_cls = DPMSolverMultistepScheduler
|
|
sched_init_args["algorithm_type"] = sampler.replace("k_", "")
|
|
elif sampler == "dpmsingle":
|
|
scheduler_cls = DPMSolverSinglestepScheduler
|
|
elif sampler == "heun":
|
|
scheduler_cls = HeunDiscreteScheduler
|
|
elif sampler == "dpm_2":
|
|
scheduler_cls = KDPM2DiscreteScheduler
|
|
elif sampler == "dpm_2_a":
|
|
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
|
elif sampler == "lcm":
|
|
scheduler_cls = LCMScheduler
|
|
elif sampler == "custom_lcm":
|
|
scheduler_cls = CustomLCMScheduler
|
|
|
|
config = copy.deepcopy(sdxl_sampler_config)
|
|
config.update(sched_init_args)
|
|
|
|
scheduler = scheduler_cls.from_config(config)
|
|
|
|
return scheduler
|
|
|
|
|
|
# testing
|
|
if __name__ == "__main__":
|
|
from diffusers import DiffusionPipeline
|
|
|
|
from diffusers import StableDiffusionKDiffusionPipeline
|
|
import torch
|
|
import os
|
|
|
|
inference_steps = 25
|
|
|
|
pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
|
|
pipe = pipe.to("cuda")
|
|
|
|
k_diffusion_model = CompVisDenoiser(model)
|
|
|
|
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="sd_text2img_k_diffusion")
|
|
pipe = pipe.to("cuda")
|
|
|
|
prompt = "an astronaut riding a horse on mars"
|
|
pipe.set_scheduler("sample_heun")
|
|
generator = torch.Generator(device="cuda").manual_seed(seed)
|
|
image = pipe(prompt, generator=generator, num_inference_steps=20).images[0]
|
|
|
|
image.save("./astronaut_heun_k_diffusion.png")
|