From a48c9aba8d277d13dad0c0a4d69f90406065df4d Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 29 Aug 2024 12:34:18 -0600 Subject: [PATCH] Created a v2 trainer and moved all the training logic to single torch model so it can can be run in parallel --- extensions_built_in/sd_trainer/TrainerV2.py | 238 ++++ extensions_built_in/sd_trainer/__init__.py | 19 +- toolkit/models/unified_training_model.py | 1418 +++++++++++++++++++ 3 files changed, 1674 insertions(+), 1 deletion(-) create mode 100644 extensions_built_in/sd_trainer/TrainerV2.py create mode 100644 toolkit/models/unified_training_model.py diff --git a/extensions_built_in/sd_trainer/TrainerV2.py b/extensions_built_in/sd_trainer/TrainerV2.py new file mode 100644 index 00000000..8b4a1ebe --- /dev/null +++ b/extensions_built_in/sd_trainer/TrainerV2.py @@ -0,0 +1,238 @@ +import os +import random +from collections import OrderedDict +from typing import Union, List + +import numpy as np +from diffusers import T2IAdapter, ControlNetModel + +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.data_loader import get_dataloader_datasets +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.stable_diffusion_model import BlankNetwork +from toolkit.train_tools import get_torch_dtype, add_all_snr_to_noise_scheduler +import gc +import torch +from jobs.process import BaseSDTrainProcess +from torchvision import transforms +from diffusers import EMAModel +import math +from toolkit.train_tools import precondition_model_outputs_flow_match +from toolkit.models.unified_training_model import UnifiedTrainingModel + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +adapter_transforms = transforms.Compose([ + transforms.ToTensor(), +]) + + +class TrainerV2(BaseSDTrainProcess): + + def __init__(self, process_id: int, job, config: OrderedDict, **kwargs): + super().__init__(process_id, job, config, **kwargs) + self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] + self.do_prior_prediction = False + self.do_long_prompts = False + self.do_guided_loss = False + + self._clip_image_embeds_unconditional: Union[List[str], None] = None + self.negative_prompt_pool: Union[List[str], None] = None + self.batch_negative_prompt: Union[List[str], None] = None + + self.scaler = torch.cuda.amp.GradScaler() + + self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" + + self.do_grad_scale = True + if self.is_fine_tuning: + self.do_grad_scale = False + if self.adapter_config is not None: + if self.adapter_config.train: + self.do_grad_scale = False + + if self.train_config.dtype in ["fp16", "float16"]: + # patch the scaler to allow fp16 training + org_unscale_grads = self.scaler._unscale_grads_ + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + self.scaler._unscale_grads_ = _unscale_grads_replacer + + self.unified_training_model: UnifiedTrainingModel = None + + + def before_model_load(self): + pass + + def before_dataset_load(self): + self.assistant_adapter = None + # get adapter assistant if one is set + if self.train_config.adapter_assist_name_or_path is not None: + adapter_path = self.train_config.adapter_assist_name_or_path + + if self.train_config.adapter_assist_type == "t2i": + # dont name this adapter since we are not training it + self.assistant_adapter = T2IAdapter.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch) + elif self.train_config.adapter_assist_type == "control_net": + self.assistant_adapter = ControlNetModel.from_pretrained( + adapter_path, torch_dtype=get_torch_dtype(self.train_config.dtype) + ).to(self.device_torch, dtype=get_torch_dtype(self.train_config.dtype)) + else: + raise ValueError(f"Unknown adapter assist type {self.train_config.adapter_assist_type}") + + self.assistant_adapter.eval() + self.assistant_adapter.requires_grad_(False) + flush() + if self.train_config.train_turbo and self.train_config.show_turbo_outputs: + raise ValueError("Turbo outputs are not supported on MultiGPUSDTrainer") + + def hook_before_train_loop(self): + # if self.train_config.do_prior_divergence: + # self.do_prior_prediction = True + # move vae to device if we did not cache latents + if not self.is_latents_cached: + self.sd.vae.eval() + self.sd.vae.to(self.device_torch) + else: + # offload it. Already cached + self.sd.vae.to('cpu') + flush() + add_all_snr_to_noise_scheduler(self.sd.noise_scheduler, self.device_torch) + if self.adapter is not None: + self.adapter.to(self.device_torch) + + # check if we have regs and using adapter and caching clip embeddings + has_reg = self.datasets_reg is not None and len(self.datasets_reg) > 0 + is_caching_clip_embeddings = self.datasets is not None and any([self.datasets[i].cache_clip_vision_to_disk for i in range(len(self.datasets))]) + + if has_reg and is_caching_clip_embeddings: + # we need a list of unconditional clip image embeds from other datasets to handle regs + unconditional_clip_image_embeds = [] + datasets = get_dataloader_datasets(self.data_loader) + for i in range(len(datasets)): + unconditional_clip_image_embeds += datasets[i].clip_vision_unconditional_cache + + if len(unconditional_clip_image_embeds) == 0: + raise ValueError("No unconditional clip image embeds found. This should not happen") + + self._clip_image_embeds_unconditional = unconditional_clip_image_embeds + + if self.train_config.negative_prompt is not None: + raise ValueError("Negative prompt is not supported on MultiGPUSDTrainer") + + # setup the unified training model + self.unified_training_model = UnifiedTrainingModel( + sd=self.sd, + network=self.network, + adapter=self.adapter, + assistant_adapter=self.assistant_adapter, + train_config=self.train_config, + adapter_config=self.adapter_config, + embedding=self.embedding, + timer=self.timer, + trigger_word=self.trigger_word, + ) + + # call parent hook + super().hook_before_train_loop() + + # you can expand these in a child class to make customization easier + + def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): + return self.unified_training_model.preprocess_batch(batch) + + + def before_unet_predict(self): + pass + + def after_unet_predict(self): + pass + + def end_of_training_loop(self): + pass + + def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): + + self.optimizer.zero_grad(set_to_none=True) + + loss = self.unified_training_model(batch) + + if torch.isnan(loss): + print("loss is nan") + loss = torch.zeros_like(loss).requires_grad_(True) + + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + with (network): + with self.timer('backward'): + # todo we have multiplier seperated. works for now as res are not in same batch, but need to change + # IMPORTANT if gradient checkpointing do not leave with network when doing backward + # it will destroy the gradients. This is because the network is a context manager + # and will change the multipliers back to 0.0 when exiting. They will be + # 0.0 for the backward pass and the gradients will be 0.0 + # I spent weeks on fighting this. DON'T DO IT + # with fsdp_overlap_step_with_backward(): + # if self.is_bfloat: + # loss.backward() + # else: + if not self.do_grad_scale: + loss.backward() + else: + self.scaler.scale(loss).backward() + + if not self.is_grad_accumulation_step: + # fix this for multi params + if self.train_config.optimizer != 'adafactor': + if self.do_grad_scale: + self.scaler.unscale_(self.optimizer) + if isinstance(self.params[0], dict): + for i in range(len(self.params)): + torch.nn.utils.clip_grad_norm_(self.params[i]['params'], self.train_config.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(self.params, self.train_config.max_grad_norm) + # only step if we are not accumulating + with self.timer('optimizer_step'): + # self.optimizer.step() + if not self.do_grad_scale: + self.optimizer.step() + else: + self.scaler.step(self.optimizer) + self.scaler.update() + + self.optimizer.zero_grad(set_to_none=True) + if self.ema is not None: + with self.timer('ema_update'): + self.ema.update() + else: + # gradient accumulation. Just a place for breakpoint + pass + + # TODO Should we only step scheduler on grad step? If so, need to recalculate last step + with self.timer('scheduler_step'): + self.lr_scheduler.step() + + if self.embedding is not None: + with self.timer('restore_embeddings'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.embedding.restore_embeddings() + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('restore_adapter'): + # Let's make sure we don't update any embedding weights besides the newly added token + self.adapter.restore_embeddings() + + loss_dict = OrderedDict( + {'loss': loss.item()} + ) + + self.end_of_training_loop() + + return loss_dict diff --git a/extensions_built_in/sd_trainer/__init__.py b/extensions_built_in/sd_trainer/__init__.py index 45aa841e..fc505e3f 100644 --- a/extensions_built_in/sd_trainer/__init__.py +++ b/extensions_built_in/sd_trainer/__init__.py @@ -19,6 +19,23 @@ class SDTrainerExtension(Extension): return SDTrainer +# This is for generic training (LoRA, Dreambooth, FineTuning) +class MultiGPUSDTrainerExtension(Extension): + # uid must be unique, it is how the extension is identified + uid = "trainer_v2" + + # name is the name of the extension for printing + name = "Trainer V2" + + # This is where your process class is loaded + # keep your imports in here so they don't slow down the rest of the program + @classmethod + def get_process(cls): + # import your process class here so it is only loaded when needed and return it + from .TrainerV2 import TrainerV2 + return TrainerV2 + + # for backwards compatability class TextualInversionTrainer(SDTrainerExtension): uid = "textual_inversion_trainer" @@ -26,5 +43,5 @@ class TextualInversionTrainer(SDTrainerExtension): AI_TOOLKIT_EXTENSIONS = [ # you can put a list of extensions here - SDTrainerExtension, TextualInversionTrainer + SDTrainerExtension, TextualInversionTrainer, MultiGPUSDTrainerExtension ] diff --git a/toolkit/models/unified_training_model.py b/toolkit/models/unified_training_model.py new file mode 100644 index 00000000..aa6fdde1 --- /dev/null +++ b/toolkit/models/unified_training_model.py @@ -0,0 +1,1418 @@ +import random +from typing import Union, Optional + +import torch +from diffusers import T2IAdapter, ControlNetModel +from safetensors.torch import load_file +from torch import nn + +from toolkit.basic import value_map +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.embedding import Embedding +from toolkit.guidance import GuidanceType, get_guidance_loss +from toolkit.image_utils import reduce_contrast +from toolkit.ip_adapter import IPAdapter +from toolkit.network_mixins import Network +from toolkit.prompt_utils import PromptEmbeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork +from toolkit.timer import Timer +from toolkit.train_tools import get_torch_dtype, apply_learnable_snr_gos, apply_snr_weight + +from toolkit.config_modules import TrainConfig, AdapterConfig + +AdapterType = Union[ + T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel +] + + +class UnifiedTrainingModel(nn.Module): + def __init__( + self, + sd: StableDiffusion, + network: Optional[Network] = None, + adapter: Optional[AdapterType] = None, + assistant_adapter: Optional[AdapterType] = None, + train_config: TrainConfig = None, + adapter_config: AdapterConfig = None, + embedding: Optional[Embedding] = None, + timer: Timer = None, + trigger_word: Optional[str] = None, + ): + super(UnifiedTrainingModel, self).__init__() + self.sd: StableDiffusion = sd + self.network: Optional[Network] = network + self.adapter: Optional[AdapterType] = adapter + self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] = assistant_adapter + self.train_config: TrainConfig = train_config + self.adapter_config: AdapterConfig = adapter_config + self.embedding: Optional[Embedding] = embedding + self.timer: Timer = timer + self.trigger_word: Optional[str] = trigger_word + self.device_torch = torch.device("cuda") + + # misc config + self.do_long_prompts = False + self.do_prior_prediction = False + self.do_guided_loss = False + + if self.train_config.do_prior_divergence: + self.do_prior_prediction = True + + def before_unet_predict(self): + pass + + def after_unet_predict(self): + pass + + def get_prior_prediction( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + conditioned_prompts=None, + **kwargs + ): + # todo for embeddings, we need to run without trigger words + was_unet_training = self.sd.unet.training + was_network_active = False + if self.network is not None: + was_network_active = self.network.is_active + self.network.is_active = False + can_disable_adapter = False + was_adapter_active = False + if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or + isinstance(self.adapter, ReferenceAdapter) or + (isinstance(self.adapter, CustomAdapter)) + ): + can_disable_adapter = True + was_adapter_active = self.adapter.is_active + self.adapter.is_active = False + + # do a prediction here so we can match its output with network multiplier set to 0.0 + with torch.no_grad(): + dtype = get_torch_dtype(self.train_config.dtype) + + embeds_to_use = conditional_embeds.clone().detach() + # handle clip vision adapter by removing triggers from prompt and replacing with the class name + if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None: + prompt_list = batch.get_caption_list() + class_name = '' + + triggers = ['[trigger]', '[name]'] + remove_tokens = [] + + if self.embed_config is not None: + triggers.append(self.embed_config.trigger) + for i in range(1, self.embed_config.tokens): + remove_tokens.append(f"{self.embed_config.trigger}_{i}") + if self.embed_config.trigger_class_name is not None: + class_name = self.embed_config.trigger_class_name + + if self.adapter is not None: + triggers.append(self.adapter_config.trigger) + for i in range(1, self.adapter_config.num_tokens): + remove_tokens.append(f"{self.adapter_config.trigger}_{i}") + if self.adapter_config.trigger_class_name is not None: + class_name = self.adapter_config.trigger_class_name + + for idx, prompt in enumerate(prompt_list): + for remove_token in remove_tokens: + prompt = prompt.replace(remove_token, '') + for trigger in triggers: + prompt = prompt.replace(trigger, class_name) + prompt_list[idx] = prompt + + embeds_to_use = self.sd.encode_prompt( + prompt_list, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype).detach() + + # dont use network on this + # self.network.multiplier = 0.0 + self.sd.unet.eval() + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux: + # we need to remove the image embeds from the prompt except for flux + embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() + end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens + embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.clone().detach() + unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos] + + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + + prior_pred = self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), + conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + rescale_cfg=self.train_config.cfg_rescale, + **pred_kwargs # adapter residuals in here + ) + if was_unet_training: + self.sd.unet.train() + prior_pred = prior_pred.detach() + # remove the residuals as we wont use them on prediction when matching control + if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: + del pred_kwargs['down_intrablock_additional_residuals'] + if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: + del pred_kwargs['down_block_additional_residuals'] + if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs: + del pred_kwargs['mid_block_additional_residual'] + + if can_disable_adapter: + self.adapter.is_active = was_adapter_active + # restore network + # self.network.multiplier = network_weight_list + if self.network is not None: + self.network.is_active = was_network_active + return prior_pred + + def predict_noise( + self, + noisy_latents: torch.Tensor, + timesteps: Union[int, torch.Tensor] = 1, + conditional_embeds: Union[PromptEmbeds, None] = None, + unconditional_embeds: Union[PromptEmbeds, None] = None, + **kwargs, + ): + dtype = get_torch_dtype(self.train_config.dtype) + return self.sd.predict_noise( + latents=noisy_latents.to(self.device_torch, dtype=dtype), + conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeddings=unconditional_embeds, + timestep=timesteps, + guidance_scale=self.train_config.cfg_scale, + detach_unconditional=False, + rescale_cfg=self.train_config.cfg_rescale, + **kwargs + ) + + def get_guided_loss( + self, + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + unconditional_embeds: Optional[PromptEmbeds] = None, + **kwargs + ): + loss = get_guidance_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + sd=self.sd, + unconditional_embeds=unconditional_embeds, + scaler=self.scaler, + **kwargs + ) + + return loss + + def get_noise(self, latents, batch_size, dtype=torch.float32): + # get noise + noise = self.sd.get_latent_noise( + height=latents.shape[2], + width=latents.shape[3], + batch_size=batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + if self.train_config.random_noise_shift > 0.0: + # get random noise -1 to 1 + noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, + dtype=noise.dtype) * 2 - 1 + + # multiply by shift amount + noise_shift *= self.train_config.random_noise_shift + + # add to noise + noise += noise_shift + + # standardize the noise + std = noise.std(dim=(2, 3), keepdim=True) + normalizer = 1 / (std + 1e-6) + noise = noise * normalizer + + return noise + + def calculate_loss( + self, + noise_pred: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + timesteps: torch.Tensor, + batch: 'DataLoaderBatchDTO', + mask_multiplier: Union[torch.Tensor, float] = 1.0, + prior_pred: Union[torch.Tensor, None] = None, + **kwargs + ): + loss_target = self.train_config.loss_target + is_reg = any(batch.get_is_reg_list()) + + prior_mask_multiplier = None + target_mask_multiplier = None + dtype = get_torch_dtype(self.train_config.dtype) + + has_mask = batch.mask_tensor is not None + + with torch.no_grad(): + loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32) + + if self.train_config.match_noise_norm: + # match the norm of the noise + noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) + noise_pred = noise_pred * (noise_norm / noise_pred_norm) + + if self.train_config.pred_scaler != 1.0: + noise_pred = noise_pred * self.train_config.pred_scaler + + target = None + + if self.train_config.target_noise_multiplier != 1.0: + noise = noise * self.train_config.target_noise_multiplier + + if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): + if self.train_config.correct_pred_norm and not is_reg: + with torch.no_grad(): + # this only works if doing a prior pred + if prior_pred is not None: + prior_mean = prior_pred.mean([2,3], keepdim=True) + prior_std = prior_pred.std([2,3], keepdim=True) + noise_mean = noise_pred.mean([2,3], keepdim=True) + noise_std = noise_pred.std([2,3], keepdim=True) + + mean_adjust = prior_mean - noise_mean + std_adjust = prior_std - noise_std + + mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier + std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier + + target_mean = noise_mean + mean_adjust + target_std = noise_std + std_adjust + + eps = 1e-5 + # match the noise to the prior + noise = (noise - noise_mean) / (noise_std + eps) + noise = noise * (target_std + eps) + target_mean + noise = noise.detach() + + if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: + assert not self.train_config.train_turbo + with torch.no_grad(): + # we need to make the noise prediction be a masked blending of noise and prior_pred + stretched_mask_multiplier = value_map( + mask_multiplier, + batch.file_items[0].dataset_config.mask_min_value, + 1.0, + 0.0, + 1.0 + ) + + prior_mask_multiplier = 1.0 - stretched_mask_multiplier + + + # target_mask_multiplier = mask_multiplier + # mask_multiplier = 1.0 + target = noise + # target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier) + # set masked multiplier to 1.0 so we dont double apply it + # mask_multiplier = 1.0 + elif prior_pred is not None and not self.train_config.do_prior_divergence: + assert not self.train_config.train_turbo + # matching adapter prediction + target = prior_pred + elif self.sd.prediction_type == 'v_prediction': + # v-parameterization training + target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) + + elif self.sd.is_flow_matching: + target = (noise - batch.latents).detach() + else: + target = noise + + if target is None: + target = noise + + pred = noise_pred + + if self.train_config.train_turbo: + raise ValueError("Turbo training is not supported in MultiGPUSDTrainer") + + ignore_snr = False + + if loss_target == 'source' or loss_target == 'unaugmented': + assert not self.train_config.train_turbo + # ignore_snr = True + if batch.sigmas is None: + raise ValueError("Batch sigmas is None. This should not happen") + + # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190 + denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents + weighing = batch.sigmas ** -2.0 + if loss_target == 'source': + # denoise the latent and compare to the latent in the batch + target = batch.latents + elif loss_target == 'unaugmented': + # we have to encode images into latents for now + # we also denoise as the unaugmented tensor is not a noisy diffirental + with torch.no_grad(): + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype) + unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier + target = unaugmented_latents.detach() + + # Get the target for loss depending on the prediction type + if self.sd.noise_scheduler.config.prediction_type == "epsilon": + target = target # we are computing loss against denoise latents + elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": + target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") + + # mse loss without reduction + loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) + loss = loss_per_element + else: + + if self.train_config.loss_type == "mae": + loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") + else: + loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") + + # handle linear timesteps and only adjust the weight of the timesteps + if self.sd.is_flow_matching and self.train_config.linear_timesteps: + # calculate the weights for the timesteps + timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype) + timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() + loss = loss * timestep_weight + + if self.train_config.do_prior_divergence and prior_pred is not None: + loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) + + if self.train_config.train_turbo: + mask_multiplier = mask_multiplier[:, 3:, :, :] + # resize to the size of the loss + mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') + + # multiply by our mask + loss = loss * mask_multiplier + + prior_loss = None + if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: + assert not self.train_config.train_turbo + if self.train_config.loss_type == "mae": + prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none") + else: + prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") + + prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier + if torch.isnan(prior_loss).any(): + print("Prior loss is nan") + prior_loss = None + else: + prior_loss = prior_loss.mean([1, 2, 3]) + # loss = loss + prior_loss + # loss = loss + prior_loss + # loss = loss + prior_loss + loss = loss.mean([1, 2, 3]) + # apply loss multiplier before prior loss + loss = loss * loss_multiplier + if prior_loss is not None: + loss = loss + prior_loss + + if not self.train_config.train_turbo: + if self.train_config.learnable_snr_gos: + # add snr_gamma + loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) + elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: + # add snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, + fixed=True) + elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() + + # check for additional losses + if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: + + loss = loss + self.adapter.additional_loss.mean() + self.adapter.additional_loss = None + + if self.train_config.target_norm_std: + # seperate out the batch and channels + pred_std = noise_pred.std([2, 3], keepdim=True) + norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean() + loss = loss + norm_std_loss + + + return loss + + def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): + return batch + + def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): + with torch.no_grad(): + with self.timer('prepare_prompt'): + prompts = batch.get_caption_list() + is_reg_list = batch.get_is_reg_list() + + is_any_reg = any([is_reg for is_reg in is_reg_list]) + + do_double = self.train_config.short_and_long_captions and not is_any_reg + + if self.train_config.short_and_long_captions and do_double: + # dont do this with regs. No point + + # double batch and add short captions to the end + prompts = prompts + batch.get_caption_short_list() + is_reg_list = is_reg_list + is_reg_list + if self.sd.model_config.refiner_name_or_path is not None and self.train_config.train_unet: + prompts = prompts + prompts + is_reg_list = is_reg_list + is_reg_list + + conditioned_prompts = [] + + for prompt, is_reg in zip(prompts, is_reg_list): + + # make sure the embedding is in the prompts + if self.embedding is not None: + prompt = self.embedding.inject_embedding_to_prompt( + prompt, + expand_token=True, + add_if_not_present=not is_reg, + ) + + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + prompt = self.adapter.inject_trigger_into_prompt( + prompt, + expand_token=True, + add_if_not_present=not is_reg, + ) + + # make sure trigger is in the prompts if not a regularization run + if self.trigger_word is not None: + prompt = self.sd.inject_trigger_into_prompt( + prompt, + trigger=self.trigger_word, + add_if_not_present=not is_reg, + ) + + if not is_reg and self.train_config.prompt_saturation_chance > 0.0: + # do random prompt saturation by expanding the prompt to hit at least 77 tokens + if random.random() < self.train_config.prompt_saturation_chance: + est_num_tokens = len(prompt.split(' ')) + if est_num_tokens < 77: + num_repeats = int(77 / est_num_tokens) + 1 + prompt = ', '.join([prompt] * num_repeats) + + + conditioned_prompts.append(prompt) + + with self.timer('prepare_latents'): + dtype = get_torch_dtype(self.train_config.dtype) + imgs = None + is_reg = any(batch.get_is_reg_list()) + if batch.tensor is not None: + imgs = batch.tensor + imgs = imgs.to(self.device_torch, dtype=dtype) + # dont adjust for regs. + if self.train_config.img_multiplier is not None and not is_reg: + # do it ad contrast + imgs = reduce_contrast(imgs, self.train_config.img_multiplier) + if batch.latents is not None: + latents = batch.latents.to(self.device_torch, dtype=dtype) + batch.latents = latents + else: + # normalize to + if self.train_config.standardize_images: + if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: + target_mean_list = [0.0002, -0.1034, -0.1879] + target_std_list = [0.5436, 0.5116, 0.5033] + else: + target_mean_list = [-0.0739, -0.1597, -0.2380] + target_std_list = [0.5623, 0.5295, 0.5347] + # Mean: tensor([-0.0739, -0.1597, -0.2380]) + # Standard Deviation: tensor([0.5623, 0.5295, 0.5347]) + imgs_channel_mean = imgs.mean(dim=(2, 3), keepdim=True) + imgs_channel_std = imgs.std(dim=(2, 3), keepdim=True) + imgs = (imgs - imgs_channel_mean) / imgs_channel_std + target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype) + target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype) + # expand them to match dim + target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3) + target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + imgs = imgs * target_std + target_mean + batch.tensor = imgs + + # show_tensors(imgs, 'imgs') + + latents = self.sd.encode_images(imgs) + batch.latents = latents + + if self.train_config.standardize_latents: + if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: + target_mean_list = [-0.1075, 0.0231, -0.0135, 0.2164] + target_std_list = [0.8979, 0.7505, 0.9150, 0.7451] + else: + target_mean_list = [0.2949, -0.3188, 0.0807, 0.1929] + target_std_list = [0.8560, 0.9629, 0.7778, 0.6719] + + latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) + latents_channel_std = latents.std(dim=(2, 3), keepdim=True) + latents = (latents - latents_channel_mean) / latents_channel_std + target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype) + target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype) + # expand them to match dim + target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3) + target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + latents = latents * target_std + target_mean + batch.latents = latents + + # show_latents(latents, self.sd.vae, 'latents') + + + if batch.unconditional_tensor is not None and batch.unconditional_latents is None: + unconditional_imgs = batch.unconditional_tensor + unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype) + unconditional_latents = self.sd.encode_images(unconditional_imgs) + batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier + + unaugmented_latents = None + if self.train_config.loss_target == 'differential_noise': + # we determine noise from the differential of the latents + unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) + + batch_size = len(batch.file_items) + min_noise_steps = self.train_config.min_denoising_steps + max_noise_steps = self.train_config.max_denoising_steps + if self.sd.model_config.refiner_name_or_path is not None: + # if we are not training the unet, then we are only doing refiner and do not need to double up + if self.train_config.train_unet: + max_noise_steps = round(self.train_config.max_denoising_steps * self.sd.model_config.refiner_start_at) + do_double = True + else: + min_noise_steps = round(self.train_config.max_denoising_steps * self.sd.model_config.refiner_start_at) + do_double = False + + with self.timer('prepare_noise'): + num_train_timesteps = self.train_config.num_train_timesteps + + if self.train_config.noise_scheduler in ['custom_lcm']: + # we store this value on our custom one + self.sd.noise_scheduler.set_timesteps( + self.sd.noise_scheduler.train_timesteps, device=self.device_torch + ) + elif self.train_config.noise_scheduler in ['lcm']: + self.sd.noise_scheduler.set_timesteps( + num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps + ) + elif self.train_config.noise_scheduler == 'flowmatch': + self.sd.noise_scheduler.set_train_timesteps( + num_train_timesteps, + device=self.device_torch, + linear=self.train_config.linear_timesteps + ) + else: + self.sd.noise_scheduler.set_timesteps( + num_train_timesteps, device=self.device_torch + ) + + content_or_style = self.train_config.content_or_style + if is_reg: + content_or_style = self.train_config.content_or_style_reg + + # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': + if content_or_style in ['style', 'content']: + # this is from diffusers training code + # Cubic sampling for favoring later or earlier timesteps + # For more details about why cubic sampling is used for content / structure, + # refer to section 3.4 of https://arxiv.org/abs/2302.08453 + + # for content / structure, it is best to favor earlier timesteps + # for style, it is best to favor later timesteps + + orig_timesteps = torch.rand((batch_size,), device=latents.device) + + if content_or_style == 'content': + timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps + elif content_or_style == 'style': + timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps + + timestep_indices = value_map( + timestep_indices, + 0, + self.train_config.num_train_timesteps - 1, + min_noise_steps, + max_noise_steps - 1 + ) + timestep_indices = timestep_indices.long().clamp( + min_noise_steps + 1, + max_noise_steps - 1 + ) + + elif content_or_style == 'balanced': + if min_noise_steps == max_noise_steps: + timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps + else: + # todo, some schedulers use indices, otheres use timesteps. Not sure what to do here + timestep_indices = torch.randint( + min_noise_steps + 1, + max_noise_steps - 1, + (batch_size,), + device=self.device_torch + ) + timestep_indices = timestep_indices.long() + else: + raise ValueError(f"Unknown content_or_style {content_or_style}") + + # do flow matching + # if self.sd.is_flow_matching: + # u = compute_density_for_timestep_sampling( + # weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] + # batch_size=batch_size, + # logit_mean=0.0, + # logit_std=1.0, + # mode_scale=1.29, + # ) + # timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() + # convert the timestep_indices to a timestep + timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] + timesteps = torch.stack(timesteps, dim=0) + + # get noise + noise = self.get_noise(latents, batch_size, dtype=dtype) + + # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents + # this will negate any noise offsets + if self.train_config.dynamic_noise_offset and not is_reg: + latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2 + # subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel + noise = noise + latents_channel_mean + + if self.train_config.loss_target == 'differential_noise': + differential = latents - unaugmented_latents + # add noise to differential + # noise = noise + differential + noise = noise + (differential * 0.5) + # noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max()) + latents = unaugmented_latents + + noise_multiplier = self.train_config.noise_multiplier + + noise = noise * noise_multiplier + + latent_multiplier = self.train_config.latent_multiplier + + # handle adaptive scaling mased on std + if self.train_config.adaptive_scaling_factor: + std = latents.std(dim=(2, 3), keepdim=True) + normalizer = 1 / (std + 1e-6) + latent_multiplier = normalizer + + latents = latents * latent_multiplier + batch.latents = latents + + # normalize latents to a mean of 0 and an std of 1 + # mean_zero_latents = latents - latents.mean() + # latents = mean_zero_latents / mean_zero_latents.std() + + if batch.unconditional_latents is not None: + batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier + + + noisy_latents = self.sd.add_noise(latents, noise, timesteps) + + # determine scaled noise + # todo do we need to scale this or does it always predict full intensity + # noise = noisy_latents - latents + + # https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77 + if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented': + sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) + # add it to the batch + batch.sigmas = sigmas + # todo is this for sdxl? find out where this came from originally + # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + + def double_up_tensor(tensor: torch.Tensor): + if tensor is None: + return None + return torch.cat([tensor, tensor], dim=0) + + if do_double: + if self.sd.model_config.refiner_name_or_path: + # apply refiner double up + refiner_timesteps = torch.randint( + max_noise_steps, + self.train_config.max_denoising_steps, + (batch_size,), + device=self.device_torch + ) + refiner_timesteps = refiner_timesteps.long() + # add our new timesteps on to end + timesteps = torch.cat([timesteps, refiner_timesteps], dim=0) + + refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps) + noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0) + + else: + # just double it + noisy_latents = double_up_tensor(noisy_latents) + timesteps = double_up_tensor(timesteps) + + noise = double_up_tensor(noise) + # prompts are already updated above + imgs = double_up_tensor(imgs) + batch.mask_tensor = double_up_tensor(batch.mask_tensor) + batch.control_tensor = double_up_tensor(batch.control_tensor) + + noisy_latent_multiplier = self.train_config.noisy_latent_multiplier + + if noisy_latent_multiplier != 1.0: + noisy_latents = noisy_latents * noisy_latent_multiplier + + # remove grads for these + noisy_latents.requires_grad = False + noisy_latents = noisy_latents.detach() + noise.requires_grad = False + noise = noise.detach() + + return noisy_latents, noise, timesteps, conditioned_prompts, imgs + + + def forward(self, batch: DataLoaderBatchDTO): + self.timer.start('preprocess_batch') + batch = self.preprocess_batch(batch) + dtype = get_torch_dtype(self.train_config.dtype) + # sanity check + if self.sd.vae.dtype != self.sd.vae_torch_dtype: + self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) + if isinstance(self.sd.text_encoder, list): + for encoder in self.sd.text_encoder: + if encoder.dtype != self.sd.te_torch_dtype: + encoder.to(self.sd.te_torch_dtype) + else: + if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: + self.sd.text_encoder.to(self.sd.te_torch_dtype) + + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + if self.train_config.do_cfg or self.train_config.do_random_cfg: + # pick random negative prompts + if self.negative_prompt_pool is not None: + negative_prompts = [] + for i in range(noisy_latents.shape[0]): + num_neg = random.randint(1, self.train_config.max_negative_prompts) + this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] + this_neg_prompt = ', '.join(this_neg_prompts) + negative_prompts.append(this_neg_prompt) + self.batch_negative_prompt = negative_prompts + else: + self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] + + if self.adapter and isinstance(self.adapter, CustomAdapter): + # condition the prompt + # todo handle more than one adapter image + self.adapter.num_control_images = 1 + conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) + + network_weight_list = batch.get_network_weight_list() + if self.train_config.single_item_batching: + network_weight_list = network_weight_list + network_weight_list + + has_adapter_img = batch.control_tensor is not None + has_clip_image = batch.clip_image_tensor is not None + has_clip_image_embeds = batch.clip_image_embeds is not None + # force it to be true if doing regs as we handle those differently + if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): + has_clip_image = True + if self._clip_image_embeds_unconditional is not None: + has_clip_image_embeds = True # we are caching embeds, handle that differently + has_clip_image = False + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: + raise ValueError( + "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") + + match_adapter_assist = False + + # check if we are matching the adapter assistant + if self.assistant_adapter: + if self.train_config.match_adapter_chance == 1.0: + match_adapter_assist = True + elif self.train_config.match_adapter_chance > 0.0: + match_adapter_assist = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) < self.train_config.match_adapter_chance + + self.timer.stop('preprocess_batch') + + is_reg = False + with torch.no_grad(): + loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + for idx, file_item in enumerate(batch.file_items): + if file_item.is_reg: + loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight + is_reg = True + + adapter_images = None + sigmas = None + if has_adapter_img and (self.adapter or self.assistant_adapter): + with self.timer('get_adapter_images'): + # todo move this to data loader + if batch.control_tensor is not None: + adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() + # match in channels + if self.assistant_adapter is not None: + in_channels = self.assistant_adapter.config.in_channels + if adapter_images.shape[1] != in_channels: + # we need to match the channels + adapter_images = adapter_images[:, :in_channels, :, :] + else: + raise NotImplementedError("Adapter images now must be loaded with dataloader") + + clip_images = None + if has_clip_image: + with self.timer('get_clip_images'): + # todo move this to data loader + if batch.clip_image_tensor is not None: + clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() + + mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + if batch.mask_tensor is not None: + with self.timer('get_mask_multiplier'): + # upsampling no supported for bfloat16 + mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() + # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) + mask_multiplier = torch.nn.functional.interpolate( + mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) + ) + # expand to match latents + mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) + mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() + + def get_adapter_multiplier(): + if self.adapter and isinstance(self.adapter, T2IAdapter): + # training a t2i adapter, not using as assistant. + return 1.0 + elif match_adapter_assist: + # training a texture. We want it high + adapter_strength_min = 0.9 + adapter_strength_max = 1.0 + else: + # training with assistance, we want it low + # adapter_strength_min = 0.4 + # adapter_strength_max = 0.7 + adapter_strength_min = 0.5 + adapter_strength_max = 1.1 + + adapter_conditioning_scale = torch.rand( + (1,), device=self.device_torch, dtype=dtype + ) + + adapter_conditioning_scale = value_map( + adapter_conditioning_scale, + 0.0, + 1.0, + adapter_strength_min, + adapter_strength_max + ) + return adapter_conditioning_scale + + # flush() + with self.timer('grad_setup'): + + # text encoding + grad_on_text_encoder = False + if self.train_config.train_text_encoder: + grad_on_text_encoder = True + + if self.embedding is not None: + grad_on_text_encoder = True + + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + grad_on_text_encoder = True + + if self.adapter_config and self.adapter_config.type == 'te_augmenter': + grad_on_text_encoder = True + + # have a blank network so we can wrap it in a context and set multipliers without checking every time + if self.network is not None: + network = self.network + else: + network = BlankNetwork() + + # set the weights + network.multiplier = network_weight_list + + # activate network if it exits + + prompts_1 = conditioned_prompts + prompts_2 = None + if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl: + prompts_1 = batch.get_caption_short_list() + prompts_2 = conditioned_prompts + + # make the batch splits + if self.train_config.single_item_batching: + if self.sd.model_config.refiner_name_or_path is not None: + raise ValueError("Single item batching is not supported when training the refiner") + batch_size = noisy_latents.shape[0] + # chunk/split everything + noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) + noise_list = torch.chunk(noise, batch_size, dim=0) + timesteps_list = torch.chunk(timesteps, batch_size, dim=0) + conditioned_prompts_list = [[prompt] for prompt in prompts_1] + if imgs is not None: + imgs_list = torch.chunk(imgs, batch_size, dim=0) + else: + imgs_list = [None for _ in range(batch_size)] + if adapter_images is not None: + adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) + else: + adapter_images_list = [None for _ in range(batch_size)] + if clip_images is not None: + clip_images_list = torch.chunk(clip_images, batch_size, dim=0) + else: + clip_images_list = [None for _ in range(batch_size)] + mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) + if prompts_2 is None: + prompt_2_list = [None for _ in range(batch_size)] + else: + prompt_2_list = [[prompt] for prompt in prompts_2] + + else: + noisy_latents_list = [noisy_latents] + noise_list = [noise] + timesteps_list = [timesteps] + conditioned_prompts_list = [prompts_1] + imgs_list = [imgs] + adapter_images_list = [adapter_images] + clip_images_list = [clip_images] + mask_multiplier_list = [mask_multiplier] + if prompts_2 is None: + prompt_2_list = [None] + else: + prompt_2_list = [prompts_2] + + for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip( + noisy_latents_list, + noise_list, + timesteps_list, + conditioned_prompts_list, + imgs_list, + adapter_images_list, + clip_images_list, + mask_multiplier_list, + prompt_2_list + ): + + with (network): + # encode clip adapter here so embeds are active for tokenizer + if self.adapter and isinstance(self.adapter, ClipVisionAdapter): + with self.timer('encode_clip_vision_embeds'): + if has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True + ) + else: + # just do a blank one + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ), + is_training=True, + has_been_preprocessed=True, + drop=True + ) + # it will be injected into the tokenizer when called + self.adapter(conditional_clip_embeds) + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: + quad_count = random.randint(1, 4) + self.adapter.train() + self.adapter.trigger_pre_te( + tensors_0_1=clip_images if not is_reg else None, # on regs we send none to get random noise + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + batch_size=noisy_latents.shape[0] + ) + + with self.timer('encode_prompt'): + unconditional_embeds = None + if grad_on_text_encoder: + with torch.set_grad_enabled(True): + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) + + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + # todo only do one and repeat it + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + else: + with torch.set_grad_enabled(False): + # make sure it is in eval mode + if isinstance(self.sd.text_encoder, list): + for te in self.sd.text_encoder: + te.eval() + else: + self.sd.text_encoder.eval() + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, prompt_2, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) + if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.sd.encode_prompt( + self.batch_negative_prompt, + dropout_prob=self.train_config.prompt_dropout_prob, + long_prompts=self.do_long_prompts).to( + self.device_torch, + dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # detach the embeddings + conditional_embeds = conditional_embeds.detach() + if self.train_config.do_cfg: + unconditional_embeds = unconditional_embeds.detach() + + # flush() + pred_kwargs = {} + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, T2IAdapter)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): + with torch.set_grad_enabled(self.adapter is not None): + adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + down_block_additional_residuals = adapter(adapter_images) + if self.assistant_adapter: + # not training. detach + down_block_additional_residuals = [ + sample.to(dtype=dtype).detach() * adapter_multiplier for sample in + down_block_additional_residuals + ] + else: + down_block_additional_residuals = [ + sample.to(dtype=dtype) * adapter_multiplier for sample in + down_block_additional_residuals + ] + + pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals + + if self.adapter and isinstance(self.adapter, IPAdapter): + with self.timer('encode_adapter_embeds'): + # number of images to do if doing a quad image + quad_count = random.randint(1, 4) + image_size = self.adapter.input_size + if has_clip_image_embeds: + # todo handle reg images better than this + if is_reg: + # get unconditional image embeds from cache + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + if self.train_config.do_cfg: + embeds = [ + load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in + range(noisy_latents.shape[0]) + ] + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + embeds, + quad_count=quad_count + ) + + else: + conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( + batch.clip_image_embeds_unconditional, + quad_count=quad_count + ) + elif is_reg: + # we will zero it out in the img embedder + clip_images = torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach() + # drop will zero it out + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images, + drop=True, + is_training=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + torch.zeros( + (noisy_latents.shape[0], 3, image_size, image_size), + device=self.device_torch, dtype=dtype + ).detach(), + is_training=True, + drop=True, + has_been_preprocessed=False, + quad_count=quad_count + ) + elif has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count, + # do cfg on clip embeds to normalize the embeddings for when doing cfg + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None + ) + if self.train_config.do_cfg: + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True, + drop=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + else: + print("No Clip Image") + print([file_item.path for file_item in batch.file_items]) + raise ValueError("Could not find clip image") + + if not self.adapter_config.train_image_encoder: + # we are not training the image encoder, so we need to detach the embeds + conditional_clip_embeds = conditional_clip_embeds.detach() + if self.train_config.do_cfg: + unconditional_clip_embeds = unconditional_clip_embeds.detach() + + with self.timer('encode_adapter'): + self.adapter.train() + conditional_embeds = self.adapter( + conditional_embeds.detach(), + conditional_clip_embeds, + is_unconditional=False + ) + if self.train_config.do_cfg: + unconditional_embeds = self.adapter( + unconditional_embeds.detach(), + unconditional_clip_embeds, + is_unconditional=True + ) + else: + # wipe out unconsitional + self.adapter.last_unconditional = None + + if self.adapter and isinstance(self.adapter, ReferenceAdapter): + # pass in our scheduler + self.adapter.noise_scheduler = self.lr_scheduler + if has_clip_image or has_adapter_img: + img_to_use = clip_images if has_clip_image else adapter_images + # currently 0-1 needs to be -1 to 1 + reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype) + self.adapter.set_reference_images(reference_images) + self.adapter.noise_scheduler = self.sd.noise_scheduler + elif is_reg: + self.adapter.set_blank_reference_images(noisy_latents.shape[0]) + else: + self.adapter.set_reference_images(None) + + prior_pred = None + + do_reg_prior = False + # if is_reg and (self.network is not None or self.adapter is not None): + # # we are doing a reg image and we have a network or adapter + # do_reg_prior = True + + do_inverted_masked_prior = False + if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: + do_inverted_masked_prior = True + + do_correct_pred_norm_prior = self.train_config.correct_pred_norm + + do_guidance_prior = False + + if batch.unconditional_latents is not None: + # for this not that, we need a prior pred to normalize + guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type + if guidance_type == 'tnt': + do_guidance_prior = True + + if (( + has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): + with self.timer('prior predict'): + prior_pred = self.get_prior_prediction( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + noise=noise, + batch=batch, + unconditional_embeds=unconditional_embeds, + conditioned_prompts=conditioned_prompts + ) + if prior_pred is not None: + prior_pred = prior_pred.detach() + + # do the custom adapter after the prior prediction + if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: + quad_count = random.randint(1, 4) + self.adapter.train() + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=conditional_embeds, + is_training=True, + has_been_preprocessed=True, + quad_count=quad_count + ) + if self.train_config.do_cfg and unconditional_embeds is not None: + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=clip_images, + prompt_embeds=unconditional_embeds, + is_training=True, + has_been_preprocessed=True, + is_unconditional=True, + quad_count=quad_count + ) + + if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None: + self.adapter.add_extra_values(batch.extra_values.detach()) + + if self.train_config.do_cfg: + self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), + is_unconditional=True) + + if has_adapter_img: + if (self.adapter and isinstance(self.adapter, ControlNetModel)) or ( + self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): + if self.train_config.do_cfg: + raise ValueError("ControlNetModel is not supported with CFG") + with torch.set_grad_enabled(self.adapter is not None): + adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter + adapter_multiplier = get_adapter_multiplier() + with self.timer('encode_adapter'): + # add_text_embeds is pooled_prompt_embeds for sdxl + added_cond_kwargs = {} + if self.sd.is_xl: + added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds + added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents) + down_block_res_samples, mid_block_res_sample = adapter( + noisy_latents, + timesteps, + encoder_hidden_states=conditional_embeds.text_embeds, + controlnet_cond=adapter_images, + conditioning_scale=1.0, + guess_mode=False, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + pred_kwargs['down_block_additional_residuals'] = down_block_res_samples + pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample + + self.before_unet_predict() + # do a prior pred if we have an unconditional image, we will swap out the giadance later + if batch.unconditional_latents is not None or self.do_guided_loss: + # do guided loss + loss = self.get_guided_loss( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + batch=batch, + noise=noise, + unconditional_embeds=unconditional_embeds, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) + + else: + with self.timer('predict_unet'): + if unconditional_embeds is not None: + unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() + noise_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + **pred_kwargs + ) + self.after_unet_predict() + + with self.timer('calculate_loss'): + noise = noise.to(self.device_torch, dtype=dtype).detach() + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_pred, + ) + loss = loss * loss_multiplier.mean() + + + return loss \ No newline at end of file