mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-03 09:39:58 +00:00
Allow augmentations and targeting different loss types fron the config file
This commit is contained in:
@@ -1,15 +1,9 @@
|
||||
import os.path
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
import gc
|
||||
@@ -24,7 +18,6 @@ def flush():
|
||||
|
||||
|
||||
adapter_transforms = transforms.Compose([
|
||||
# transforms.PILToTensor(),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
@@ -62,6 +55,77 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.sd.vae.to('cpu')
|
||||
flush()
|
||||
|
||||
# you can expand these in a child class to make customization easier
|
||||
def calculate_loss(
|
||||
self,
|
||||
noise_pred: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
noisy_latents: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
batch: 'DataLoaderBatchDTO',
|
||||
mask_multiplier: Union[torch.Tensor, float] = 1.0,
|
||||
control_pred: Union[torch.Tensor, None] = None,
|
||||
**kwargs
|
||||
):
|
||||
loss_target = self.train_config.loss_target
|
||||
# add latents and unaug latents
|
||||
if control_pred is not None:
|
||||
# matching adapter prediction
|
||||
target = control_pred
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
pred = noise_pred
|
||||
|
||||
ignore_snr = False
|
||||
|
||||
if loss_target == 'source' or loss_target == 'unaugmented':
|
||||
# ignore_snr = True
|
||||
if batch.sigmas is None:
|
||||
raise ValueError("Batch sigmas is None. This should not happen")
|
||||
|
||||
# src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190
|
||||
denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents
|
||||
weighing = batch.sigmas ** -2.0
|
||||
if loss_target == 'source':
|
||||
# denoise the latent and compare to the latent in the batch
|
||||
target = batch.latents
|
||||
elif loss_target == 'unaugmented':
|
||||
# we have to encode images into latents for now
|
||||
# we also denoise as the unaugmented tensor is not a noisy diffirental
|
||||
with torch.no_grad():
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
target = unaugmented_latents.detach()
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if self.sd.noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = target # we are computing loss against denoise latents
|
||||
elif self.sd.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}")
|
||||
|
||||
# mse loss without reduction
|
||||
loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2)
|
||||
loss = loss_per_element
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
@@ -251,26 +315,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
with self.timer('calculate_loss'):
|
||||
noise = noise.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
if control_pred is not None:
|
||||
# matching adapter prediction
|
||||
target = control_pred
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||
# add min_snr_gamma
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
loss = self.calculate_loss(
|
||||
noise_pred=noise_pred,
|
||||
noise=noise,
|
||||
noisy_latents=noisy_latents,
|
||||
timesteps=timesteps,
|
||||
batch=batch,
|
||||
mask_multiplier=mask_multiplier,
|
||||
control_pred=control_pred,
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
@@ -442,6 +442,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# override in subclass
|
||||
return params
|
||||
|
||||
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
||||
sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device, dtype=dtype)
|
||||
schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device)
|
||||
timesteps = timesteps.to(self.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
|
||||
with torch.no_grad():
|
||||
with self.timer('prepare_prompt'):
|
||||
@@ -483,6 +495,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch.latents = latents
|
||||
# flush() # todo check performance removing this
|
||||
|
||||
unaugmented_latents = None
|
||||
if self.train_config.loss_target == 'differential_noise':
|
||||
# we determine noise from the differential of the latents
|
||||
unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor)
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
with self.timer('prepare_noise'):
|
||||
@@ -516,7 +533,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.train_config.max_denoising_steps
|
||||
)
|
||||
timesteps = timesteps.long().clamp(
|
||||
self.train_config.min_denoising_steps,
|
||||
self.train_config.min_denoising_steps + 1,
|
||||
self.train_config.max_denoising_steps - 1
|
||||
)
|
||||
|
||||
@@ -539,6 +556,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
noise_offset=self.train_config.noise_offset
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
if self.train_config.loss_target == 'differential_noise':
|
||||
differential = latents - unaugmented_latents
|
||||
# add noise to differential
|
||||
# noise = noise + differential
|
||||
noise = noise + (differential * 0.5)
|
||||
# noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max())
|
||||
latents = unaugmented_latents
|
||||
|
||||
noise_multiplier = self.train_config.noise_multiplier
|
||||
|
||||
noise = noise * noise_multiplier
|
||||
@@ -549,6 +574,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77
|
||||
if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented':
|
||||
sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# add it to the batch
|
||||
batch.sigmas = sigmas
|
||||
# todo is this for sdxl? find out where this came from originally
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noisy_latents = noisy_latents.detach()
|
||||
|
||||
@@ -88,7 +88,7 @@ class EmbeddingConfig:
|
||||
|
||||
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
|
||||
LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise']
|
||||
|
||||
class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
@@ -127,6 +127,7 @@ class TrainConfig:
|
||||
|
||||
match_adapter_assist = kwargs.get('match_adapter_assist', False)
|
||||
self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0)
|
||||
self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented,
|
||||
|
||||
# legacy
|
||||
if match_adapter_assist and self.match_adapter_chance == 0.0:
|
||||
@@ -258,7 +259,13 @@ class DatasetConfig:
|
||||
# cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory
|
||||
self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False)
|
||||
|
||||
if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk):
|
||||
# https://albumentations.ai/docs/api_reference/augmentations/transforms
|
||||
# augmentations are returned as a separate image and cannot currently be cached
|
||||
self.augmentations: List[dict] = kwargs.get('augmentations', None)
|
||||
|
||||
has_augmentations = self.augmentations is not None and len(self.augmentations) > 0
|
||||
|
||||
if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk):
|
||||
print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False")
|
||||
self.cache_latents = False
|
||||
self.cache_latents_to_disk = False
|
||||
|
||||
@@ -16,7 +16,7 @@ import albumentations as A
|
||||
|
||||
from toolkit.buckets import get_bucket_for_image_size, BucketResolution
|
||||
from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -111,21 +111,7 @@ class ImageDataset(Dataset, CaptionMixin):
|
||||
return img
|
||||
|
||||
|
||||
class Augments:
|
||||
def __init__(self, **kwargs):
|
||||
self.method_name = kwargs.get('method', None)
|
||||
self.params = kwargs.get('params', {})
|
||||
|
||||
# convert kwargs enums for cv2
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, str):
|
||||
# split the string
|
||||
split_string = value.split('.')
|
||||
if len(split_string) == 2 and split_string[0] == 'cv2':
|
||||
if hasattr(cv2, split_string[1]):
|
||||
self.params[key] = getattr(cv2, split_string[1].upper())
|
||||
else:
|
||||
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
|
||||
|
||||
|
||||
class AugmentedImageDataset(ImageDataset):
|
||||
|
||||
@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -28,6 +28,7 @@ class FileItemDTO(
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
AugmentationFileItemDTOMixin,
|
||||
PoiFileItemDTOMixin,
|
||||
ArgBreakMixin,
|
||||
):
|
||||
@@ -80,6 +81,8 @@ class DataLoaderBatchDTO:
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
@@ -119,11 +122,26 @@ class DataLoaderBatchDTO:
|
||||
else:
|
||||
mask_tensors.append(x.mask_tensor)
|
||||
self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors])
|
||||
|
||||
# add unaugmented tensors for ones with augments
|
||||
if any([x.unaugmented_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_unaugmented_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.unaugmented_tensor is not None:
|
||||
base_unaugmented_tensor = x.unaugmented_tensor
|
||||
break
|
||||
unaugmented_tensor = []
|
||||
for x in self.file_items:
|
||||
if x.unaugmented_tensor is None:
|
||||
unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor))
|
||||
else:
|
||||
unaugmented_tensor.append(x.unaugmented_tensor)
|
||||
self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
|
||||
def get_is_reg_list(self):
|
||||
return [x.is_reg for x in self.file_items]
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import random
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, List, Dict, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
@@ -19,6 +20,7 @@ from toolkit.prompt_utils import inject_trigger_into_prompt
|
||||
from torchvision import transforms
|
||||
from PIL import Image, ImageFilter
|
||||
from PIL.ImageOps import exif_transpose
|
||||
import albumentations as A
|
||||
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -26,8 +28,25 @@ if TYPE_CHECKING:
|
||||
from toolkit.data_loader import AiToolkitDataset
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
|
||||
# def get_associated_caption_from_img_path(img_path):
|
||||
|
||||
class Augments:
|
||||
def __init__(self, **kwargs):
|
||||
self.method_name = kwargs.get('method', None)
|
||||
self.params = kwargs.get('params', {})
|
||||
|
||||
# convert kwargs enums for cv2
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, str):
|
||||
# split the string
|
||||
split_string = value.split('.')
|
||||
if len(split_string) == 2 and split_string[0] == 'cv2':
|
||||
if hasattr(cv2, split_string[1]):
|
||||
self.params[key] = getattr(cv2, split_string[1].upper())
|
||||
else:
|
||||
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
|
||||
|
||||
|
||||
transforms_dict = {
|
||||
'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
|
||||
@@ -195,7 +214,6 @@ class BucketsMixin:
|
||||
print(f'{len(self.buckets)} buckets made')
|
||||
|
||||
|
||||
|
||||
class CaptionProcessingDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
@@ -392,7 +410,10 @@ class ImageProcessingDTOMixin:
|
||||
if augment in transforms_dict:
|
||||
img = transforms_dict[augment](img)
|
||||
|
||||
if transform:
|
||||
if self.has_augmentations:
|
||||
# augmentations handles transforms
|
||||
img = self.augment_image(img, transform=transform)
|
||||
elif transform:
|
||||
img = transform(img)
|
||||
|
||||
self.tensor = img
|
||||
@@ -468,6 +489,54 @@ class ControlFileItemDTOMixin:
|
||||
self.control_tensor = None
|
||||
|
||||
|
||||
class AugmentationFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.has_augmentations = False
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
# self.augmentations: Union[None, List[Augments]] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.augmentations is not None and len(dataset_config.augmentations) > 0:
|
||||
self.has_augmentations = True
|
||||
augmentations = [Augments(**aug) for aug in dataset_config.augmentations]
|
||||
augmentation_list = []
|
||||
for aug in augmentations:
|
||||
# make sure method name is valid
|
||||
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
|
||||
# get the method
|
||||
method = getattr(A, aug.method_name)
|
||||
# add the method to the list
|
||||
augmentation_list.append(method(**aug.params))
|
||||
|
||||
self.aug_transform = A.Compose(augmentation_list)
|
||||
|
||||
def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ):
|
||||
|
||||
# save the original tensor
|
||||
self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img)
|
||||
|
||||
open_cv_image = np.array(img)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
# apply augmentations
|
||||
augmented = self.aug_transform(image=open_cv_image)["image"]
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# convert to PIL image
|
||||
augmented = Image.fromarray(augmented)
|
||||
|
||||
augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented)
|
||||
|
||||
return augmented_tensor
|
||||
|
||||
def cleanup_control(self: 'FileItemDTO'):
|
||||
self.unaugmented_tensor = None
|
||||
|
||||
|
||||
class MaskFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
@@ -558,6 +627,7 @@ class MaskFileItemDTOMixin:
|
||||
def cleanup_mask(self: 'FileItemDTO'):
|
||||
self.mask_tensor = None
|
||||
|
||||
|
||||
class PoiFileItemDTOMixin:
|
||||
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
|
||||
# items in the poi will always be inside the image when random cropping
|
||||
@@ -799,7 +869,6 @@ class LatentCachingMixin:
|
||||
del latent
|
||||
del file_item.tensor
|
||||
|
||||
|
||||
flush(garbage_collect=False)
|
||||
file_item.is_latent_cached = True
|
||||
# flush every 100
|
||||
|
||||
Reference in New Issue
Block a user