mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added training for pixart-a
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import math
|
||||
|
||||
from diffusers import (
|
||||
DDPMScheduler,
|
||||
@@ -25,7 +26,7 @@ SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
sdxl_sampler_config = {
|
||||
sd_config = {
|
||||
"_class_name": "EulerAncestralDiscreteScheduler",
|
||||
"_diffusers_version": "0.24.0.dev0",
|
||||
"beta_end": 0.012,
|
||||
@@ -43,15 +44,44 @@ sdxl_sampler_config = {
|
||||
"trained_betas": None
|
||||
}
|
||||
|
||||
pixart_config = {
|
||||
"_class_name": "DPMSolverMultistepScheduler",
|
||||
"_diffusers_version": "0.22.0.dev0",
|
||||
"algorithm_type": "dpmsolver++",
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"beta_start": 0.0001,
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"euler_at_final": False,
|
||||
# "lambda_min_clipped": -Infinity,
|
||||
"lambda_min_clipped": -math.inf,
|
||||
"lower_order_final": True,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"solver_order": 2,
|
||||
"solver_type": "midpoint",
|
||||
"steps_offset": 0,
|
||||
"thresholding": False,
|
||||
"timestep_spacing": "linspace",
|
||||
"trained_betas": None,
|
||||
"use_karras_sigmas": False,
|
||||
"use_lu_lambdas": False,
|
||||
"variance_type": None
|
||||
}
|
||||
|
||||
|
||||
def get_sampler(
|
||||
sampler: str,
|
||||
kwargs: dict = None,
|
||||
arch: str = "sd"
|
||||
):
|
||||
sched_init_args = {}
|
||||
if kwargs is not None:
|
||||
sched_init_args.update(kwargs)
|
||||
|
||||
config_to_use = copy.deepcopy(sd_config) if arch == "sd" else copy.deepcopy(pixart_config)
|
||||
|
||||
if sampler.startswith("k_"):
|
||||
sched_init_args["use_karras_sigmas"] = True
|
||||
|
||||
@@ -83,7 +113,7 @@ def get_sampler(
|
||||
elif sampler == "custom_lcm":
|
||||
scheduler_cls = CustomLCMScheduler
|
||||
|
||||
config = copy.deepcopy(sdxl_sampler_config)
|
||||
config = copy.deepcopy(config_to_use)
|
||||
config.update(sched_init_args)
|
||||
|
||||
scheduler = scheduler_cls.from_config(config)
|
||||
|
||||
Reference in New Issue
Block a user