mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Allow quantizing the te independently on flux. added lognorm_blend timestep schedule
This commit is contained in:
@@ -346,8 +346,8 @@ class TrainConfig:
|
||||
self.standardize_images = kwargs.get('standardize_images', False)
|
||||
self.standardize_latents = kwargs.get('standardize_latents', False)
|
||||
|
||||
if self.train_turbo and not self.noise_scheduler.startswith("euler"):
|
||||
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
|
||||
# if self.train_turbo and not self.noise_scheduler.startswith("euler"):
|
||||
# raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
|
||||
|
||||
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
|
||||
self.do_cfg = kwargs.get('do_cfg', False)
|
||||
@@ -458,6 +458,7 @@ class ModelConfig:
|
||||
|
||||
# only for flux for now
|
||||
self.quantize = kwargs.get("quantize", False)
|
||||
self.quantize_te = kwargs.get("quantize_te", self.quantize)
|
||||
self.low_vram = kwargs.get("low_vram", False)
|
||||
self.attn_masking = kwargs.get("attn_masking", False)
|
||||
if self.attn_masking and not self.is_flux:
|
||||
@@ -827,7 +828,10 @@ class GenerateImageConfig:
|
||||
prompt += ' --gr ' + str(self.guidance_rescale)
|
||||
|
||||
# get gen info
|
||||
f.write(self.prompt)
|
||||
try:
|
||||
f.write(self.prompt)
|
||||
except Exception as e:
|
||||
print(f"Error writing prompt file. Prompt contains non-unicode characters. {e}")
|
||||
|
||||
def _process_prompt_string(self):
|
||||
# we will try to support all sd-scripts where we can
|
||||
|
||||
@@ -110,7 +110,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
return timesteps
|
||||
elif timestep_type == 'lognorm_blend':
|
||||
# disgtribute timestepd to the center/early and blend in linear
|
||||
alpha = 0.8
|
||||
alpha = 0.75
|
||||
|
||||
lognormal = LogNormal(loc=0, scale=0.333)
|
||||
|
||||
|
||||
@@ -696,7 +696,7 @@ class StableDiffusion:
|
||||
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
if self.model_config.quantize_te:
|
||||
print("Quantizing T5")
|
||||
quantize(text_encoder_2, weights=qfloat8)
|
||||
freeze(text_encoder_2)
|
||||
|
||||
Reference in New Issue
Block a user