Allow quantizing the te independently on flux. added lognorm_blend timestep schedule

This commit is contained in:
Jaret Burkett
2025-01-18 18:02:31 -07:00
parent 4723f23c0d
commit fadb2f3a76
3 changed files with 9 additions and 5 deletions

View File

@@ -346,8 +346,8 @@ class TrainConfig:
self.standardize_images = kwargs.get('standardize_images', False)
self.standardize_latents = kwargs.get('standardize_latents', False)
if self.train_turbo and not self.noise_scheduler.startswith("euler"):
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
# if self.train_turbo and not self.noise_scheduler.startswith("euler"):
# raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
self.do_cfg = kwargs.get('do_cfg', False)
@@ -458,6 +458,7 @@ class ModelConfig:
# only for flux for now
self.quantize = kwargs.get("quantize", False)
self.quantize_te = kwargs.get("quantize_te", self.quantize)
self.low_vram = kwargs.get("low_vram", False)
self.attn_masking = kwargs.get("attn_masking", False)
if self.attn_masking and not self.is_flux:
@@ -827,7 +828,10 @@ class GenerateImageConfig:
prompt += ' --gr ' + str(self.guidance_rescale)
# get gen info
f.write(self.prompt)
try:
f.write(self.prompt)
except Exception as e:
print(f"Error writing prompt file. Prompt contains non-unicode characters. {e}")
def _process_prompt_string(self):
# we will try to support all sd-scripts where we can

View File

@@ -110,7 +110,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
return timesteps
elif timestep_type == 'lognorm_blend':
# disgtribute timestepd to the center/early and blend in linear
alpha = 0.8
alpha = 0.75
lognormal = LogNormal(loc=0, scale=0.333)

View File

@@ -696,7 +696,7 @@ class StableDiffusion:
text_encoder_2.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize:
if self.model_config.quantize_te:
print("Quantizing T5")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)