diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index ae6da3de..65a3691d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -346,8 +346,8 @@ class TrainConfig: self.standardize_images = kwargs.get('standardize_images', False) self.standardize_latents = kwargs.get('standardize_latents', False) - if self.train_turbo and not self.noise_scheduler.startswith("euler"): - raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers") + # if self.train_turbo and not self.noise_scheduler.startswith("euler"): + # raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers") self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False) self.do_cfg = kwargs.get('do_cfg', False) @@ -458,6 +458,7 @@ class ModelConfig: # only for flux for now self.quantize = kwargs.get("quantize", False) + self.quantize_te = kwargs.get("quantize_te", self.quantize) self.low_vram = kwargs.get("low_vram", False) self.attn_masking = kwargs.get("attn_masking", False) if self.attn_masking and not self.is_flux: @@ -827,7 +828,10 @@ class GenerateImageConfig: prompt += ' --gr ' + str(self.guidance_rescale) # get gen info - f.write(self.prompt) + try: + f.write(self.prompt) + except Exception as e: + print(f"Error writing prompt file. Prompt contains non-unicode characters. {e}") def _process_prompt_string(self): # we will try to support all sd-scripts where we can diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index c8c190a5..6c5b90df 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -110,7 +110,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): return timesteps elif timestep_type == 'lognorm_blend': # disgtribute timestepd to the center/early and blend in linear - alpha = 0.8 + alpha = 0.75 lognormal = LogNormal(loc=0, scale=0.333) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index aaa30898..286d2410 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -696,7 +696,7 @@ class StableDiffusion: text_encoder_2.to(self.device_torch, dtype=dtype) flush() - if self.model_config.quantize: + if self.model_config.quantize_te: print("Quantizing T5") quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2)