Remove dropout from cached text embeddings even if used specifies it so blank prompts are not cached.

This commit is contained in:
Jaret Burkett
2025-09-26 11:50:53 -06:00
parent e04f55c553
commit be990630b9

View File

@@ -344,6 +344,9 @@ class CaptionProcessingDTOMixin:
prompt = clean_caption(prompt)
if short_caption is not None:
short_caption = clean_caption(short_caption)
if prompt.strip() == '' and self.dataset_config.default_caption is not None:
prompt = self.dataset_config.default_caption
else:
prompt = ''
if self.dataset_config.default_caption is not None:
@@ -379,7 +382,7 @@ class CaptionProcessingDTOMixin:
if raw_caption is None:
raw_caption = ''
# handle dropout
if self.dataset_config.caption_dropout_rate > 0 and not short_caption:
if self.dataset_config.caption_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings:
# get a random float form 0 to 1
rand = random.random()
if rand < self.dataset_config.caption_dropout_rate:
@@ -394,7 +397,7 @@ class CaptionProcessingDTOMixin:
token_list = [x for x in token_list if x]
# handle token dropout
if self.dataset_config.token_dropout_rate > 0 and not short_caption:
if self.dataset_config.token_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings:
new_token_list = []
keep_tokens: int = self.dataset_config.keep_tokens
for idx, token in enumerate(token_list):
@@ -441,7 +444,8 @@ class CaptionProcessingDTOMixin:
token_list = [x for x in token_list if x]
random.shuffle(token_list)
caption = ', '.join(token_list)
if caption == '':
pass
return caption