mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Integrate dataset level trigger words and allow them to be cached. Default to global trigger if it is set.
This commit is contained in:
@@ -156,6 +156,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||
for raw_dataset in raw_datasets:
|
||||
dataset = DatasetConfig(**raw_dataset)
|
||||
# handle trigger word per dataset
|
||||
if dataset.trigger_word is None and self.trigger_word is not None:
|
||||
dataset.trigger_word = self.trigger_word
|
||||
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
||||
if not is_caching:
|
||||
self.is_latents_cached = False
|
||||
|
||||
@@ -301,6 +301,7 @@ class CaptionProcessingDTOMixin:
|
||||
|
||||
dataset_config: DatasetConfig = kwargs.get('dataset_config', None)
|
||||
self.extra_values: List[float] = dataset_config.extra_values
|
||||
self.trigger_word = dataset_config.trigger_word
|
||||
|
||||
# todo allow for loading from sd-scripts style dict
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
|
||||
@@ -364,6 +365,13 @@ class CaptionProcessingDTOMixin:
|
||||
add_if_not_present=False,
|
||||
short_caption=False
|
||||
):
|
||||
if trigger is None and self.trigger_word is not None:
|
||||
trigger = self.trigger_word
|
||||
|
||||
if trigger is not None and not self.is_reg:
|
||||
# add if not present if not regularization
|
||||
add_if_not_present = True
|
||||
|
||||
if short_caption:
|
||||
raw_caption = self.raw_caption_short
|
||||
else:
|
||||
@@ -408,7 +416,7 @@ class CaptionProcessingDTOMixin:
|
||||
|
||||
# join back together
|
||||
caption = ', '.join(token_list)
|
||||
# caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||
|
||||
if self.dataset_config.random_triggers:
|
||||
num_triggers = self.dataset_config.random_triggers_max
|
||||
@@ -1806,9 +1814,6 @@ class TextEmbeddingFileItemDTOMixin:
|
||||
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
|
||||
if self.caption is None:
|
||||
self.load_caption()
|
||||
# throw error is [trigger] in caption as we cannot inject it while caching
|
||||
if '[trigger]' in self.caption:
|
||||
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
|
||||
item = OrderedDict([
|
||||
("caption", self.caption),
|
||||
("text_embedding_space_version", self.text_embedding_space_version),
|
||||
|
||||
Reference in New Issue
Block a user