mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Allow full control of caption extensions
This commit is contained in:
@@ -714,7 +714,10 @@ class DatasetConfig:
|
||||
random_triggers = [line for line in random_triggers if line.strip() != '']
|
||||
self.random_triggers: List[str] = random_triggers
|
||||
self.random_triggers_max: int = kwargs.get('random_triggers_max', 1)
|
||||
self.caption_ext: str = kwargs.get('caption_ext', None)
|
||||
self.caption_ext: str = kwargs.get('caption_ext', '.txt')
|
||||
# if caption_ext doesnt start with a dot, add it
|
||||
if self.caption_ext and not self.caption_ext.startswith('.'):
|
||||
self.caption_ext = '.' + self.caption_ext
|
||||
self.random_scale: bool = kwargs.get('random_scale', False)
|
||||
self.random_crop: bool = kwargs.get('random_crop', False)
|
||||
self.resolution: int = kwargs.get('resolution', 512)
|
||||
|
||||
@@ -62,7 +62,6 @@ transforms_dict = {
|
||||
'RandomEqualize': transforms.RandomEqualize(p=0.2),
|
||||
}
|
||||
|
||||
caption_ext_list = ['txt', 'json', 'caption']
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
|
||||
|
||||
@@ -91,15 +90,16 @@ def standardize_images(images):
|
||||
return standardized_images
|
||||
|
||||
def clean_caption(caption):
|
||||
# remove any newlines
|
||||
caption = caption.replace('\n', ', ')
|
||||
# remove new lines for all operating systems
|
||||
caption = caption.replace('\r', ', ')
|
||||
caption_split = caption.split(',')
|
||||
# remove empty strings
|
||||
caption_split = [p.strip() for p in caption_split if p.strip()]
|
||||
# join back together
|
||||
caption = ', '.join(caption_split)
|
||||
# this doesnt make any sense anymore in a world that is not based on comma seperated tokens
|
||||
# # remove any newlines
|
||||
# caption = caption.replace('\n', ', ')
|
||||
# # remove new lines for all operating systems
|
||||
# caption = caption.replace('\r', ', ')
|
||||
# caption_split = caption.split(',')
|
||||
# # remove empty strings
|
||||
# caption_split = [p.strip() for p in caption_split if p.strip()]
|
||||
# # join back together
|
||||
# caption = ', '.join(caption_split)
|
||||
return caption
|
||||
|
||||
|
||||
@@ -115,22 +115,17 @@ class CaptionMixin:
|
||||
# check if either has a prompt file
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = None
|
||||
for ext in caption_ext_list:
|
||||
prompt_path = path_no_ext + '.' + ext
|
||||
if os.path.exists(prompt_path):
|
||||
break
|
||||
ext = self.dataset_config.caption_ext
|
||||
prompt_path = path_no_ext + ext
|
||||
else:
|
||||
img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = None
|
||||
for ext in caption_ext_list:
|
||||
prompt_path = path_no_ext + '.' + ext
|
||||
if os.path.exists(prompt_path):
|
||||
break
|
||||
prompt_path = path_no_ext + ext
|
||||
|
||||
# allow folders to have a default prompt
|
||||
default_prompt_path = os.path.join(os.path.dirname(img_path), 'default.txt')
|
||||
default_prompt_path_with_ext = os.path.join(os.path.dirname(img_path), 'default' + ext)
|
||||
|
||||
if os.path.exists(prompt_path):
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
@@ -141,6 +136,10 @@ class CaptionMixin:
|
||||
if 'caption' in prompt:
|
||||
prompt = prompt['caption']
|
||||
|
||||
prompt = clean_caption(prompt)
|
||||
elif os.path.exists(default_prompt_path_with_ext):
|
||||
with open(default_prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
prompt = clean_caption(prompt)
|
||||
elif os.path.exists(default_prompt_path):
|
||||
with open(default_prompt_path, 'r', encoding='utf-8') as f:
|
||||
|
||||
Reference in New Issue
Block a user