Allow augmentations and targeting different loss types fron the config file

This commit is contained in:
Jaret Burkett
2023-10-18 03:04:57 -06:00
parent da6302ada8
commit 07bf7bd7de
6 changed files with 216 additions and 50 deletions

View File

@@ -1,15 +1,9 @@
import os.path
from collections import OrderedDict
from typing import Union
from PIL import Image
from diffusers import T2IAdapter
from torch.utils.data import DataLoader
from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc
@@ -24,7 +18,6 @@ def flush():
adapter_transforms = transforms.Compose([
# transforms.PILToTensor(),
transforms.ToTensor(),
])
@@ -62,6 +55,77 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu')
flush()
# you can expand these in a child class to make customization easier
def calculate_loss(
self,
noise_pred: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
control_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
# add latents and unaug latents
if control_pred is not None:
# matching adapter prediction
target = control_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
pred = noise_pred
ignore_snr = False
if loss_target == 'source' or loss_target == 'unaugmented':
# ignore_snr = True
if batch.sigmas is None:
raise ValueError("Batch sigmas is None. This should not happen")
# src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190
denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents
weighing = batch.sigmas ** -2.0
if loss_target == 'source':
# denoise the latent and compare to the latent in the batch
target = batch.latents
elif loss_target == 'unaugmented':
# we have to encode images into latents for now
# we also denoise as the unaugmented tensor is not a noisy diffirental
with torch.no_grad():
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
target = unaugmented_latents.detach()
# Get the target for loss depending on the prediction type
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
target = target # we are computing loss against denoise latents
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
# mse loss without reduction
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# multiply by our mask
loss = loss * mask_multiplier
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
return loss
def hook_train_loop(self, batch):
self.timer.start('preprocess_batch')
@@ -251,26 +315,15 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach()
if control_pred is not None:
# matching adapter prediction
target = control_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
# multiply by our mask
loss = loss * mask_multiplier
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
mask_multiplier=mask_multiplier,
control_pred=control_pred,
)
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")

View File

@@ -442,6 +442,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
# override in subclass
return params
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device, dtype=dtype)
schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device)
timesteps = timesteps.to(self.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad():
with self.timer('prepare_prompt'):
@@ -483,6 +495,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch.latents = latents
# flush() # todo check performance removing this
unaugmented_latents = None
if self.train_config.loss_target == 'differential_noise':
# we determine noise from the differential of the latents
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
batch_size = latents.shape[0]
with self.timer('prepare_noise'):
@@ -516,7 +533,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.train_config.max_denoising_steps
)
timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps,
self.train_config.min_denoising_steps + 1,
self.train_config.max_denoising_steps - 1
)
@@ -539,6 +556,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype)
if self.train_config.loss_target == 'differential_noise':
differential = latents - unaugmented_latents
# add noise to differential
# noise = noise + differential
noise = noise + (differential * 0.5)
# noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max())
latents = unaugmented_latents
noise_multiplier = self.train_config.noise_multiplier
noise = noise * noise_multiplier
@@ -549,6 +574,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
# https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77
if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented':
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# add it to the batch
batch.sigmas = sigmas
# todo is this for sdxl? find out where this came from originally
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
# remove grads for these
noisy_latents.requires_grad = False
noisy_latents = noisy_latents.detach()

View File

@@ -88,7 +88,7 @@ class EmbeddingConfig:
ContentOrStyleType = Literal['balanced', 'style', 'content']
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
class TrainConfig:
def __init__(self, **kwargs):
@@ -127,6 +127,7 @@ class TrainConfig:
match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented,
# legacy
if match_adapter_assist and self.match_adapter_chance == 0.0:
@@ -258,7 +259,13 @@ class DatasetConfig:
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk):
# https://albumentations.ai/docs/api_reference/augmentations/transforms
# augmentations are returned as a separate image and cannot currently be cached
self.augmentations: List[dict] = kwargs.get('augmentations', None)
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk):
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
self.cache_latents = False
self.cache_latents_to_disk = False

View File

@@ -16,7 +16,7 @@ import albumentations as A
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
if TYPE_CHECKING:
@@ -111,21 +111,7 @@ class ImageDataset(Dataset, CaptionMixin):
return img
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
self.params = kwargs.get('params', {})
# convert kwargs enums for cv2
for key, value in self.params.items():
if isinstance(value, str):
# split the string
split_string = value.split('.')
if len(split_string) == 2 and split_string[0] == 'cv2':
if hasattr(cv2, split_string[1]):
self.params[key] = getattr(cv2, split_string[1].upper())
else:
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
class AugmentedImageDataset(ImageDataset):

View File

@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
@@ -28,6 +28,7 @@ class FileItemDTO(
ImageProcessingDTOMixin,
ControlFileItemDTOMixin,
MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin,
):
@@ -80,6 +81,8 @@ class DataLoaderBatchDTO:
self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
if not is_latents_cached:
# only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
@@ -119,11 +122,26 @@ class DataLoaderBatchDTO:
else:
mask_tensors.append(x.mask_tensor)
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
# add unaugmented tensors for ones with augments
if any([x.unaugmented_tensor is not None for x in self.file_items]):
# find one to use as a base
base_unaugmented_tensor = None
for x in self.file_items:
if x.unaugmented_tensor is not None:
base_unaugmented_tensor = x.unaugmented_tensor
break
unaugmented_tensor = []
for x in self.file_items:
if x.unaugmented_tensor is None:
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
else:
unaugmented_tensor.append(x.unaugmented_tensor)
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
except Exception as e:
print(e)
raise e
def get_is_reg_list(self):
return [x.is_reg for x in self.file_items]

View File

@@ -7,6 +7,7 @@ import random
from collections import OrderedDict
from typing import TYPE_CHECKING, List, Dict, Union
import cv2
import numpy as np
import torch
from safetensors.torch import load_file, save_file
@@ -19,6 +20,7 @@ from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms
from PIL import Image, ImageFilter
from PIL.ImageOps import exif_transpose
import albumentations as A
from toolkit.train_tools import get_torch_dtype
@@ -26,8 +28,25 @@ if TYPE_CHECKING:
from toolkit.data_loader import AiToolkitDataset
from toolkit.data_transfer_object.data_loader import FileItemDTO
# def get_associated_caption_from_img_path(img_path):
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
self.params = kwargs.get('params', {})
# convert kwargs enums for cv2
for key, value in self.params.items():
if isinstance(value, str):
# split the string
split_string = value.split('.')
if len(split_string) == 2 and split_string[0] == 'cv2':
if hasattr(cv2, split_string[1]):
self.params[key] = getattr(cv2, split_string[1].upper())
else:
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
transforms_dict = {
'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
@@ -195,7 +214,6 @@ class BucketsMixin:
print(f'{len(self.buckets)} buckets made')
class CaptionProcessingDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
@@ -392,7 +410,10 @@ class ImageProcessingDTOMixin:
if augment in transforms_dict:
img = transforms_dict[augment](img)
if transform:
if self.has_augmentations:
# augmentations handles transforms
img = self.augment_image(img, transform=transform)
elif transform:
img = transform(img)
self.tensor = img
@@ -468,6 +489,54 @@ class ControlFileItemDTOMixin:
self.control_tensor = None
class AugmentationFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_augmentations = False
self.unaugmented_tensor: Union[torch.Tensor, None] = None
# self.augmentations: Union[None, List[Augments]] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.augmentations is not None and len(dataset_config.augmentations) > 0:
self.has_augmentations = True
augmentations = [Augments(**aug) for aug in dataset_config.augmentations]
augmentation_list = []
for aug in augmentations:
# make sure method name is valid
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
# get the method
method = getattr(A, aug.method_name)
# add the method to the list
augmentation_list.append(method(**aug.params))
self.aug_transform = A.Compose(augmentation_list)
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
# save the original tensor
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# apply augmentations
augmented = self.aug_transform(image=open_cv_image)["image"]
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
return augmented_tensor
def cleanup_control(self: 'FileItemDTO'):
self.unaugmented_tensor = None
class MaskFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
@@ -558,6 +627,7 @@ class MaskFileItemDTOMixin:
def cleanup_mask(self: 'FileItemDTO'):
self.mask_tensor = None
class PoiFileItemDTOMixin:
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
# items in the poi will always be inside the image when random cropping
@@ -799,7 +869,6 @@ class LatentCachingMixin:
del latent
del file_item.tensor
flush(garbage_collect=False)
file_item.is_latent_cached = True
# flush every 100