mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Allow augmentations and targeting different loss types fron the config file
This commit is contained in:
@@ -88,7 +88,7 @@ class EmbeddingConfig:
|
||||
|
||||
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
|
||||
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -127,6 +127,7 @@ class TrainConfig:
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented,
|
||||
|
||||
# legacy
|
||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||
@@ -258,7 +259,13 @@ class DatasetConfig:
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||
|
||||
if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk):
|
||||
# https://albumentations.ai/docs/api_reference/augmentations/transforms
|
||||
# augmentations are returned as a separate image and cannot currently be cached
|
||||
self.augmentations: List[dict] = kwargs.get('augmentations', None)
|
||||
|
||||
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
|
||||
|
||||
if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk):
|
||||
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
|
||||
self.cache_latents = False
|
||||
self.cache_latents_to_disk = False
|
||||
|
||||
@@ -16,7 +16,7 @@ import albumentations as A
|
||||
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -111,21 +111,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
return img
|
||||
|
||||
|
||||
class Augments:
|
||||
def __init__(self, **kwargs):
|
||||
self.method_name = kwargs.get('method', None)
|
||||
self.params = kwargs.get('params', {})
|
||||
|
||||
# convert kwargs enums for cv2
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, str):
|
||||
# split the string
|
||||
split_string = value.split('.')
|
||||
if len(split_string) == 2 and split_string[0] == 'cv2':
|
||||
if hasattr(cv2, split_string[1]):
|
||||
self.params[key] = getattr(cv2, split_string[1].upper())
|
||||
else:
|
||||
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
|
||||
|
||||
|
||||
class AugmentedImageDataset(ImageDataset):
|
||||
|
||||
@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -28,6 +28,7 @@ class FileItemDTO(
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
AugmentationFileItemDTOMixin,
|
||||
PoiFileItemDTOMixin,
|
||||
ArgBreakMixin,
|
||||
):
|
||||
@@ -80,6 +81,8 @@ class DataLoaderBatchDTO:
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
@@ -119,11 +122,26 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
mask_tensors.append(x.mask_tensor)
|
||||
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
|
||||
|
||||
# add unaugmented tensors for ones with augments
|
||||
if any([x.unaugmented_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_unaugmented_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.unaugmented_tensor is not None:
|
||||
base_unaugmented_tensor = x.unaugmented_tensor
|
||||
break
|
||||
unaugmented_tensor = []
|
||||
for x in self.file_items:
|
||||
if x.unaugmented_tensor is None:
|
||||
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
|
||||
else:
|
||||
unaugmented_tensor.append(x.unaugmented_tensor)
|
||||
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
|
||||
def get_is_reg_list(self):
|
||||
return [x.is_reg for x in self.file_items]
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import random
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, List, Dict, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
@@ -19,6 +20,7 @@ from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
from torchvision import transforms
|
||||
from PIL import Image, ImageFilter
|
||||
from PIL.ImageOps import exif_transpose
|
||||
import albumentations as A
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -26,8 +28,25 @@ if TYPE_CHECKING:
|
||||
from toolkit.data_loader import AiToolkitDataset
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
|
||||
# def get_associated_caption_from_img_path(img_path):
|
||||
|
||||
class Augments:
|
||||
def __init__(self, **kwargs):
|
||||
self.method_name = kwargs.get('method', None)
|
||||
self.params = kwargs.get('params', {})
|
||||
|
||||
# convert kwargs enums for cv2
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, str):
|
||||
# split the string
|
||||
split_string = value.split('.')
|
||||
if len(split_string) == 2 and split_string[0] == 'cv2':
|
||||
if hasattr(cv2, split_string[1]):
|
||||
self.params[key] = getattr(cv2, split_string[1].upper())
|
||||
else:
|
||||
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
|
||||
|
||||
|
||||
transforms_dict = {
|
||||
'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
|
||||
@@ -195,7 +214,6 @@ class BucketsMixin:
|
||||
print(f'{len(self.buckets)} buckets made')
|
||||
|
||||
|
||||
|
||||
class CaptionProcessingDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
@@ -392,7 +410,10 @@ class ImageProcessingDTOMixin:
|
||||
if augment in transforms_dict:
|
||||
img = transforms_dict[augment](img)
|
||||
|
||||
if transform:
|
||||
if self.has_augmentations:
|
||||
# augmentations handles transforms
|
||||
img = self.augment_image(img, transform=transform)
|
||||
elif transform:
|
||||
img = transform(img)
|
||||
|
||||
self.tensor = img
|
||||
@@ -468,6 +489,54 @@ class ControlFileItemDTOMixin:
|
||||
self.control_tensor = None
|
||||
|
||||
|
||||
class AugmentationFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.has_augmentations = False
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
# self.augmentations: Union[None, List[Augments]] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.augmentations is not None and len(dataset_config.augmentations) > 0:
|
||||
self.has_augmentations = True
|
||||
augmentations = [Augments(**aug) for aug in dataset_config.augmentations]
|
||||
augmentation_list = []
|
||||
for aug in augmentations:
|
||||
# make sure method name is valid
|
||||
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
|
||||
# get the method
|
||||
method = getattr(A, aug.method_name)
|
||||
# add the method to the list
|
||||
augmentation_list.append(method(**aug.params))
|
||||
|
||||
self.aug_transform = A.Compose(augmentation_list)
|
||||
|
||||
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
|
||||
|
||||
# save the original tensor
|
||||
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
|
||||
|
||||
open_cv_image = np.array(img)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
# apply augmentations
|
||||
augmented = self.aug_transform(image=open_cv_image)["image"]
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# convert to PIL image
|
||||
augmented = Image.fromarray(augmented)
|
||||
|
||||
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
|
||||
|
||||
return augmented_tensor
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.unaugmented_tensor = None
|
||||
|
||||
|
||||
class MaskFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
@@ -558,6 +627,7 @@ class MaskFileItemDTOMixin:
|
||||
def cleanup_mask(self: 'FileItemDTO'):
|
||||
self.mask_tensor = None
|
||||
|
||||
|
||||
class PoiFileItemDTOMixin:
|
||||
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
|
||||
# items in the poi will always be inside the image when random cropping
|
||||
@@ -799,7 +869,6 @@ class LatentCachingMixin:
|
||||
del latent
|
||||
del file_item.tensor
|
||||
|
||||
|
||||
flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
# flush every 100
|
||||
|
||||
Reference in New Issue
Block a user