mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 13:23:56 +00:00
Allow loading from a json detail file for captions
This commit is contained in:
@@ -33,7 +33,7 @@ transforms_dict = {
|
||||
'RandomEqualize': transforms.RandomEqualize(p=0.2),
|
||||
}
|
||||
|
||||
|
||||
caption_ext_list = ['txt', 'json', 'caption']
|
||||
class CaptionMixin:
|
||||
def get_caption_item(self: 'AiToolkitDataset', index):
|
||||
if not hasattr(self, 'caption_type'):
|
||||
@@ -45,20 +45,30 @@ class CaptionMixin:
|
||||
img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path
|
||||
# check if either has a prompt file
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
if not os.path.exists(prompt_path):
|
||||
img_path = img_path_or_tuple[1] if isinstance(img_path_or_tuple[1], str) else img_path_or_tuple[1].path
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
prompt_path = None
|
||||
for ext in caption_ext_list:
|
||||
prompt_path = path_no_ext + '.' + ext
|
||||
if os.path.exists(prompt_path):
|
||||
break
|
||||
else:
|
||||
img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
prompt_path = None
|
||||
for ext in caption_ext_list:
|
||||
prompt_path = path_no_ext + '.' + ext
|
||||
if os.path.exists(prompt_path):
|
||||
break
|
||||
|
||||
if os.path.exists(prompt_path):
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
# check if is json
|
||||
if prompt_path.endswith('.json'):
|
||||
prompt = json.loads(prompt)
|
||||
if 'caption' in prompt:
|
||||
prompt = prompt['caption']
|
||||
|
||||
# remove any newlines
|
||||
prompt = prompt.replace('\n', ', ')
|
||||
# remove new lines for all operating systems
|
||||
@@ -173,6 +183,10 @@ class CaptionProcessingDTOMixin:
|
||||
if os.path.exists(prompt_path):
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
if prompt_path.endswith('.json'):
|
||||
prompt = json.loads(prompt)
|
||||
if 'caption' in prompt:
|
||||
prompt = prompt['caption']
|
||||
# remove any newlines
|
||||
prompt = prompt.replace('\n', ', ')
|
||||
# remove new lines for all operating systems
|
||||
|
||||
Reference in New Issue
Block a user