diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3d00735..5fa7937 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -460,6 +460,7 @@ class DatasetConfig: self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0)) self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) + self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped self.flip_x: bool = kwargs.get('flip_x', False) self.flip_y: bool = kwargs.get('flip_y', False) self.augments: List[str] = kwargs.get('augments', []) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 9a32814..20ffc3a 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -336,20 +336,27 @@ class CaptionProcessingDTOMixin: # remove empty strings token_list = [x for x in token_list if x] - if self.dataset_config.shuffle_tokens: - random.shuffle(token_list) - # handle token dropout if self.dataset_config.token_dropout_rate > 0 and not short_caption: new_token_list = [] - for token in token_list: - # get a random float form 0 to 1 - rand = random.random() - if rand > self.dataset_config.token_dropout_rate: - # keep the token + keep_tokens: int = self.dataset_config.keep_tokens + for idx, token in enumerate(token_list): + if idx < keep_tokens: new_token_list.append(token) + elif self.dataset_config.token_dropout_rate >= 1.0: + # drop the token + pass + else: + # get a random float form 0 to 1 + rand = random.random() + if rand > self.dataset_config.token_dropout_rate: + # keep the token + new_token_list.append(token) token_list = new_token_list + if self.dataset_config.shuffle_tokens: + random.shuffle(token_list) + # join back together caption = ', '.join(token_list) caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)