mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized
This commit is contained in:
@@ -482,6 +482,8 @@ class TrainConfig:
|
||||
# will cache a blank prompt or the trigger word, and unload the text encoder to cpu
|
||||
# will make training faster and use less vram
|
||||
self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
|
||||
# will toggle all datasets to cache text embeddings
|
||||
self.cache_text_embeddings: bool = kwargs.get('cache_text_embeddings', False)
|
||||
# for swapping which parameters are trained during training
|
||||
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
|
||||
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
|
||||
@@ -1189,6 +1191,7 @@ def validate_configs(
|
||||
train_config: TrainConfig,
|
||||
model_config: ModelConfig,
|
||||
save_config: SaveConfig,
|
||||
dataset_configs: List[DatasetConfig]
|
||||
):
|
||||
if model_config.is_flux:
|
||||
if save_config.save_format != 'diffusers':
|
||||
@@ -1200,3 +1203,18 @@ def validate_configs(
|
||||
if train_config.bypass_guidance_embedding and train_config.do_guidance_loss:
|
||||
raise ValueError("Cannot bypass guidance embedding and do guidance loss at the same time. "
|
||||
"Please set bypass_guidance_embedding to False or do_guidance_loss to False.")
|
||||
|
||||
# see if any datasets are caching text embeddings
|
||||
is_caching_text_embeddings = any(dataset.cache_text_embeddings for dataset in dataset_configs)
|
||||
if is_caching_text_embeddings:
|
||||
|
||||
# check if they are doing differential output preservation
|
||||
if train_config.diff_output_preservation:
|
||||
raise ValueError("Cannot use differential output preservation with caching text embeddings. Please set diff_output_preservation to False.")
|
||||
|
||||
# make sure they are all cached
|
||||
for dataset in dataset_configs:
|
||||
if not dataset.cache_text_embeddings:
|
||||
raise ValueError("All datasets must have cache_text_embeddings set to True when caching text embeddings is enabled.")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user