Merge remote-tracking branch 'origin/main' into accelerate-multi-gpu

This commit is contained in:
Jaret Burkett
2025-01-26 11:19:34 -07:00

View File

@@ -76,6 +76,18 @@ pixart_config = {
"variance_type": None
}
flux_config = {
"_class_name": "FlowMatchEulerDiscreteScheduler",
"_diffusers_version": "0.30.0.dev0",
"base_image_seq_len": 256,
"base_shift": 0.5,
"max_image_seq_len": 4096,
"max_shift": 1.15,
"num_train_timesteps": 1000,
"shift": 3.0,
"use_dynamic_shifting": True
}
def get_sampler(
sampler: str,
@@ -120,12 +132,7 @@ def get_sampler(
scheduler_cls = CustomLCMScheduler
elif sampler == "flowmatch":
scheduler_cls = CustomFlowMatchEulerDiscreteScheduler
config_to_use = {
"_class_name": "FlowMatchEulerDiscreteScheduler",
"_diffusers_version": "0.29.0.dev0",
"num_train_timesteps": 1000,
"shift": 3.0
}
config_to_use = copy.deepcopy(flux_config)
else:
raise ValueError(f"Sampler {sampler} not supported")