diff --git a/toolkit/sampler.py b/toolkit/sampler.py index f9b0311b..aae6e379 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -76,6 +76,18 @@ pixart_config = { "variance_type": None } +flux_config = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.30.0.dev0", + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": True +} + def get_sampler( sampler: str, @@ -120,12 +132,7 @@ def get_sampler( scheduler_cls = CustomLCMScheduler elif sampler == "flowmatch": scheduler_cls = CustomFlowMatchEulerDiscreteScheduler - config_to_use = { - "_class_name": "FlowMatchEulerDiscreteScheduler", - "_diffusers_version": "0.29.0.dev0", - "num_train_timesteps": 1000, - "shift": 3.0 - } + config_to_use = copy.deepcopy(flux_config) else: raise ValueError(f"Sampler {sampler} not supported")