mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 09:44:02 +00:00
Added base for ultimate slider. WIP
This commit is contained in:
@@ -163,7 +163,7 @@ class PairedImageDataset(Dataset):
|
||||
self.pos_file_list = [os.path.join(self.pos_folder, file) for file in os.listdir(self.pos_folder) if
|
||||
file.lower().endswith(supported_exts)]
|
||||
self.neg_file_list = [os.path.join(self.neg_folder, file) for file in os.listdir(self.neg_folder) if
|
||||
file.lower().endswith(supported_exts)]
|
||||
file.lower().endswith(supported_exts)]
|
||||
|
||||
matched_files = []
|
||||
for pos_file in self.pos_file_list:
|
||||
@@ -177,7 +177,6 @@ class PairedImageDataset(Dataset):
|
||||
# remove duplicates
|
||||
matched_files = [t for t in (set(tuple(i) for i in matched_files))]
|
||||
|
||||
|
||||
self.file_list = matched_files
|
||||
print(f" - Found {len(self.file_list)} matching pairs")
|
||||
else:
|
||||
@@ -190,6 +189,15 @@ class PairedImageDataset(Dataset):
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
def get_all_prompts(self):
|
||||
prompts = []
|
||||
for index in range(len(self.file_list)):
|
||||
prompts.append(self.get_prompt_item(index))
|
||||
|
||||
# remove duplicates
|
||||
prompts = list(set(prompts))
|
||||
return prompts
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file_list)
|
||||
|
||||
@@ -202,19 +210,9 @@ class PairedImageDataset(Dataset):
|
||||
else:
|
||||
return default
|
||||
|
||||
def __getitem__(self, index):
|
||||
def get_prompt_item(self, index):
|
||||
img_path_or_tuple = self.file_list[index]
|
||||
if isinstance(img_path_or_tuple, tuple):
|
||||
# load both images
|
||||
img_path = img_path_or_tuple[0]
|
||||
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
img_path = img_path_or_tuple[1]
|
||||
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
# combine them side by side
|
||||
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
|
||||
img.paste(img1, (0, 0))
|
||||
img.paste(img2, (img1.width, 0))
|
||||
|
||||
# check if either has a prompt file
|
||||
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
@@ -223,7 +221,6 @@ class PairedImageDataset(Dataset):
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
else:
|
||||
img_path = img_path_or_tuple
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
@@ -242,6 +239,25 @@ class PairedImageDataset(Dataset):
|
||||
prompt = ', '.join(prompt_split)
|
||||
else:
|
||||
prompt = self.default_prompt
|
||||
return prompt
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path_or_tuple = self.file_list[index]
|
||||
if isinstance(img_path_or_tuple, tuple):
|
||||
# load both images
|
||||
img_path = img_path_or_tuple[0]
|
||||
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
img_path = img_path_or_tuple[1]
|
||||
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
# combine them side by side
|
||||
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
|
||||
img.paste(img1, (0, 0))
|
||||
img.paste(img2, (img1.width, 0))
|
||||
else:
|
||||
img_path = img_path_or_tuple
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
|
||||
prompt = self.get_prompt_item(index)
|
||||
|
||||
height = self.size
|
||||
# determine width to keep aspect ratio
|
||||
@@ -252,4 +268,3 @@ class PairedImageDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user