Allow augmentations and targeting different loss types fron the config file

This commit is contained in:
Jaret Burkett
2023-10-18 03:04:57 -06:00
parent da6302ada8
commit 07bf7bd7de
6 changed files with 216 additions and 50 deletions

View File

@@ -1,15 +1,9 @@
import os.path
from collections import OrderedDict from collections import OrderedDict
from typing import Union from typing import Union
from PIL import Image
from diffusers import T2IAdapter from diffusers import T2IAdapter
from torch.utils.data import DataLoader
from toolkit.basic import value_map from toolkit.basic import value_map
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight from toolkit.train_tools import get_torch_dtype, apply_snr_weight
import gc import gc
@@ -24,7 +18,6 @@ def flush():
adapter_transforms = transforms.Compose([ adapter_transforms = transforms.Compose([
# transforms.PILToTensor(),
transforms.ToTensor(), transforms.ToTensor(),
]) ])
@@ -62,6 +55,77 @@ class SDTrainer(BaseSDTrainProcess):
self.sd.vae.to('cpu') self.sd.vae.to('cpu')
flush() flush()
# you can expand these in a child class to make customization easier
def calculate_loss(
self,
noise_pred: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
timesteps: torch.Tensor,
batch: 'DataLoaderBatchDTO',
mask_multiplier: Union[torch.Tensor, float] = 1.0,
control_pred: Union[torch.Tensor, None] = None,
**kwargs
):
loss_target = self.train_config.loss_target
# add latents and unaug latents
if control_pred is not None:
# matching adapter prediction
target = control_pred
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
else:
target = noise
pred = noise_pred
ignore_snr = False
if loss_target == 'source' or loss_target == 'unaugmented':
# ignore_snr = True
if batch.sigmas is None:
raise ValueError("Batch sigmas is None. This should not happen")
# src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190
denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents
weighing = batch.sigmas ** -2.0
if loss_target == 'source':
# denoise the latent and compare to the latent in the batch
target = batch.latents
elif loss_target == 'unaugmented':
# we have to encode images into latents for now
# we also denoise as the unaugmented tensor is not a noisy diffirental
with torch.no_grad():
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
target = unaugmented_latents.detach()
# Get the target for loss depending on the prediction type
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
target = target # we are computing loss against denoise latents
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
# mse loss without reduction
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
loss = loss_per_element
else:
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# multiply by our mask
loss = loss * mask_multiplier
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
return loss
def hook_train_loop(self, batch): def hook_train_loop(self, batch):
self.timer.start('preprocess_batch') self.timer.start('preprocess_batch')
@@ -251,26 +315,15 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('calculate_loss'): with self.timer('calculate_loss'):
noise = noise.to(self.device_torch, dtype=dtype).detach() noise = noise.to(self.device_torch, dtype=dtype).detach()
loss = self.calculate_loss(
if control_pred is not None: noise_pred=noise_pred,
# matching adapter prediction noise=noise,
target = control_pred noisy_latents=noisy_latents,
elif self.sd.prediction_type == 'v_prediction': timesteps=timesteps,
# v-parameterization training batch=batch,
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) mask_multiplier=mask_multiplier,
else: control_pred=control_pred,
target = noise )
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
# multiply by our mask
loss = loss * mask_multiplier
loss = loss.mean([1, 2, 3])
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
# add min_snr_gamma
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# check if nan # check if nan
if torch.isnan(loss): if torch.isnan(loss):
raise ValueError("loss is nan") raise ValueError("loss is nan")

View File

@@ -442,6 +442,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
# override in subclass # override in subclass
return params return params
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device, dtype=dtype)
schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device)
timesteps = timesteps.to(self.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
with torch.no_grad(): with torch.no_grad():
with self.timer('prepare_prompt'): with self.timer('prepare_prompt'):
@@ -483,6 +495,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch.latents = latents batch.latents = latents
# flush() # todo check performance removing this # flush() # todo check performance removing this
unaugmented_latents = None
if self.train_config.loss_target == 'differential_noise':
# we determine noise from the differential of the latents
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
batch_size = latents.shape[0] batch_size = latents.shape[0]
with self.timer('prepare_noise'): with self.timer('prepare_noise'):
@@ -516,7 +533,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.train_config.max_denoising_steps self.train_config.max_denoising_steps
) )
timesteps = timesteps.long().clamp( timesteps = timesteps.long().clamp(
self.train_config.min_denoising_steps, self.train_config.min_denoising_steps + 1,
self.train_config.max_denoising_steps - 1 self.train_config.max_denoising_steps - 1
) )
@@ -539,6 +556,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise_offset=self.train_config.noise_offset noise_offset=self.train_config.noise_offset
).to(self.device_torch, dtype=dtype) ).to(self.device_torch, dtype=dtype)
if self.train_config.loss_target == 'differential_noise':
differential = latents - unaugmented_latents
# add noise to differential
# noise = noise + differential
noise = noise + (differential * 0.5)
# noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max())
latents = unaugmented_latents
noise_multiplier = self.train_config.noise_multiplier noise_multiplier = self.train_config.noise_multiplier
noise = noise * noise_multiplier noise = noise * noise_multiplier
@@ -549,6 +574,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
# https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77
if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented':
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
# add it to the batch
batch.sigmas = sigmas
# todo is this for sdxl? find out where this came from originally
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
# remove grads for these # remove grads for these
noisy_latents.requires_grad = False noisy_latents.requires_grad = False
noisy_latents = noisy_latents.detach() noisy_latents = noisy_latents.detach()

View File

@@ -88,7 +88,7 @@ class EmbeddingConfig:
ContentOrStyleType = Literal['balanced', 'style', 'content'] ContentOrStyleType = Literal['balanced', 'style', 'content']
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
class TrainConfig: class TrainConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
@@ -127,6 +127,7 @@ class TrainConfig:
match_adapter_assist = kwargs.get('match_adapter_assist', False) match_adapter_assist = kwargs.get('match_adapter_assist', False)
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0) self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented,
# legacy # legacy
if match_adapter_assist and self.match_adapter_chance == 0.0: if match_adapter_assist and self.match_adapter_chance == 0.0:
@@ -258,7 +259,13 @@ class DatasetConfig:
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False) self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk): # https://albumentations.ai/docs/api_reference/augmentations/transforms
# augmentations are returned as a separate image and cannot currently be cached
self.augmentations: List[dict] = kwargs.get('augmentations', None)
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk):
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False") print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
self.cache_latents = False self.cache_latents = False
self.cache_latents_to_disk = False self.cache_latents_to_disk = False

View File

@@ -16,7 +16,7 @@ import albumentations as A
from toolkit.buckets import get_bucket_for_image_size, BucketResolution from toolkit.buckets import get_bucket_for_image_size, BucketResolution
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -111,21 +111,7 @@ class ImageDataset(Dataset, CaptionMixin):
return img return img
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
self.params = kwargs.get('params', {})
# convert kwargs enums for cv2
for key, value in self.params.items():
if isinstance(value, str):
# split the string
split_string = value.split('.')
if len(split_string) == 2 and split_string[0] == 'cv2':
if hasattr(cv2, split_string[1]):
self.params[key] = getattr(cv2, split_string[1].upper())
else:
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
class AugmentedImageDataset(ImageDataset): class AugmentedImageDataset(ImageDataset):

View File

@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig from toolkit.config_modules import DatasetConfig
@@ -28,6 +28,7 @@ class FileItemDTO(
ImageProcessingDTOMixin, ImageProcessingDTOMixin,
ControlFileItemDTOMixin, ControlFileItemDTOMixin,
MaskFileItemDTOMixin, MaskFileItemDTOMixin,
AugmentationFileItemDTOMixin,
PoiFileItemDTOMixin, PoiFileItemDTOMixin,
ArgBreakMixin, ArgBreakMixin,
): ):
@@ -80,6 +81,8 @@ class DataLoaderBatchDTO:
self.latents: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None
self.control_tensor: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None
self.mask_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None
self.unaugmented_tensor: Union[torch.Tensor, None] = None
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
if not is_latents_cached: if not is_latents_cached:
# only return a tensor if latents are not cached # only return a tensor if latents are not cached
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
@@ -119,11 +122,26 @@ class DataLoaderBatchDTO:
else: else:
mask_tensors.append(x.mask_tensor) mask_tensors.append(x.mask_tensor)
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors]) self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
# add unaugmented tensors for ones with augments
if any([x.unaugmented_tensor is not None for x in self.file_items]):
# find one to use as a base
base_unaugmented_tensor = None
for x in self.file_items:
if x.unaugmented_tensor is not None:
base_unaugmented_tensor = x.unaugmented_tensor
break
unaugmented_tensor = []
for x in self.file_items:
if x.unaugmented_tensor is None:
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
else:
unaugmented_tensor.append(x.unaugmented_tensor)
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
except Exception as e: except Exception as e:
print(e) print(e)
raise e raise e
def get_is_reg_list(self): def get_is_reg_list(self):
return [x.is_reg for x in self.file_items] return [x.is_reg for x in self.file_items]

View File

@@ -7,6 +7,7 @@ import random
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, List, Dict, Union from typing import TYPE_CHECKING, List, Dict, Union
import cv2
import numpy as np import numpy as np
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
@@ -19,6 +20,7 @@ from toolkit.prompt_utils import inject_trigger_into_prompt
from torchvision import transforms from torchvision import transforms
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
import albumentations as A
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
@@ -26,8 +28,25 @@ if TYPE_CHECKING:
from toolkit.data_loader import AiToolkitDataset from toolkit.data_loader import AiToolkitDataset
from toolkit.data_transfer_object.data_loader import FileItemDTO from toolkit.data_transfer_object.data_loader import FileItemDTO
# def get_associated_caption_from_img_path(img_path): # def get_associated_caption_from_img_path(img_path):
class Augments:
def __init__(self, **kwargs):
self.method_name = kwargs.get('method', None)
self.params = kwargs.get('params', {})
# convert kwargs enums for cv2
for key, value in self.params.items():
if isinstance(value, str):
# split the string
split_string = value.split('.')
if len(split_string) == 2 and split_string[0] == 'cv2':
if hasattr(cv2, split_string[1]):
self.params[key] = getattr(cv2, split_string[1].upper())
else:
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
transforms_dict = { transforms_dict = {
'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03), 'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
@@ -195,7 +214,6 @@ class BucketsMixin:
print(f'{len(self.buckets)} buckets made') print(f'{len(self.buckets)} buckets made')
class CaptionProcessingDTOMixin: class CaptionProcessingDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs): def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'): if hasattr(super(), '__init__'):
@@ -392,7 +410,10 @@ class ImageProcessingDTOMixin:
if augment in transforms_dict: if augment in transforms_dict:
img = transforms_dict[augment](img) img = transforms_dict[augment](img)
if transform: if self.has_augmentations:
# augmentations handles transforms
img = self.augment_image(img, transform=transform)
elif transform:
img = transform(img) img = transform(img)
self.tensor = img self.tensor = img
@@ -468,6 +489,54 @@ class ControlFileItemDTOMixin:
self.control_tensor = None self.control_tensor = None
class AugmentationFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
self.has_augmentations = False
self.unaugmented_tensor: Union[torch.Tensor, None] = None
# self.augmentations: Union[None, List[Augments]] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
if dataset_config.augmentations is not None and len(dataset_config.augmentations) > 0:
self.has_augmentations = True
augmentations = [Augments(**aug) for aug in dataset_config.augmentations]
augmentation_list = []
for aug in augmentations:
# make sure method name is valid
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
# get the method
method = getattr(A, aug.method_name)
# add the method to the list
augmentation_list.append(method(**aug.params))
self.aug_transform = A.Compose(augmentation_list)
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
# save the original tensor
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
open_cv_image = np.array(img)
# Convert RGB to BGR
open_cv_image = open_cv_image[:, :, ::-1].copy()
# apply augmentations
augmented = self.aug_transform(image=open_cv_image)["image"]
# convert back to RGB tensor
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
# convert to PIL image
augmented = Image.fromarray(augmented)
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
return augmented_tensor
def cleanup_control(self: 'FileItemDTO'):
self.unaugmented_tensor = None
class MaskFileItemDTOMixin: class MaskFileItemDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs): def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'): if hasattr(super(), '__init__'):
@@ -558,6 +627,7 @@ class MaskFileItemDTOMixin:
def cleanup_mask(self: 'FileItemDTO'): def cleanup_mask(self: 'FileItemDTO'):
self.mask_tensor = None self.mask_tensor = None
class PoiFileItemDTOMixin: class PoiFileItemDTOMixin:
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject # Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
# items in the poi will always be inside the image when random cropping # items in the poi will always be inside the image when random cropping
@@ -799,7 +869,6 @@ class LatentCachingMixin:
del latent del latent
del file_item.tensor del file_item.tensor
flush(garbage_collect=False) flush(garbage_collect=False)
file_item.is_latent_cached = True file_item.is_latent_cached = True
# flush every 100 # flush every 100