mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added support for caching text embeddings. This is just initial support and will probably fail for some models. Still needs to be ompimized
This commit is contained in:
@@ -145,7 +145,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
raw_datasets = preprocess_dataset_raw_config(raw_datasets)
|
||||
self.datasets = None
|
||||
self.datasets_reg = None
|
||||
self.dataset_configs: List[DatasetConfig] = []
|
||||
self.params = []
|
||||
|
||||
# add dataset text embedding cache to their config
|
||||
if self.train_config.cache_text_embeddings:
|
||||
for raw_dataset in raw_datasets:
|
||||
raw_dataset['cache_text_embeddings'] = True
|
||||
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
for raw_dataset in raw_datasets:
|
||||
dataset = DatasetConfig(**raw_dataset)
|
||||
@@ -160,6 +167,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.datasets is None:
|
||||
self.datasets = []
|
||||
self.datasets.append(dataset)
|
||||
self.dataset_configs.append(dataset)
|
||||
|
||||
self.is_caching_text_embeddings = any(
|
||||
dataset.cache_text_embeddings for dataset in self.dataset_configs
|
||||
)
|
||||
|
||||
# cannot train trigger word if caching text embeddings
|
||||
if self.is_caching_text_embeddings and self.trigger_word is not None:
|
||||
raise ValueError("Cannot train trigger word if caching text embeddings. Please remove the trigger word or disable text embedding caching.")
|
||||
|
||||
self.embed_config = None
|
||||
embedding_raw = self.get_conf('embedding', None)
|
||||
@@ -206,7 +222,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_decorator=self.decorator_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
|
||||
require_grads=False # we ensure them later
|
||||
)
|
||||
|
||||
@@ -220,7 +236,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_embedding=self.embed_config is not None,
|
||||
train_decorator=self.decorator_config is not None,
|
||||
train_refiner=self.train_config.train_refiner,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder,
|
||||
unload_text_encoder=self.train_config.unload_text_encoder or self.is_caching_text_embeddings,
|
||||
require_grads=True # We check for grads when getting params
|
||||
)
|
||||
|
||||
@@ -235,7 +251,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.snr_gos: Union[LearnableSNRGamma, None] = None
|
||||
self.ema: ExponentialMovingAverage = None
|
||||
|
||||
validate_configs(self.train_config, self.model_config, self.save_config)
|
||||
validate_configs(self.train_config, self.model_config, self.save_config, self.dataset_configs)
|
||||
|
||||
do_profiler = self.get_conf('torch_profiler', False)
|
||||
self.torch_profiler = None if not do_profiler else torch.profiler.profile(
|
||||
|
||||
Reference in New Issue
Block a user