mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-04 20:49:58 +00:00
44 lines
1.8 KiB
Python
44 lines
1.8 KiB
Python
import os
|
|
|
|
|
|
class CaptionMixin:
|
|
def get_caption_item(self, index):
|
|
if not hasattr(self, 'caption_type'):
|
|
raise Exception('caption_type not found on class instance')
|
|
if not hasattr(self, 'file_list'):
|
|
raise Exception('file_list not found on class instance')
|
|
img_path_or_tuple = self.file_list[index]
|
|
if isinstance(img_path_or_tuple, tuple):
|
|
# check if either has a prompt file
|
|
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
|
|
prompt_path = path_no_ext + '.txt'
|
|
if not os.path.exists(prompt_path):
|
|
path_no_ext = os.path.splitext(img_path_or_tuple[1])[0]
|
|
prompt_path = path_no_ext + '.txt'
|
|
else:
|
|
img_path = img_path_or_tuple
|
|
# see if prompt file exists
|
|
path_no_ext = os.path.splitext(img_path)[0]
|
|
prompt_path = path_no_ext + '.txt'
|
|
|
|
if os.path.exists(prompt_path):
|
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
|
prompt = f.read()
|
|
# remove any newlines
|
|
prompt = prompt.replace('\n', ', ')
|
|
# remove new lines for all operating systems
|
|
prompt = prompt.replace('\r', ', ')
|
|
prompt_split = prompt.split(',')
|
|
# remove empty strings
|
|
prompt_split = [p.strip() for p in prompt_split if p.strip()]
|
|
# join back together
|
|
prompt = ', '.join(prompt_split)
|
|
else:
|
|
prompt = ''
|
|
# get default_prompt if it exists on the class instance
|
|
if hasattr(self, 'default_prompt'):
|
|
prompt = self.default_prompt
|
|
if hasattr(self, 'default_caption'):
|
|
prompt = self.default_caption
|
|
return prompt
|