Integrate dataset level trigger words and allow them to be cached. Default to global trigger if it is set.

This commit is contained in:
Jaret Burkett
2025-09-18 03:29:18 -06:00
parent 3cdf50cbfc
commit 390e21bec6
2 changed files with 12 additions and 4 deletions

View File

@@ -156,6 +156,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
if raw_datasets is not None and len(raw_datasets) > 0:
for raw_dataset in raw_datasets:
dataset = DatasetConfig(**raw_dataset)
# handle trigger word per dataset
if dataset.trigger_word is None and self.trigger_word is not None:
dataset.trigger_word = self.trigger_word
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
if not is_caching:
self.is_latents_cached = False

View File

@@ -301,6 +301,7 @@ class CaptionProcessingDTOMixin:
dataset_config: DatasetConfig = kwargs.get('dataset_config', None)
self.extra_values: List[float] = dataset_config.extra_values
self.trigger_word = dataset_config.trigger_word
# todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
@@ -364,6 +365,13 @@ class CaptionProcessingDTOMixin:
add_if_not_present=False,
short_caption=False
):
if trigger is None and self.trigger_word is not None:
trigger = self.trigger_word
if trigger is not None and not self.is_reg:
# add if not present if not regularization
add_if_not_present = True
if short_caption:
raw_caption = self.raw_caption_short
else:
@@ -408,7 +416,7 @@ class CaptionProcessingDTOMixin:
# join back together
caption = ', '.join(token_list)
# caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
if self.dataset_config.random_triggers:
num_triggers = self.dataset_config.random_triggers_max
@@ -1806,9 +1814,6 @@ class TextEmbeddingFileItemDTOMixin:
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
if self.caption is None:
self.load_caption()
# throw error is [trigger] in caption as we cannot inject it while caching
if '[trigger]' in self.caption:
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
item = OrderedDict([
("caption", self.caption),
("text_embedding_space_version", self.text_embedding_space_version),