mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Remove dropout from cached text embeddings even if used specifies it so blank prompts are not cached.
This commit is contained in:
@@ -344,6 +344,9 @@ class CaptionProcessingDTOMixin:
|
|||||||
prompt = clean_caption(prompt)
|
prompt = clean_caption(prompt)
|
||||||
if short_caption is not None:
|
if short_caption is not None:
|
||||||
short_caption = clean_caption(short_caption)
|
short_caption = clean_caption(short_caption)
|
||||||
|
|
||||||
|
if prompt.strip() == '' and self.dataset_config.default_caption is not None:
|
||||||
|
prompt = self.dataset_config.default_caption
|
||||||
else:
|
else:
|
||||||
prompt = ''
|
prompt = ''
|
||||||
if self.dataset_config.default_caption is not None:
|
if self.dataset_config.default_caption is not None:
|
||||||
@@ -379,7 +382,7 @@ class CaptionProcessingDTOMixin:
|
|||||||
if raw_caption is None:
|
if raw_caption is None:
|
||||||
raw_caption = ''
|
raw_caption = ''
|
||||||
# handle dropout
|
# handle dropout
|
||||||
if self.dataset_config.caption_dropout_rate > 0 and not short_caption:
|
if self.dataset_config.caption_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings:
|
||||||
# get a random float form 0 to 1
|
# get a random float form 0 to 1
|
||||||
rand = random.random()
|
rand = random.random()
|
||||||
if rand < self.dataset_config.caption_dropout_rate:
|
if rand < self.dataset_config.caption_dropout_rate:
|
||||||
@@ -394,7 +397,7 @@ class CaptionProcessingDTOMixin:
|
|||||||
token_list = [x for x in token_list if x]
|
token_list = [x for x in token_list if x]
|
||||||
|
|
||||||
# handle token dropout
|
# handle token dropout
|
||||||
if self.dataset_config.token_dropout_rate > 0 and not short_caption:
|
if self.dataset_config.token_dropout_rate > 0 and not short_caption and not self.dataset_config.cache_text_embeddings:
|
||||||
new_token_list = []
|
new_token_list = []
|
||||||
keep_tokens: int = self.dataset_config.keep_tokens
|
keep_tokens: int = self.dataset_config.keep_tokens
|
||||||
for idx, token in enumerate(token_list):
|
for idx, token in enumerate(token_list):
|
||||||
@@ -441,7 +444,8 @@ class CaptionProcessingDTOMixin:
|
|||||||
token_list = [x for x in token_list if x]
|
token_list = [x for x in token_list if x]
|
||||||
random.shuffle(token_list)
|
random.shuffle(token_list)
|
||||||
caption = ', '.join(token_list)
|
caption = ', '.join(token_list)
|
||||||
|
if caption == '':
|
||||||
|
pass
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user