Added ability to use two seperate folders for datasets when doing image reference sliders

This commit is contained in:
Jaret Burkett
2023-08-18 11:44:33 -06:00
parent 8d09eb44ec
commit d51c4ca704
3 changed files with 134 additions and 59 deletions

View File

@@ -147,12 +147,43 @@ class PairedImageDataset(Dataset):
super().__init__()
self.config = config
self.size = self.get_config('size', 512)
self.path = self.get_config('path', required=True)
self.path = self.get_config('path', None)
self.pos_folder = self.get_config('pos_folder', None)
self.neg_folder = self.get_config('neg_folder', None)
self.default_prompt = self.get_config('default_prompt', '')
self.network_weight = self.get_config('network_weight', 1.0)
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
print(f" - Found {len(self.file_list)} images")
self.pos_weight = self.get_config('pos_weight', self.network_weight)
self.neg_weight = self.get_config('neg_weight', self.network_weight)
supported_exts = ('.jpg', '.jpeg', '.png', '.webp', '.JPEG', '.JPG', '.PNG', '.WEBP')
if self.pos_folder is not None and self.neg_folder is not None:
# find matching files
self.pos_file_list = [os.path.join(self.pos_folder, file) for file in os.listdir(self.pos_folder) if
file.lower().endswith(supported_exts)]
self.neg_file_list = [os.path.join(self.neg_folder, file) for file in os.listdir(self.neg_folder) if
file.lower().endswith(supported_exts)]
matched_files = []
for pos_file in self.pos_file_list:
pos_file_no_ext = os.path.splitext(pos_file)[0]
for neg_file in self.neg_file_list:
neg_file_no_ext = os.path.splitext(neg_file)[0]
if os.path.basename(pos_file_no_ext) == os.path.basename(neg_file_no_ext):
matched_files.append((neg_file, pos_file))
break
# remove duplicates
matched_files = [t for t in (set(tuple(i) for i in matched_files))]
self.file_list = matched_files
print(f" - Found {len(self.file_list)} matching pairs")
else:
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
file.lower().endswith(supported_exts)]
print(f" - Found {len(self.file_list)} images")
self.transform = transforms.Compose([
transforms.ToTensor(),
@@ -172,12 +203,31 @@ class PairedImageDataset(Dataset):
return default
def __getitem__(self, index):
img_path = self.file_list[index]
img = exif_transpose(Image.open(img_path)).convert('RGB')
img_path_or_tuple = self.file_list[index]
if isinstance(img_path_or_tuple, tuple):
# load both images
img_path = img_path_or_tuple[0]
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
img_path = img_path_or_tuple[1]
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
# combine them side by side
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
img.paste(img1, (0, 0))
img.paste(img2, (img1.width, 0))
# check if either has a prompt file
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
prompt_path = path_no_ext + '.txt'
if not os.path.exists(prompt_path):
path_no_ext = os.path.splitext(img_path_or_tuple[1])[0]
prompt_path = path_no_ext + '.txt'
else:
img_path = img_path_or_tuple
img = exif_transpose(Image.open(img_path)).convert('RGB')
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
@@ -201,5 +251,5 @@ class PairedImageDataset(Dataset):
img = img.resize((width, height), Image.BICUBIC)
img = self.transform(img)
return img, prompt, self.network_weight
return img, prompt, (self.neg_weight, self.pos_weight)