Allow full control of caption extensions

This commit is contained in:
Jaret Burkett
2025-04-10 07:42:04 -06:00
parent 96ba2fd129
commit d8bdc03256
2 changed files with 22 additions and 20 deletions

View File

@@ -714,7 +714,10 @@ class DatasetConfig:
random_triggers = [line for line in random_triggers if line.strip() != '']
self.random_triggers: List[str] = random_triggers
self.random_triggers_max: int = kwargs.get('random_triggers_max', 1)
self.caption_ext: str = kwargs.get('caption_ext', None)
self.caption_ext: str = kwargs.get('caption_ext', '.txt')
# if caption_ext doesnt start with a dot, add it
if self.caption_ext and not self.caption_ext.startswith('.'):
self.caption_ext = '.' + self.caption_ext
self.random_scale: bool = kwargs.get('random_scale', False)
self.random_crop: bool = kwargs.get('random_crop', False)
self.resolution: int = kwargs.get('resolution', 512)

View File

@@ -62,7 +62,6 @@ transforms_dict = {
'RandomEqualize': transforms.RandomEqualize(p=0.2),
}
caption_ext_list = ['txt', 'json', 'caption']
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
@@ -91,15 +90,16 @@ def standardize_images(images):
return standardized_images
def clean_caption(caption):
# remove any newlines
caption = caption.replace('\n', ', ')
# remove new lines for all operating systems
caption = caption.replace('\r', ', ')
caption_split = caption.split(',')
# remove empty strings
caption_split = [p.strip() for p in caption_split if p.strip()]
# join back together
caption = ', '.join(caption_split)
# this doesnt make any sense anymore in a world that is not based on comma seperated tokens
# # remove any newlines
# caption = caption.replace('\n', ', ')
# # remove new lines for all operating systems
# caption = caption.replace('\r', ', ')
# caption_split = caption.split(',')
# # remove empty strings
# caption_split = [p.strip() for p in caption_split if p.strip()]
# # join back together
# caption = ', '.join(caption_split)
return caption
@@ -115,22 +115,17 @@ class CaptionMixin:
# check if either has a prompt file
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = None
for ext in caption_ext_list:
prompt_path = path_no_ext + '.' + ext
if os.path.exists(prompt_path):
break
ext = self.dataset_config.caption_ext
prompt_path = path_no_ext + ext
else:
img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = None
for ext in caption_ext_list:
prompt_path = path_no_ext + '.' + ext
if os.path.exists(prompt_path):
break
prompt_path = path_no_ext + ext
# allow folders to have a default prompt
default_prompt_path = os.path.join(os.path.dirname(img_path), 'default.txt')
default_prompt_path_with_ext = os.path.join(os.path.dirname(img_path), 'default' + ext)
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
@@ -141,6 +136,10 @@ class CaptionMixin:
if 'caption' in prompt:
prompt = prompt['caption']
prompt = clean_caption(prompt)
elif os.path.exists(default_prompt_path_with_ext):
with open(default_prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
prompt = clean_caption(prompt)
elif os.path.exists(default_prompt_path):
with open(default_prompt_path, 'r', encoding='utf-8') as f: