Made a very basic vae trainer.

This commit is contained in:
Jaret Burkett
2023-07-17 19:03:50 -06:00
parent 78b59c5e99
commit 439310e4dc
11 changed files with 410 additions and 89 deletions

67
toolkit/data_loader.py Normal file
View File

@@ -0,0 +1,67 @@
import os
import random
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, config):
self.config = config
self.name = self.get_config('name', 'dataset')
self.path = self.get_config('path', required=True)
self.scale = self.get_config('scale', 1)
self.random_scale = self.get_config('random_scale', False)
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else self.get_config('random_crop', False)
self.resolution = self.get_config('resolution', 256)
self.file_list = [os.path.join(self.path, file) for file in os.listdir(self.path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))]
# this might take a while
print(f" - Preprocessing image dimensions")
self.file_list = [file for file in self.file_list if
int(min(Image.open(file).size) * self.scale) >= self.resolution]
print(f" - Found {len(self.file_list)} images")
assert len(self.file_list) > 0, f"no images found in {self.path}"
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
])
def get_config(self, key, default=None, required=False):
if key in self.config:
value = self.config[key]
return value
elif required:
raise ValueError(f'config file error. Missing "config.dataset.{key}" key')
else:
return default
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
img_path = self.file_list[index]
img = exif_transpose(Image.open(img_path)).convert('RGB')
# Downscale the source image first
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
if self.random_crop:
if self.random_scale:
scale_size = random.randint(int(img.size[0] * self.scale), self.resolution)
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
min_dimension = min(img.size)
img = transforms.CenterCrop(min_dimension)(img)
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
img = self.transform(img)
return img

View File

@@ -10,6 +10,10 @@ def get_job(config_path):
if job == 'extract':
from jobs import ExtractJob
return ExtractJob(config)
if job == 'train':
from jobs import TrainJob
return TrainJob(config)
# elif job == 'train':
# from jobs import TrainJob
# return TrainJob(config)

View File

@@ -3,3 +3,4 @@ import os
TOOLKIT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
CONFIG_ROOT = os.path.join(TOOLKIT_ROOT, 'config')
SD_SCRIPTS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories", "sd-scripts")
REPOS_ROOT = os.path.join(TOOLKIT_ROOT, "repositories")