Added keep tokens to keep so many tokens in a prompt when dropping

This commit is contained in:
Jaret Burkett
2024-03-18 13:18:25 -06:00
parent 89f4bcad2e
commit 9c1cc9641e
2 changed files with 16 additions and 8 deletions

View File

@@ -460,6 +460,7 @@ class DatasetConfig:
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])

View File

@@ -336,20 +336,27 @@ class CaptionProcessingDTOMixin:
# remove empty strings
token_list = [x for x in token_list if x]
if self.dataset_config.shuffle_tokens:
random.shuffle(token_list)
# handle token dropout
if self.dataset_config.token_dropout_rate > 0 and not short_caption:
new_token_list = []
for token in token_list:
# get a random float form 0 to 1
rand = random.random()
if rand > self.dataset_config.token_dropout_rate:
# keep the token
keep_tokens: int = self.dataset_config.keep_tokens
for idx, token in enumerate(token_list):
if idx < keep_tokens:
new_token_list.append(token)
elif self.dataset_config.token_dropout_rate >= 1.0:
# drop the token
pass
else:
# get a random float form 0 to 1
rand = random.random()
if rand > self.dataset_config.token_dropout_rate:
# keep the token
new_token_list.append(token)
token_list = new_token_list
if self.dataset_config.shuffle_tokens:
random.shuffle(token_list)
# join back together
caption = ', '.join(token_list)
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)