diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ca14ee89..56d0bae9 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1,15 +1,9 @@ -import os.path from collections import OrderedDict from typing import Union - -from PIL import Image from diffusers import T2IAdapter -from torch.utils.data import DataLoader - from toolkit.basic import value_map from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.ip_adapter import IPAdapter -from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.train_tools import get_torch_dtype, apply_snr_weight import gc @@ -24,7 +18,6 @@ def flush(): adapter_transforms = transforms.Compose([ - # transforms.PILToTensor(), transforms.ToTensor(), ]) @@ -62,6 +55,77 @@ class SDTrainer(BaseSDTrainProcess): self.sd.vae.to('cpu') flush() + # you can expand these in a child class to make customization easier + def calculate_loss( + self, + noise_pred: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + batch: 'DataLoaderBatchDTO', + mask_multiplier: Union[torch.Tensor, float] = 1.0, + control_pred: Union[torch.Tensor, None] = None, + **kwargs + ): + loss_target = self.train_config.loss_target + # add latents and unaug latents + if control_pred is not None: + # matching adapter prediction + target = control_pred + elif self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) + else: + target = noise + + pred = noise_pred + + ignore_snr = False + + if loss_target == 'source' or loss_target == 'unaugmented': + # ignore_snr = True + if batch.sigmas is None: + raise ValueError("Batch sigmas is None. This should not happen") + + # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190 + denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents + weighing = batch.sigmas ** -2.0 + if loss_target == 'source': + # denoise the latent and compare to the latent in the batch + target = batch.latents + elif loss_target == 'unaugmented': + # we have to encode images into latents for now + # we also denoise as the unaugmented tensor is not a noisy diffirental + with torch.no_grad(): + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) + target = unaugmented_latents.detach() + + # Get the target for loss depending on the prediction type + if self.sd.noise_scheduler.config.prediction_type == "epsilon": + target = target # we are computing loss against denoise latents + elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": + target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") + + # mse loss without reduction + loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) + loss = loss_per_element + else: + loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + + # multiply by our mask + loss = loss * mask_multiplier + + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + return loss + def hook_train_loop(self, batch): self.timer.start('preprocess_batch') @@ -251,26 +315,15 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('calculate_loss'): noise = noise.to(self.device_torch, dtype=dtype).detach() - - if control_pred is not None: - # matching adapter prediction - target = control_pred - elif self.sd.prediction_type == 'v_prediction': - # v-parameterization training - target = self.sd.noise_scheduler.get_velocity(noisy_latents, noise, timesteps) - else: - target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - # multiply by our mask - loss = loss * mask_multiplier - - loss = loss.mean([1, 2, 3]) - - if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: - # add min_snr_gamma - loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) - - loss = loss.mean() + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + control_pred=control_pred, + ) # check if nan if torch.isnan(loss): raise ValueError("loss is nan") diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index bbb9bb05..1476aaec 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -442,6 +442,18 @@ class BaseSDTrainProcess(BaseTrainProcess): # override in subclass return params + def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32): + sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device, dtype=dtype) + schedule_timesteps = self.sd.noise_scheduler.timesteps.to(self.device) + timesteps = timesteps.to(self.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): with torch.no_grad(): with self.timer('prepare_prompt'): @@ -483,6 +495,11 @@ class BaseSDTrainProcess(BaseTrainProcess): batch.latents = latents # flush() # todo check performance removing this + unaugmented_latents = None + if self.train_config.loss_target == 'differential_noise': + # we determine noise from the differential of the latents + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) + batch_size = latents.shape[0] with self.timer('prepare_noise'): @@ -516,7 +533,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.train_config.max_denoising_steps ) timesteps = timesteps.long().clamp( - self.train_config.min_denoising_steps, + self.train_config.min_denoising_steps + 1, self.train_config.max_denoising_steps - 1 ) @@ -539,6 +556,14 @@ class BaseSDTrainProcess(BaseTrainProcess): noise_offset=self.train_config.noise_offset ).to(self.device_torch, dtype=dtype) + if self.train_config.loss_target == 'differential_noise': + differential = latents - unaugmented_latents + # add noise to differential + # noise = noise + differential + noise = noise + (differential * 0.5) + # noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max()) + latents = unaugmented_latents + noise_multiplier = self.train_config.noise_multiplier noise = noise * noise_multiplier @@ -549,6 +574,14 @@ class BaseSDTrainProcess(BaseTrainProcess): noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps) + # https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77 + if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented': + sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) + # add it to the batch + batch.sigmas = sigmas + # todo is this for sdxl? find out where this came from originally + # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + # remove grads for these noisy_latents.requires_grad = False noisy_latents = noisy_latents.detach() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 60bfe601..f79007aa 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -88,7 +88,7 @@ class EmbeddingConfig: ContentOrStyleType = Literal['balanced', 'style', 'content'] - +LossTarget = Literal['noise', 'source', 'unaugmented', 'differential_noise'] class TrainConfig: def __init__(self, **kwargs): @@ -127,6 +127,7 @@ class TrainConfig: match_adapter_assist = kwargs.get('match_adapter_assist', False) self.match_adapter_chance = kwargs.get('match_adapter_chance', 0.0) + self.loss_target: LossTarget = kwargs.get('loss_target', 'noise') # noise, source, unaugmented, # legacy if match_adapter_assist and self.match_adapter_chance == 0.0: @@ -258,7 +259,13 @@ class DatasetConfig: # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False) - if len(self.augments) > 0 and (self.cache_latents or self.cache_latents_to_disk): + # https://albumentations.ai/docs/api_reference/augmentations/transforms + # augmentations are returned as a separate image and cannot currently be cached + self.augmentations: List[dict] = kwargs.get('augmentations', None) + + has_augmentations = self.augmentations is not None and len(self.augmentations) > 0 + + if (len(self.augments) > 0 or has_augmentations) and (self.cache_latents or self.cache_latents_to_disk): print(f"WARNING: Augments are not supported with caching latents. Setting cache_latents to False") self.cache_latents = False self.cache_latents_to_disk = False diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 7f5cfb62..504094d8 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -16,7 +16,7 @@ import albumentations as A from toolkit.buckets import get_bucket_for_image_size, BucketResolution from toolkit.config_modules import DatasetConfig, preprocess_dataset_raw_config -from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin +from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin, LatentCachingMixin, Augments from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO if TYPE_CHECKING: @@ -111,21 +111,7 @@ class ImageDataset(Dataset, CaptionMixin): return img -class Augments: - def __init__(self, **kwargs): - self.method_name = kwargs.get('method', None) - self.params = kwargs.get('params', {}) - # convert kwargs enums for cv2 - for key, value in self.params.items(): - if isinstance(value, str): - # split the string - split_string = value.split('.') - if len(split_string) == 2 and split_string[0] == 'cv2': - if hasattr(cv2, split_string[1]): - self.params[key] = getattr(cv2, split_string[1].upper()) - else: - raise ValueError(f"invalid cv2 enum: {split_string[1]}") class AugmentedImageDataset(ImageDataset): diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index a36df5d8..32012dc0 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose from toolkit import image_utils from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ - ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin + ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig @@ -28,6 +28,7 @@ class FileItemDTO( ImageProcessingDTOMixin, ControlFileItemDTOMixin, MaskFileItemDTOMixin, + AugmentationFileItemDTOMixin, PoiFileItemDTOMixin, ArgBreakMixin, ): @@ -80,6 +81,8 @@ class DataLoaderBatchDTO: self.latents: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None + self.unaugmented_tensor: Union[torch.Tensor, None] = None + self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code if not is_latents_cached: # only return a tensor if latents are not cached self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) @@ -119,11 +122,26 @@ class DataLoaderBatchDTO: else: mask_tensors.append(x.mask_tensor) self.mask_tensor = torch.cat([x.unsqueeze(0) for x in mask_tensors]) + + # add unaugmented tensors for ones with augments + if any([x.unaugmented_tensor is not None for x in self.file_items]): + # find one to use as a base + base_unaugmented_tensor = None + for x in self.file_items: + if x.unaugmented_tensor is not None: + base_unaugmented_tensor = x.unaugmented_tensor + break + unaugmented_tensor = [] + for x in self.file_items: + if x.unaugmented_tensor is None: + unaugmented_tensor.append(torch.zeros_like(base_unaugmented_tensor)) + else: + unaugmented_tensor.append(x.unaugmented_tensor) + self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor]) except Exception as e: print(e) raise e - def get_is_reg_list(self): return [x.is_reg for x in self.file_items] diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 444a3c32..3acc5c15 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -7,6 +7,7 @@ import random from collections import OrderedDict from typing import TYPE_CHECKING, List, Dict, Union +import cv2 import numpy as np import torch from safetensors.torch import load_file, save_file @@ -19,6 +20,7 @@ from toolkit.prompt_utils import inject_trigger_into_prompt from torchvision import transforms from PIL import Image, ImageFilter from PIL.ImageOps import exif_transpose +import albumentations as A from toolkit.train_tools import get_torch_dtype @@ -26,8 +28,25 @@ if TYPE_CHECKING: from toolkit.data_loader import AiToolkitDataset from toolkit.data_transfer_object.data_loader import FileItemDTO + # def get_associated_caption_from_img_path(img_path): +class Augments: + def __init__(self, **kwargs): + self.method_name = kwargs.get('method', None) + self.params = kwargs.get('params', {}) + + # convert kwargs enums for cv2 + for key, value in self.params.items(): + if isinstance(value, str): + # split the string + split_string = value.split('.') + if len(split_string) == 2 and split_string[0] == 'cv2': + if hasattr(cv2, split_string[1]): + self.params[key] = getattr(cv2, split_string[1].upper()) + else: + raise ValueError(f"invalid cv2 enum: {split_string[1]}") + transforms_dict = { 'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03), @@ -195,7 +214,6 @@ class BucketsMixin: print(f'{len(self.buckets)} buckets made') - class CaptionProcessingDTOMixin: def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): @@ -392,7 +410,10 @@ class ImageProcessingDTOMixin: if augment in transforms_dict: img = transforms_dict[augment](img) - if transform: + if self.has_augmentations: + # augmentations handles transforms + img = self.augment_image(img, transform=transform) + elif transform: img = transform(img) self.tensor = img @@ -468,6 +489,54 @@ class ControlFileItemDTOMixin: self.control_tensor = None +class AugmentationFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_augmentations = False + self.unaugmented_tensor: Union[torch.Tensor, None] = None + # self.augmentations: Union[None, List[Augments]] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + if dataset_config.augmentations is not None and len(dataset_config.augmentations) > 0: + self.has_augmentations = True + augmentations = [Augments(**aug) for aug in dataset_config.augmentations] + augmentation_list = [] + for aug in augmentations: + # make sure method name is valid + assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" + # get the method + method = getattr(A, aug.method_name) + # add the method to the list + augmentation_list.append(method(**aug.params)) + + self.aug_transform = A.Compose(augmentation_list) + + def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ): + + # save the original tensor + self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img) + + open_cv_image = np.array(img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # apply augmentations + augmented = self.aug_transform(image=open_cv_image)["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented) + + return augmented_tensor + + def cleanup_control(self: 'FileItemDTO'): + self.unaugmented_tensor = None + + class MaskFileItemDTOMixin: def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): @@ -558,6 +627,7 @@ class MaskFileItemDTOMixin: def cleanup_mask(self: 'FileItemDTO'): self.mask_tensor = None + class PoiFileItemDTOMixin: # Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject # items in the poi will always be inside the image when random cropping @@ -799,7 +869,6 @@ class LatentCachingMixin: del latent del file_item.tensor - flush(garbage_collect=False) file_item.is_latent_cached = True # flush every 100