mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Allow quantizing the te independently on flux. added lognorm_blend timestep schedule
This commit is contained in:
@@ -346,8 +346,8 @@ class TrainConfig:
|
|||||||
self.standardize_images = kwargs.get('standardize_images', False)
|
self.standardize_images = kwargs.get('standardize_images', False)
|
||||||
self.standardize_latents = kwargs.get('standardize_latents', False)
|
self.standardize_latents = kwargs.get('standardize_latents', False)
|
||||||
|
|
||||||
if self.train_turbo and not self.noise_scheduler.startswith("euler"):
|
# if self.train_turbo and not self.noise_scheduler.startswith("euler"):
|
||||||
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
|
# raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
|
||||||
|
|
||||||
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
|
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
|
||||||
self.do_cfg = kwargs.get('do_cfg', False)
|
self.do_cfg = kwargs.get('do_cfg', False)
|
||||||
@@ -458,6 +458,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
# only for flux for now
|
# only for flux for now
|
||||||
self.quantize = kwargs.get("quantize", False)
|
self.quantize = kwargs.get("quantize", False)
|
||||||
|
self.quantize_te = kwargs.get("quantize_te", self.quantize)
|
||||||
self.low_vram = kwargs.get("low_vram", False)
|
self.low_vram = kwargs.get("low_vram", False)
|
||||||
self.attn_masking = kwargs.get("attn_masking", False)
|
self.attn_masking = kwargs.get("attn_masking", False)
|
||||||
if self.attn_masking and not self.is_flux:
|
if self.attn_masking and not self.is_flux:
|
||||||
@@ -827,7 +828,10 @@ class GenerateImageConfig:
|
|||||||
prompt += ' --gr ' + str(self.guidance_rescale)
|
prompt += ' --gr ' + str(self.guidance_rescale)
|
||||||
|
|
||||||
# get gen info
|
# get gen info
|
||||||
f.write(self.prompt)
|
try:
|
||||||
|
f.write(self.prompt)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error writing prompt file. Prompt contains non-unicode characters. {e}")
|
||||||
|
|
||||||
def _process_prompt_string(self):
|
def _process_prompt_string(self):
|
||||||
# we will try to support all sd-scripts where we can
|
# we will try to support all sd-scripts where we can
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
|||||||
return timesteps
|
return timesteps
|
||||||
elif timestep_type == 'lognorm_blend':
|
elif timestep_type == 'lognorm_blend':
|
||||||
# disgtribute timestepd to the center/early and blend in linear
|
# disgtribute timestepd to the center/early and blend in linear
|
||||||
alpha = 0.8
|
alpha = 0.75
|
||||||
|
|
||||||
lognormal = LogNormal(loc=0, scale=0.333)
|
lognormal = LogNormal(loc=0, scale=0.333)
|
||||||
|
|
||||||
|
|||||||
@@ -696,7 +696,7 @@ class StableDiffusion:
|
|||||||
text_encoder_2.to(self.device_torch, dtype=dtype)
|
text_encoder_2.to(self.device_torch, dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
if self.model_config.quantize:
|
if self.model_config.quantize_te:
|
||||||
print("Quantizing T5")
|
print("Quantizing T5")
|
||||||
quantize(text_encoder_2, weights=qfloat8)
|
quantize(text_encoder_2, weights=qfloat8)
|
||||||
freeze(text_encoder_2)
|
freeze(text_encoder_2)
|
||||||
|
|||||||
Reference in New Issue
Block a user