mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Integrate dataset level trigger words and allow them to be cached. Default to global trigger if it is set.
This commit is contained in:
@@ -156,6 +156,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if raw_datasets is not None and len(raw_datasets) > 0:
|
if raw_datasets is not None and len(raw_datasets) > 0:
|
||||||
for raw_dataset in raw_datasets:
|
for raw_dataset in raw_datasets:
|
||||||
dataset = DatasetConfig(**raw_dataset)
|
dataset = DatasetConfig(**raw_dataset)
|
||||||
|
# handle trigger word per dataset
|
||||||
|
if dataset.trigger_word is None and self.trigger_word is not None:
|
||||||
|
dataset.trigger_word = self.trigger_word
|
||||||
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
|
||||||
if not is_caching:
|
if not is_caching:
|
||||||
self.is_latents_cached = False
|
self.is_latents_cached = False
|
||||||
|
|||||||
@@ -301,6 +301,7 @@ class CaptionProcessingDTOMixin:
|
|||||||
|
|
||||||
dataset_config: DatasetConfig = kwargs.get('dataset_config', None)
|
dataset_config: DatasetConfig = kwargs.get('dataset_config', None)
|
||||||
self.extra_values: List[float] = dataset_config.extra_values
|
self.extra_values: List[float] = dataset_config.extra_values
|
||||||
|
self.trigger_word = dataset_config.trigger_word
|
||||||
|
|
||||||
# todo allow for loading from sd-scripts style dict
|
# todo allow for loading from sd-scripts style dict
|
||||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
|
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]=None):
|
||||||
@@ -364,6 +365,13 @@ class CaptionProcessingDTOMixin:
|
|||||||
add_if_not_present=False,
|
add_if_not_present=False,
|
||||||
short_caption=False
|
short_caption=False
|
||||||
):
|
):
|
||||||
|
if trigger is None and self.trigger_word is not None:
|
||||||
|
trigger = self.trigger_word
|
||||||
|
|
||||||
|
if trigger is not None and not self.is_reg:
|
||||||
|
# add if not present if not regularization
|
||||||
|
add_if_not_present = True
|
||||||
|
|
||||||
if short_caption:
|
if short_caption:
|
||||||
raw_caption = self.raw_caption_short
|
raw_caption = self.raw_caption_short
|
||||||
else:
|
else:
|
||||||
@@ -408,7 +416,7 @@ class CaptionProcessingDTOMixin:
|
|||||||
|
|
||||||
# join back together
|
# join back together
|
||||||
caption = ', '.join(token_list)
|
caption = ', '.join(token_list)
|
||||||
# caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||||
|
|
||||||
if self.dataset_config.random_triggers:
|
if self.dataset_config.random_triggers:
|
||||||
num_triggers = self.dataset_config.random_triggers_max
|
num_triggers = self.dataset_config.random_triggers_max
|
||||||
@@ -1806,9 +1814,6 @@ class TextEmbeddingFileItemDTOMixin:
|
|||||||
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
|
# TODO: we need a way to cache all the other features like trigger words, DOP, etc. For now, we need to throw an error if not compatible.
|
||||||
if self.caption is None:
|
if self.caption is None:
|
||||||
self.load_caption()
|
self.load_caption()
|
||||||
# throw error is [trigger] in caption as we cannot inject it while caching
|
|
||||||
if '[trigger]' in self.caption:
|
|
||||||
raise Exception("Error: [trigger] in caption is not supported when caching text embeddings. Please remove it from the caption.")
|
|
||||||
item = OrderedDict([
|
item = OrderedDict([
|
||||||
("caption", self.caption),
|
("caption", self.caption),
|
||||||
("text_embedding_space_version", self.text_embedding_space_version),
|
("text_embedding_space_version", self.text_embedding_space_version),
|
||||||
|
|||||||
Reference in New Issue
Block a user