Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized

This commit is contained in:
Jaret Burkett
2025-08-07 10:27:55 -06:00
parent 4c4a10d439
commit bb6db3d635
16 changed files with 485 additions and 195 deletions

View File

@@ -145,7 +145,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
self.datasets = None
self.datasets_reg = None
self.dataset_configs: List[DatasetConfig] = []
self.params = []
# add dataset text embedding cache to their config
if self.train_config.cache_text_embeddings:
for raw_dataset in raw_datasets:
raw_dataset['cache_text_embeddings'] = True
if raw_datasets is not None and len(raw_datasets) > 0:
for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset)
@@ -160,6 +167,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.datasets is None:
self.datasets = []
self.datasets.append(dataset)
self.dataset_configs.append(dataset)
self.is_caching_text_embeddings = any(
dataset.cache_text_embeddings for dataset in self.dataset_configs
)
# cannot train trigger word if caching text embeddings
if self.is_caching_text_embeddings and self.trigger_word is not None:
raise ValueError("Cannot train trigger word if caching text embeddings. Please remove the trigger word or disable text embedding caching.")
self.embed_config = None
embedding_raw = self.get_conf('embedding', None)
@@ -206,7 +222,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_embedding=self.embed_config is not None,
train_decorator=self.decorator_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder,
unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
require_grads=False # we ensure them later
)
@@ -220,7 +236,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_embedding=self.embed_config is not None,
train_decorator=self.decorator_config is not None,
train_refiner=self.train_config.train_refiner,
unload_text_encoder=self.train_config.unload_text_encoder,
unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
require_grads=True # We check for grads when getting params
)
@@ -235,7 +251,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.snr_gos: Union[LearnableSNRGamma, None] = None
self.ema: ExponentialMovingAverage = None
validate_configs(self.train_config, self.model_config, self.save_config)
validate_configs(self.train_config, self.model_config, self.save_config, self.dataset_configs)
do_profiler = self.get_conf('torch_profiler', False)
self.torch_profiler = None if not do_profiler else torch.profiler.profile(