Added training for pixart-a

This commit is contained in:
Jaret Burkett
2024-02-13 16:00:04 -07:00
parent 4ec4025cbb
commit 93b52932c1
10 changed files with 288 additions and 24 deletions

View File

@@ -1,4 +1,5 @@
import copy
import math
from diffusers import (
DDPMScheduler,
@@ -25,7 +26,7 @@ SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
sdxl_sampler_config = {
sd_config = {
"_class_name": "EulerAncestralDiscreteScheduler",
"_diffusers_version": "0.24.0.dev0",
"beta_end": 0.012,
@@ -43,15 +44,44 @@ sdxl_sampler_config = {
"trained_betas": None
}
pixart_config = {
"_class_name": "DPMSolverMultistepScheduler",
"_diffusers_version": "0.22.0.dev0",
"algorithm_type": "dpmsolver++",
"beta_end": 0.02,
"beta_schedule": "linear",
"beta_start": 0.0001,
"dynamic_thresholding_ratio": 0.995,
"euler_at_final": False,
# "lambda_min_clipped": -Infinity,
"lambda_min_clipped": -math.inf,
"lower_order_final": True,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"solver_order": 2,
"solver_type": "midpoint",
"steps_offset": 0,
"thresholding": False,
"timestep_spacing": "linspace",
"trained_betas": None,
"use_karras_sigmas": False,
"use_lu_lambdas": False,
"variance_type": None
}
def get_sampler(
sampler: str,
kwargs: dict = None,
arch: str = "sd"
):
sched_init_args = {}
if kwargs is not None:
sched_init_args.update(kwargs)
config_to_use = copy.deepcopy(sd_config) if arch == "sd" else copy.deepcopy(pixart_config)
if sampler.startswith("k_"):
sched_init_args["use_karras_sigmas"] = True
@@ -83,7 +113,7 @@ def get_sampler(
elif sampler == "custom_lcm":
scheduler_cls = CustomLCMScheduler
config = copy.deepcopy(sdxl_sampler_config)
config = copy.deepcopy(config_to_use)
config.update(sched_init_args)
scheduler = scheduler_cls.from_config(config)