diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index f17747a2..d7289d84 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1,6 +1,8 @@ from collections import OrderedDict from typing import Union from diffusers import T2IAdapter + +from toolkit import train_tools from toolkit.basic import value_map from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.ip_adapter import IPAdapter @@ -30,6 +32,7 @@ class SDTrainer(BaseSDTrainProcess): super().__init__(process_id, job, config, **kwargs) self.assistant_adapter: Union['T2IAdapter', None] self.do_prior_prediction = False + self.target_class = self.get_conf('target_class', '') if self.train_config.inverted_mask_prior: self.do_prior_prediction = True @@ -171,6 +174,99 @@ class SDTrainer(BaseSDTrainProcess): def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): return batch + def get_guided_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + **kwargs + ): + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + # target class is unconditional + target_class_embeds = self.sd.encode_prompt(self.target_class).detach() + + if batch.unconditional_latents is not None: + # do the unconditional prediction here instead of a prior prediction + unconditional_noisy_latents = self.sd.noise_scheduler.add_noise(batch.unconditional_latents, noise, + timesteps) + + was_network_active = self.network.is_active + self.network.is_active = False + self.sd.unet.eval() + + guidance_scale = 1.0 + + def cfg(uncon, con): + return uncon + guidance_scale * ( + con - uncon + ) + + target_conditional = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ).detach() + + target_unconditional = self.sd.predict_noise( + latents=unconditional_noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=target_class_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ).detach() + + neutral_latents = (noisy_latents + unconditional_noisy_latents) / 2.0 + + target_noise = cfg(target_unconditional, target_conditional) + # latents = self.noise_scheduler.step(target_noise, timesteps, noisy_latents, return_dict=False)[0] + + # target_pred = target_pred - noisy_latents + (unconditional_noisy_latents - noise) + + # target_noise_res = noisy_latents - unconditional_noisy_latents + + # target_pred = cfg(unconditional_noisy_latents, target_pred) + # target_pred = target_pred + target_noise_res + + self.network.is_active = True + self.sd.unet.train() + + prediction = self.sd.predict_noise( + latents=neutral_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + + # prediction_res = target_pred - prediction + + + # prediction = cfg(prediction, target_pred) + + loss = torch.nn.functional.mse_loss(prediction.float(), target_noise.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.learnable_snr_gos: + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001: + # add snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + return loss + def get_prior_prediction( self, noisy_latents: torch.Tensor, @@ -369,8 +465,6 @@ class SDTrainer(BaseSDTrainProcess): else: prompt_2_list = [prompts_2] - - for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip( noisy_latents_list, noise_list, @@ -386,8 +480,9 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('encode_prompt'): if grad_on_text_encoder: with torch.set_grad_enabled(True): - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to( - # conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( + conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, + long_prompts=True).to( + # conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( self.device_torch, dtype=dtype) else: @@ -398,8 +493,9 @@ class SDTrainer(BaseSDTrainProcess): te.eval() else: self.sd.text_encoder.eval() - conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=True).to( - # conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( + conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, + long_prompts=True).to( + # conditional_embeds = self.sd.encode_prompt(conditioned_prompts, prompt_2, long_prompts=False).to( self.device_torch, dtype=dtype) @@ -450,27 +546,42 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) self.before_unet_predict() - with self.timer('predict_unet'): - noise_pred = self.sd.predict_noise( - latents=noisy_latents.to(self.device_torch, dtype=dtype), - conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), - timestep=timesteps, - guidance_scale=1.0, - **pred_kwargs - ) - self.after_unet_predict() - - with self.timer('calculate_loss'): - noise = noise.to(self.device_torch, dtype=dtype).detach() - loss = self.calculate_loss( - noise_pred=noise_pred, - noise=noise, + # do a prior pred if we have an unconditional image, we will swap out the giadance later + if batch.unconditional_latents is not None: + # do guided loss + loss = self.get_guided_loss( noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, timesteps=timesteps, + pred_kwargs=pred_kwargs, batch=batch, - mask_multiplier=mask_multiplier, - prior_pred=prior_pred, + noise=noise, ) + + else: + with self.timer('predict_unet'): + noise_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + timestep=timesteps, + guidance_scale=1.0, + **pred_kwargs + ) + self.after_unet_predict() + + with self.timer('calculate_loss'): + noise = noise.to(self.device_torch, dtype=dtype).detach() + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) # check if nan if torch.isnan(loss): raise ValueError("loss is nan") diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index bbf5880f..4b2c2430 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -574,7 +574,12 @@ class BaseSDTrainProcess(BaseTrainProcess): else: latents = self.sd.encode_images(imgs) batch.latents = latents - # flush() # todo check performance removing this + + if batch.unconditional_tensor is not None and batch.unconditional_latents is None: + unconditional_imgs = batch.unconditional_tensor + unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype) + unconditional_latents = self.sd.encode_images(unconditional_imgs) + batch.unconditional_latents = unconditional_latents unaugmented_latents = None if self.train_config.loss_target == 'differential_noise': @@ -655,6 +660,10 @@ class BaseSDTrainProcess(BaseTrainProcess): noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, timesteps) + # determine scaled noise + # todo do we need to scale this or does it always predict full intensity + # noise = noisy_latents - latents + # https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77 if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented': sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 51b02065..d03eb625 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -351,6 +351,7 @@ class DatasetConfig: self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) + self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1 self.poi: Union[str, None] = kwargs.get('poi', diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index 4dea2439..c6081467 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -7,7 +7,8 @@ from PIL.ImageOps import exif_transpose from toolkit import image_utils from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ - ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin + ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ + UnconditionalFileItemDTOMixin if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig @@ -29,6 +30,7 @@ class FileItemDTO( ControlFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, + UnconditionalFileItemDTOMixin, PoiFileItemDTOMixin, ArgBreakMixin, ): @@ -70,6 +72,7 @@ class FileItemDTO( self.cleanup_latent() self.cleanup_control() self.cleanup_mask() + self.cleanup_unconditional() class DataLoaderBatchDTO: @@ -82,6 +85,8 @@ class DataLoaderBatchDTO: self.control_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None self.unaugmented_tensor: Union[torch.Tensor, None] = None + self.unconditional_tensor: Union[torch.Tensor, None] = None + self.unconditional_latents: Union[torch.Tensor, None] = None self.sigmas: Union[torch.Tensor, None] = None # can be added elseware and passed along training code if not is_latents_cached: # only return a tensor if latents are not cached @@ -138,6 +143,22 @@ class DataLoaderBatchDTO: else: unaugmented_tensor.append(x.unaugmented_tensor) self.unaugmented_tensor = torch.cat([x.unsqueeze(0) for x in unaugmented_tensor]) + + # add unconditional tensors + if any([x.unconditional_tensor is not None for x in self.file_items]): + # find one to use as a base + base_unconditional_tensor = None + for x in self.file_items: + if x.unaugmented_tensor is not None: + base_unconditional_tensor = x.unconditional_tensor + break + unconditional_tensor = [] + for x in self.file_items: + if x.unconditional_tensor is None: + unconditional_tensor.append(torch.zeros_like(base_unconditional_tensor)) + else: + unconditional_tensor.append(x.unconditional_tensor) + self.unconditional_tensor = torch.cat([x.unsqueeze(0) for x in unconditional_tensor]) except Exception as e: print(e) raise e diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 52fe3a64..394dac91 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -351,6 +351,8 @@ class ImageProcessingDTOMixin: self.load_control_image() if self.has_mask_image: self.load_mask_image() + if self.has_unconditional: + self.load_unconditional_image() return try: img = Image.open(self.path) @@ -442,6 +444,8 @@ class ImageProcessingDTOMixin: self.load_control_image() if self.has_mask_image: self.load_mask_image() + if self.has_unconditional: + self.load_unconditional_image() class ControlFileItemDTOMixin: @@ -661,6 +665,80 @@ class MaskFileItemDTOMixin: self.mask_tensor = None +class UnconditionalFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_unconditional = False + self.unconditional_path: Union[str, None] = None + self.unconditional_tensor: Union[torch.Tensor, None] = None + self.unconditional_latent: Union[torch.Tensor, None] = None + self.unconditional_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + + if dataset_config.unconditional_path is not None: + # we are using control images + img_path = kwargs.get('path', None) + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext)): + self.unconditional_path = os.path.join(dataset_config.unconditional_path, file_name_no_ext + ext) + self.has_unconditional = True + break + + def load_unconditional_image(self: 'FileItemDTO'): + try: + img = Image.open(self.unconditional_path) + img = exif_transpose(img) + except Exception as e: + print(f"Error: {e}") + print(f"Error loading image: {self.mask_path}") + + img = img.convert('RGB') + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + + if self.flip_x: + # do a flip + img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Unconditional images are not supported for non-bucket datasets") + + self.unconditional_tensor = self.unconditional_transforms(img) + + def cleanup_unconditional(self: 'FileItemDTO'): + self.unconditional_tensor = None + self.unconditional_latent = None + + class PoiFileItemDTOMixin: # Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject # items in the poi will always be inside the image when random cropping