Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized

This commit is contained in:
Jaret Burkett
2025-08-07 10:27:55 -06:00
parent 4c4a10d439
commit bb6db3d635
16 changed files with 485 additions and 195 deletions

View File

@@ -482,6 +482,8 @@ class TrainConfig:
# will cache a blank prompt or the trigger word, and unload the text encoder to cpu
# will make training faster and use less vram
self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
# will toggle all datasets to cache text embeddings
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
# for swapping which parameters are trained during training
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
@@ -1189,6 +1191,7 @@ def validate_configs(
train_config: TrainConfig,
model_config: ModelConfig,
save_config: SaveConfig,
dataset_configs: List[DatasetConfig]
):
if model_config.is_flux:
if save_config.save_format != 'diffusers':
@@ -1200,3 +1203,18 @@ def validate_configs(
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
# see if any datasets are caching text embeddings
is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs)
if is_caching_text_embeddings:
# check if they are doing differential output preservation
if train_config.diff_output_preservation:
raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.")
# make sure they are all cached
for dataset in dataset_configs:
if not dataset.cache_text_embeddings:
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")