diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 56b008d2..d8c6c6bb 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -714,7 +714,10 @@ class DatasetConfig: random_triggers = [line for line in random_triggers if line.strip() != ''] self.random_triggers: List[str] = random_triggers self.random_triggers_max: int = kwargs.get('random_triggers_max', 1) - self.caption_ext: str = kwargs.get('caption_ext', None) + self.caption_ext: str = kwargs.get('caption_ext', '.txt') + # if caption_ext doesnt start with a dot, add it + if self.caption_ext and not self.caption_ext.startswith('.'): + self.caption_ext = '.' + self.caption_ext self.random_scale: bool = kwargs.get('random_scale', False) self.random_crop: bool = kwargs.get('random_crop', False) self.resolution: int = kwargs.get('resolution', 512) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 79131c7b..395747e6 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -62,7 +62,6 @@ transforms_dict = { 'RandomEqualize': transforms.RandomEqualize(p=0.2), } -caption_ext_list = ['txt', 'json', 'caption'] img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] @@ -91,15 +90,16 @@ def standardize_images(images): return standardized_images def clean_caption(caption): - # remove any newlines - caption = caption.replace('\n', ', ') - # remove new lines for all operating systems - caption = caption.replace('\r', ', ') - caption_split = caption.split(',') - # remove empty strings - caption_split = [p.strip() for p in caption_split if p.strip()] - # join back together - caption = ', '.join(caption_split) + # this doesnt make any sense anymore in a world that is not based on comma seperated tokens + # # remove any newlines + # caption = caption.replace('\n', ', ') + # # remove new lines for all operating systems + # caption = caption.replace('\r', ', ') + # caption_split = caption.split(',') + # # remove empty strings + # caption_split = [p.strip() for p in caption_split if p.strip()] + # # join back together + # caption = ', '.join(caption_split) return caption @@ -115,22 +115,17 @@ class CaptionMixin: # check if either has a prompt file path_no_ext = os.path.splitext(img_path)[0] prompt_path = None - for ext in caption_ext_list: - prompt_path = path_no_ext + '.' + ext - if os.path.exists(prompt_path): - break + ext = self.dataset_config.caption_ext + prompt_path = path_no_ext + ext else: img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path # see if prompt file exists path_no_ext = os.path.splitext(img_path)[0] - prompt_path = None - for ext in caption_ext_list: - prompt_path = path_no_ext + '.' + ext - if os.path.exists(prompt_path): - break + prompt_path = path_no_ext + ext # allow folders to have a default prompt default_prompt_path = os.path.join(os.path.dirname(img_path), 'default.txt') + default_prompt_path_with_ext = os.path.join(os.path.dirname(img_path), 'default' + ext) if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: @@ -141,6 +136,10 @@ class CaptionMixin: if 'caption' in prompt: prompt = prompt['caption'] + prompt = clean_caption(prompt) + elif os.path.exists(default_prompt_path_with_ext): + with open(default_prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read() prompt = clean_caption(prompt) elif os.path.exists(default_prompt_path): with open(default_prompt_path, 'r', encoding='utf-8') as f: