mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 08:49:14 +00:00
Added training for a custom version of ERSGAN arcitecture. Testing training now
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
@@ -38,7 +42,7 @@ class ImageDataset(Dataset):
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
def get_config(self, key, default=None, required=False):
|
||||
@@ -65,7 +69,7 @@ class ImageDataset(Dataset):
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}")
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
@@ -78,3 +82,61 @@ class ImageDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class Augments:
|
||||
def __init__(self, **kwargs):
|
||||
self.method_name = kwargs.get('method', None)
|
||||
self.params = kwargs.get('params', {})
|
||||
|
||||
# convert kwargs enums for cv2
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, str):
|
||||
# split the string
|
||||
split_string = value.split('.')
|
||||
if len(split_string) == 2 and split_string[0] == 'cv2':
|
||||
if hasattr(cv2, split_string[1]):
|
||||
self.params[key] = getattr(cv2, split_string[1].upper())
|
||||
else:
|
||||
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
|
||||
|
||||
|
||||
class AugmentedImageDataset(ImageDataset):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.augmentations = self.get_config('augmentations', [])
|
||||
self.augmentations = [Augments(**aug) for aug in self.augmentations]
|
||||
|
||||
augmentation_list = []
|
||||
for aug in self.augmentations:
|
||||
# make sure method name is valid
|
||||
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
|
||||
# get the method
|
||||
method = getattr(A, aug.method_name)
|
||||
# add the method to the list
|
||||
augmentation_list.append(method(**aug.params))
|
||||
|
||||
self.aug_transform = A.Compose(augmentation_list)
|
||||
self.original_transform = self.transform
|
||||
# replace transform so we get raw pil image
|
||||
self.transform = transforms.Compose([])
|
||||
|
||||
def __getitem__(self, index):
|
||||
# get the original image
|
||||
# image is a PIL image, convert to bgr
|
||||
pil_image = super().__getitem__(index)
|
||||
open_cv_image = np.array(pil_image)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
# apply augmentations
|
||||
augmented = self.aug_transform(image=open_cv_image)["image"]
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# convert to PIL image
|
||||
augmented = Image.fromarray(augmented)
|
||||
|
||||
# return both # return image as 0 - 1 tensor
|
||||
return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented)
|
||||
|
||||
Reference in New Issue
Block a user