Added training for a custom version of ERSGAN arcitecture. Testing training now

This commit is contained in:
Jaret Burkett
2023-08-07 18:04:23 -06:00
parent 8c90fa86c6
commit 8bd536df7e
9 changed files with 1694 additions and 141 deletions

View File

@@ -1,10 +1,14 @@
import os
import random
import cv2
import numpy as np
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
from torch.utils.data import Dataset
from tqdm import tqdm
import albumentations as A
class ImageDataset(Dataset):
@@ -38,7 +42,7 @@ class ImageDataset(Dataset):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
def get_config(self, key, default=None, required=False):
@@ -65,7 +69,7 @@ class ImageDataset(Dataset):
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}")
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
@@ -78,3 +82,61 @@ class ImageDataset(Dataset):
img = self.transform(img)
return img
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
self.params = kwargs.get('params', {})
# convert kwargs enums for cv2
for key, value in self.params.items():
if isinstance(value, str):
# split the string
split_string = value.split('.')
if len(split_string) == 2 and split_string[0] == 'cv2':
if hasattr(cv2, split_string[1]):
self.params[key] = getattr(cv2, split_string[1].upper())
else:
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
class AugmentedImageDataset(ImageDataset):
def __init__(self, config):
super().__init__(config)
self.augmentations = self.get_config('augmentations', [])
self.augmentations = [Augments(**aug) for aug in self.augmentations]
augmentation_list = []
for aug in self.augmentations:
# make sure method name is valid
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
# get the method
method = getattr(A, aug.method_name)
# add the method to the list
augmentation_list.append(method(**aug.params))
self.aug_transform = A.Compose(augmentation_list)
self.original_transform = self.transform
# replace transform so we get raw pil image
self.transform = transforms.Compose([])
def __getitem__(self, index):
# get the original image
# image is a PIL image, convert to bgr
pil_image = super().__getitem__(index)
open_cv_image = np.array(pil_image)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# apply augmentations
augmented = self.aug_transform(image=open_cv_image)["image"]
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
# return both # return image as 0 - 1 tensor
return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented)