Allow augmentations and targeting different loss types fron the config file

This commit is contained in:
Jaret Burkett
2023-10-18 03:04:57 -06:00
parent da6302ada8
commit 07bf7bd7de
6 changed files with 216 additions and 50 deletions

View File

@@ -88,7 +88,7 @@ class EmbeddingConfig:
ContentOrStyleType = Literal['balanced', 'style', 'content']
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
class TrainConfig:
def __init__(self, **kwargs):
@@ -127,6 +127,7 @@ class TrainConfig:
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented,
# legacy
if match_adapter_assist and self.match_adapter_chance == 0.0:
@@ -258,7 +259,13 @@ class DatasetConfig:
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk):
# https://albumentations.ai/docs/api_reference/augmentations/transforms
# augmentations are returned as a separate image and cannot currently be cached
self.augmentations: List[dict] = kwargs.get('augmentations', None)
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk):
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
self.cache_latents = False
self.cache_latents_to_disk = False

View File

@@ -16,7 +16,7 @@ import albumentations as A
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
if TYPE_CHECKING:
@@ -111,21 +111,7 @@ class ImageDataset(Dataset, CaptionMixin):
return img
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
self.params = kwargs.get('params', {})
# convert kwargs enums for cv2
for key, value in self.params.items():
if isinstance(value, str):
# split the string
split_string = value.split('.')
if len(split_string) == 2 and split_string[0] == 'cv2':
if hasattr(cv2, split_string[1]):
self.params[key] = getattr(cv2, split_string[1].upper())
else:
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
class AugmentedImageDataset(ImageDataset):

View File

@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -28,6 +28,7 @@ class FileItemDTO(
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin,
):
@@ -80,6 +81,8 @@ class DataLoaderBatchDTO:
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
@@ -119,11 +122,26 @@ class DataLoaderBatchDTO:
else:
mask_tensors.append(x.mask_tensor)
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
# add unaugmented tensors for ones with augments
if any([x.unaugmented_tensor is not None for x in self.file_items]):
# find one to use as a base
base_unaugmented_tensor = None
for x in self.file_items:
if x.unaugmented_tensor is not None:
base_unaugmented_tensor = x.unaugmented_tensor
break
unaugmented_tensor = []
for x in self.file_items:
if x.unaugmented_tensor is None:
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
else:
unaugmented_tensor.append(x.unaugmented_tensor)
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
except Exception as e:
print(e)
raise e
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]

View File

@@ -7,6 +7,7 @@ import random
from collections import OrderedDict
from typing import TYPE_CHECKING, List, Dict, Union
import cv2
import numpy as np
import torch
from safetensors.torch import load_file, save_file
@@ -19,6 +20,7 @@ from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms
from PIL import Image, ImageFilter
from PIL.ImageOps import exif_transpose
import albumentations as A
from toolkit.train_tools import get_torch_dtype
@@ -26,8 +28,25 @@ if TYPE_CHECKING:
from toolkit.data_loader import AiToolkitDataset
from toolkit.data_transfer_object.data_loader import FileItemDTO
# def get_associated_caption_from_img_path(img_path):
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
self.params = kwargs.get('params', {})
# convert kwargs enums for cv2
for key, value in self.params.items():
if isinstance(value, str):
# split the string
split_string = value.split('.')
if len(split_string) == 2 and split_string[0] == 'cv2':
if hasattr(cv2, split_string[1]):
self.params[key] = getattr(cv2, split_string[1].upper())
else:
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
transforms_dict = {
'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
@@ -195,7 +214,6 @@ class BucketsMixin:
print(f'{len(self.buckets)} buckets made')
class CaptionProcessingDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
@@ -392,7 +410,10 @@ class ImageProcessingDTOMixin:
if augment in transforms_dict:
img = transforms_dict[augment](img)
if transform:
if self.has_augmentations:
# augmentations handles transforms
img = self.augment_image(img, transform=transform)
elif transform:
img = transform(img)
self.tensor = img
@@ -468,6 +489,54 @@ class ControlFileItemDTOMixin:
self.control_tensor = None
class AugmentationFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_augmentations = False
self.unaugmented_tensor: Union[torch.Tensor, None] = None
# self.augmentations: Union[None, List[Augments]] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.augmentations is not None and len(dataset_config.augmentations) > 0:
self.has_augmentations = True
augmentations = [Augments(**aug) for aug in dataset_config.augmentations]
augmentation_list = []
for aug in augmentations:
# make sure method name is valid
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
# get the method
method = getattr(A, aug.method_name)
# add the method to the list
augmentation_list.append(method(**aug.params))
self.aug_transform = A.Compose(augmentation_list)
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
# save the original tensor
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# apply augmentations
augmented = self.aug_transform(image=open_cv_image)["image"]
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
return augmented_tensor
def cleanup_control(self: 'FileItemDTO'):
self.unaugmented_tensor = None
class MaskFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
@@ -558,6 +627,7 @@ class MaskFileItemDTOMixin:
def cleanup_mask(self: 'FileItemDTO'):
self.mask_tensor = None
class PoiFileItemDTOMixin:
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
# items in the poi will always be inside the image when random cropping
@@ -799,7 +869,6 @@ class LatentCachingMixin:
del latent
del file_item.tensor
flush(garbage_collect=False)
file_item.is_latent_cached = True
# flush every 100