Files
ai-toolkit/toolkit/dataloader_mixins.py
2023-08-22 21:02:38 -06:00

44 lines
1.8 KiB
Python

import os
class CaptionMixin:
def get_caption_item(self, index):
if not hasattr(self, 'caption_type'):
raise Exception('caption_type not found on class instance')
if not hasattr(self, 'file_list'):
raise Exception('file_list not found on class instance')
img_path_or_tuple = self.file_list[index]
if isinstance(img_path_or_tuple, tuple):
# check if either has a prompt file
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
prompt_path = path_no_ext + '.txt'
if not os.path.exists(prompt_path):
path_no_ext = os.path.splitext(img_path_or_tuple[1])[0]
prompt_path = path_no_ext + '.txt'
else:
img_path = img_path_or_tuple
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
# remove any newlines
prompt = prompt.replace('\n', ', ')
# remove new lines for all operating systems
prompt = prompt.replace('\r', ', ')
prompt_split = prompt.split(',')
# remove empty strings
prompt_split = [p.strip() for p in prompt_split if p.strip()]
# join back together
prompt = ', '.join(prompt_split)
else:
prompt = ''
# get default_prompt if it exists on the class instance
if hasattr(self, 'default_prompt'):
prompt = self.default_prompt
if hasattr(self, 'default_caption'):
prompt = self.default_caption
return prompt