mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-05 04:59:56 +00:00
Added keep tokens to keep so many tokens in a prompt when dropping
This commit is contained in:
@@ -460,6 +460,7 @@ class DatasetConfig:
|
||||
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
|
||||
@@ -336,20 +336,27 @@ class CaptionProcessingDTOMixin:
|
||||
# remove empty strings
|
||||
token_list = [x for x in token_list if x]
|
||||
|
||||
if self.dataset_config.shuffle_tokens:
|
||||
random.shuffle(token_list)
|
||||
|
||||
# handle token dropout
|
||||
if self.dataset_config.token_dropout_rate > 0 and not short_caption:
|
||||
new_token_list = []
|
||||
for token in token_list:
|
||||
# get a random float form 0 to 1
|
||||
rand = random.random()
|
||||
if rand > self.dataset_config.token_dropout_rate:
|
||||
# keep the token
|
||||
keep_tokens: int = self.dataset_config.keep_tokens
|
||||
for idx, token in enumerate(token_list):
|
||||
if idx < keep_tokens:
|
||||
new_token_list.append(token)
|
||||
elif self.dataset_config.token_dropout_rate >= 1.0:
|
||||
# drop the token
|
||||
pass
|
||||
else:
|
||||
# get a random float form 0 to 1
|
||||
rand = random.random()
|
||||
if rand > self.dataset_config.token_dropout_rate:
|
||||
# keep the token
|
||||
new_token_list.append(token)
|
||||
token_list = new_token_list
|
||||
|
||||
if self.dataset_config.shuffle_tokens:
|
||||
random.shuffle(token_list)
|
||||
|
||||
# join back together
|
||||
caption = ', '.join(token_list)
|
||||
caption = inject_trigger_into_prompt(caption, trigger, to_replace_list, add_if_not_present)
|
||||
|
||||
Reference in New Issue
Block a user