mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added ability to use two seperate folders for datasets when doing image reference sliders
This commit is contained in:
@@ -147,12 +147,43 @@ class PairedImageDataset(Dataset):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.size = self.get_config('size', 512)
|
||||
self.path = self.get_config('path', required=True)
|
||||
self.path = self.get_config('path', None)
|
||||
self.pos_folder = self.get_config('pos_folder', None)
|
||||
self.neg_folder = self.get_config('neg_folder', None)
|
||||
|
||||
self.default_prompt = self.get_config('default_prompt', '')
|
||||
self.network_weight = self.get_config('network_weight', 1.0)
|
||||
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
|
||||
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
self.pos_weight = self.get_config('pos_weight', self.network_weight)
|
||||
self.neg_weight = self.get_config('neg_weight', self.network_weight)
|
||||
|
||||
supported_exts = ('.jpg', '.jpeg', '.png', '.webp', '.JPEG', '.JPG', '.PNG', '.WEBP')
|
||||
|
||||
if self.pos_folder is not None and self.neg_folder is not None:
|
||||
# find matching files
|
||||
self.pos_file_list = [os.path.join(self.pos_folder, file) for file in os.listdir(self.pos_folder) if
|
||||
file.lower().endswith(supported_exts)]
|
||||
self.neg_file_list = [os.path.join(self.neg_folder, file) for file in os.listdir(self.neg_folder) if
|
||||
file.lower().endswith(supported_exts)]
|
||||
|
||||
matched_files = []
|
||||
for pos_file in self.pos_file_list:
|
||||
pos_file_no_ext = os.path.splitext(pos_file)[0]
|
||||
for neg_file in self.neg_file_list:
|
||||
neg_file_no_ext = os.path.splitext(neg_file)[0]
|
||||
if os.path.basename(pos_file_no_ext) == os.path.basename(neg_file_no_ext):
|
||||
matched_files.append((neg_file, pos_file))
|
||||
break
|
||||
|
||||
# remove duplicates
|
||||
matched_files = [t for t in (set(tuple(i) for i in matched_files))]
|
||||
|
||||
|
||||
self.file_list = matched_files
|
||||
print(f" - Found {len(self.file_list)} matching pairs")
|
||||
else:
|
||||
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
|
||||
file.lower().endswith(supported_exts)]
|
||||
print(f" - Found {len(self.file_list)} images")
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
@@ -172,12 +203,31 @@ class PairedImageDataset(Dataset):
|
||||
return default
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = self.file_list[index]
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
img_path_or_tuple = self.file_list[index]
|
||||
if isinstance(img_path_or_tuple, tuple):
|
||||
# load both images
|
||||
img_path = img_path_or_tuple[0]
|
||||
img1 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
img_path = img_path_or_tuple[1]
|
||||
img2 = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
# combine them side by side
|
||||
img = Image.new('RGB', (img1.width + img2.width, max(img1.height, img2.height)))
|
||||
img.paste(img1, (0, 0))
|
||||
img.paste(img2, (img1.width, 0))
|
||||
|
||||
# check if either has a prompt file
|
||||
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
if not os.path.exists(prompt_path):
|
||||
path_no_ext = os.path.splitext(img_path_or_tuple[1])[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
else:
|
||||
img_path = img_path_or_tuple
|
||||
img = exif_transpose(Image.open(img_path)).convert('RGB')
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
|
||||
# see if prompt file exists
|
||||
path_no_ext = os.path.splitext(img_path)[0]
|
||||
prompt_path = path_no_ext + '.txt'
|
||||
if os.path.exists(prompt_path):
|
||||
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||
prompt = f.read()
|
||||
@@ -201,5 +251,5 @@ class PairedImageDataset(Dataset):
|
||||
img = img.resize((width, height), Image.BICUBIC)
|
||||
img = self.transform(img)
|
||||
|
||||
return img, prompt, self.network_weight
|
||||
return img, prompt, (self.neg_weight, self.pos_weight)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user