From 390e21bec644cc4cf07c933d77e3a619a485bd02 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 18 Sep 2025 03:29:18 -0600 Subject: [PATCH] Integrate dataset level trigger words and allow them to be cached. Default to global trigger if it is set. --- jobs/process/BaseSDTrainProcess.py | 3 +++ toolkit/dataloader_mixins.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 20a85e55..e9b3d260 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -156,6 +156,9 @@ class BaseSDTrainProcess(BaseTrainProcess): if raw_datasets is not None and len(raw_datasets) > 0: for raw_dataset in raw_datasets: dataset = DatasetConfig(**raw_dataset) + # handle trigger word per dataset + if dataset.trigger_word is None and self.trigger_word is not None: + dataset.trigger_word = self.trigger_word is_caching = dataset.cache_latents or dataset.cache_latents_to_disk if not is_caching: self.is_latents_cached = False diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 6d231f49..29b2aced 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -301,6 +301,7 @@ class CaptionProcessingDTOMixin: dataset_config: DatasetConfig = kwargs.get('dataset_config', None) self.extra_values: List[float] = dataset_config.extra_values + self.trigger_word = dataset_config.trigger_word # todo allow for loading from sd-scripts style dict def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None): @@ -364,6 +365,13 @@ class CaptionProcessingDTOMixin: add_if_not_present=False, short_caption=False ): + if trigger is None and self.trigger_word is not None: + trigger = self.trigger_word + + if trigger is not None and not self.is_reg: + # add if not present if not regularization + add_if_not_present = True + if short_caption: raw_caption = self.raw_caption_short else: @@ -408,7 +416,7 @@ class CaptionProcessingDTOMixin: # join back together caption = ', '.join(token_list) - # caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) + caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present) if self.dataset_config.random_triggers: num_triggers = self.dataset_config.random_triggers_max @@ -1806,9 +1814,6 @@ class TextEmbeddingFileItemDTOMixin: # TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible. if self.caption is None: self.load_caption() - # throw error is [trigger] in caption as we cannot inject it while caching - if '[trigger]' in self.caption: - raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.") item = OrderedDict([ ("caption", self.caption), ("text_embedding_space_version", self.text_embedding_space_version),