From 320e109c5fd64402a7f966457a193f072ae8f605 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 2 Oct 2023 13:25:09 -0600 Subject: [PATCH] Allow loading from a json detail file for captions --- toolkit/dataloader_mixins.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 7742d455..4c9c902f 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -33,7 +33,7 @@ transforms_dict = { 'RandomEqualize': transforms.RandomEqualize(p=0.2), } - +caption_ext_list = ['txt', 'json', 'caption'] class CaptionMixin: def get_caption_item(self: 'AiToolkitDataset', index): if not hasattr(self, 'caption_type'): @@ -45,20 +45,30 @@ class CaptionMixin: img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path # check if either has a prompt file path_no_ext = os.path.splitext(img_path)[0] - prompt_path = path_no_ext + '.txt' - if not os.path.exists(prompt_path): - img_path = img_path_or_tuple[1] if isinstance(img_path_or_tuple[1], str) else img_path_or_tuple[1].path - path_no_ext = os.path.splitext(img_path)[0] - prompt_path = path_no_ext + '.txt' + prompt_path = None + for ext in caption_ext_list: + prompt_path = path_no_ext + '.' + ext + if os.path.exists(prompt_path): + break else: img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path # see if prompt file exists path_no_ext = os.path.splitext(img_path)[0] - prompt_path = path_no_ext + '.txt' + prompt_path = None + for ext in caption_ext_list: + prompt_path = path_no_ext + '.' + ext + if os.path.exists(prompt_path): + break if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: prompt = f.read() + # check if is json + if prompt_path.endswith('.json'): + prompt = json.loads(prompt) + if 'caption' in prompt: + prompt = prompt['caption'] + # remove any newlines prompt = prompt.replace('\n', ', ') # remove new lines for all operating systems @@ -173,6 +183,10 @@ class CaptionProcessingDTOMixin: if os.path.exists(prompt_path): with open(prompt_path, 'r', encoding='utf-8') as f: prompt = f.read() + if prompt_path.endswith('.json'): + prompt = json.loads(prompt) + if 'caption' in prompt: + prompt = prompt['caption'] # remove any newlines prompt = prompt.replace('\n', ', ') # remove new lines for all operating systems