Added base for ultimate slider. WIP

This commit is contained in:
Jaret Burkett
2023-08-19 15:35:24 -06:00
parent c6675e2801
commit b77b9acc0b
6 changed files with 568 additions and 36 deletions

View File

@@ -163,7 +163,7 @@ class PairedImageDataset(Dataset):
self.pos_file_list = [os.path.join(self.pos_folder, file) for file in os.listdir(self.pos_folder) if
file.lower().endswith(supported_exts)]
self.neg_file_list = [os.path.join(self.neg_folder, file) for file in os.listdir(self.neg_folder) if
file.lower().endswith(supported_exts)]
file.lower().endswith(supported_exts)]
matched_files = []
for pos_file in self.pos_file_list:
@@ -177,7 +177,6 @@ class PairedImageDataset(Dataset):
# remove duplicates
matched_files = [t for t in (set(tuple(i) for i in matched_files))]
self.file_list = matched_files
print(f" - Found {len(self.file_list)} matching pairs")
else:
@@ -190,6 +189,15 @@ class PairedImageDataset(Dataset):
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
def get_all_prompts(self):
prompts = []
for index in range(len(self.file_list)):
prompts.append(self.get_prompt_item(index))
# remove duplicates
prompts = list(set(prompts))
return prompts
def __len__(self):
return len(self.file_list)
@@ -202,19 +210,9 @@ class PairedImageDataset(Dataset):
else:
return default
def __getitem__(self, index):
def get_prompt_item(self, index):
img_path_or_tuple = self.file_list[index]
if isinstance(img_path_or_tuple, tuple):
# load both images
img_path = img_path_or_tuple[0]
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
img_path = img_path_or_tuple[1]
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
# combine them side by side
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
img.paste(img1, (0, 0))
img.paste(img2, (img1.width, 0))
# check if either has a prompt file
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
prompt_path = path_no_ext + '.txt'
@@ -223,7 +221,6 @@ class PairedImageDataset(Dataset):
prompt_path = path_no_ext + '.txt'
else:
img_path = img_path_or_tuple
img = exif_transpose(Image.open(img_path)).convert('RGB')
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
@@ -242,6 +239,25 @@ class PairedImageDataset(Dataset):
prompt = ', '.join(prompt_split)
else:
prompt = self.default_prompt
return prompt
def __getitem__(self, index):
img_path_or_tuple = self.file_list[index]
if isinstance(img_path_or_tuple, tuple):
# load both images
img_path = img_path_or_tuple[0]
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
img_path = img_path_or_tuple[1]
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
# combine them side by side
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
img.paste(img1, (0, 0))
img.paste(img2, (img1.width, 0))
else:
img_path = img_path_or_tuple
img = exif_transpose(Image.open(img_path)).convert('RGB')
prompt = self.get_prompt_item(index)
height = self.size
# determine width to keep aspect ratio
@@ -252,4 +268,3 @@ class PairedImageDataset(Dataset):
img = self.transform(img)
return img, prompt, (self.neg_weight, self.pos_weight)