Added inbuild plugins and made one for image referenced. WIP

This commit is contained in:
Jaret Burkett
2023-08-10 16:02:44 -06:00
parent df48f0a843
commit 1a7e346b41
12 changed files with 338 additions and 26 deletions

View File

@@ -140,3 +140,65 @@ class AugmentedImageDataset(ImageDataset):
# return both # return image as 0 - 1 tensor
return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented)
class PairedImageDataset(Dataset):
def __init__(self, config):
super().__init__()
self.config = config
self.size = self.get_config('size', 512)
self.path = self.get_config('path', required=True)
self.default_prompt = self.get_config('default_prompt', '')
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
print(f" - Found {len(self.file_list)} images")
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
def __len__(self):
return len(self.file_list)
def get_config(self, key, default=None, required=False):
if key in self.config:
value = self.config[key]
return value
elif required:
raise ValueError(f'config file error. Missing "config.dataset.{key}" key')
else:
return default
def __getitem__(self, index):
img_path = self.file_list[index]
img = exif_transpose(Image.open(img_path)).convert('RGB')
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
if os.path.exists(prompt_path):
with open(prompt_path, 'r', encoding='utf-8') as f:
prompt = f.read()
# remove any newlines
prompt = prompt.replace('\n', ', ')
# remove new lines for all operating systems
prompt = prompt.replace('\r', ', ')
prompt_split = prompt.split(',')
# remove empty strings
prompt_split = [p.strip() for p in prompt_split if p.strip()]
# join back together
prompt = ', '.join(prompt_split)
else:
prompt = self.default_prompt
height = self.size
# determine width to keep aspect ratio
width = int(img.size[0] * height / img.size[1])
# Downscale the source image first
img = img.resize((width, height), Image.BICUBIC)
img = self.transform(img)
return img, prompt