mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Remove dropout from cached text embeddings even if used specifies it so blank prompts are not cached.
This commit is contained in:
@@ -344,6 +344,9 @@ class CaptionProcessingDTOMixin:
|
||||
prompt = clean_caption(prompt)
|
||||
if short_caption is not None:
|
||||
short_caption = clean_caption(short_caption)
|
||||
|
||||
if prompt.strip() == '' and self.dataset_config.default_caption is not None:
|
||||
prompt = self.dataset_config.default_caption
|
||||
else:
|
||||
prompt = ''
|
||||
if self.dataset_config.default_caption is not None:
|
||||
@@ -379,7 +382,7 @@ class CaptionProcessingDTOMixin:
|
||||
if raw_caption is None:
|
||||
raw_caption = ''
|
||||
# handle dropout
|
||||
if self.dataset_config.caption_dropout_rate > 0 and not short_caption:
|
||||
if self.dataset_config.caption_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings:
|
||||
# get a random float form 0 to 1
|
||||
rand = random.random()
|
||||
if rand < self.dataset_config.caption_dropout_rate:
|
||||
@@ -394,7 +397,7 @@ class CaptionProcessingDTOMixin:
|
||||
token_list = [x for x in token_list if x]
|
||||
|
||||
# handle token dropout
|
||||
if self.dataset_config.token_dropout_rate > 0 and not short_caption:
|
||||
if self.dataset_config.token_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings:
|
||||
new_token_list = []
|
||||
keep_tokens: int = self.dataset_config.keep_tokens
|
||||
for idx, token in enumerate(token_list):
|
||||
@@ -441,7 +444,8 @@ class CaptionProcessingDTOMixin:
|
||||
token_list = [x for x in token_list if x]
|
||||
random.shuffle(token_list)
|
||||
caption = ', '.join(token_list)
|
||||
|
||||
if caption == '':
|
||||
pass
|
||||
return caption
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user