diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a7ae422c..7af2f0b3 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -229,9 +229,13 @@ class SDTrainer(BaseSDTrainProcess): prior_mask_multiplier = None target_mask_multiplier = None + dtype = get_torch_dtype(self.train_config.dtype) has_mask = batch.mask_tensor is not None + with torch.no_grad(): + loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32) + if self.train_config.match_noise_norm: # match the norm of the noise noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) @@ -364,6 +368,8 @@ class SDTrainer(BaseSDTrainProcess): # loss = loss + prior_loss # loss = loss + prior_loss loss = loss.mean([1, 2, 3]) + # apply loss multiplier before prior loss + loss = loss * loss_multiplier if prior_loss is not None: loss = loss + prior_loss diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index c47c89ca..3f871149 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1480,10 +1480,6 @@ class BaseSDTrainProcess(BaseTrainProcess): else: batch = None - # if we are doing a reg step, always accumulate - if is_reg_step: - self.is_grad_accumulation_step = True - # setup accumulation if self.train_config.gradient_accumulation_steps == -1: # epoch is handling the accumulation, dont touch it diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index dec6803d..ebcdb963 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -496,6 +496,8 @@ class DatasetConfig: self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None) self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False) + self.replacements: List[str] = kwargs.get('replacements', []) + self.loss_multiplier: float = kwargs.get('loss_multiplier', 1.0) def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index ba0a6e4b..abc846af 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -68,7 +68,7 @@ class FileItemDTO( self.flip_x: bool = kwargs.get('flip_x', False) self.flip_y: bool = kwargs.get('flip_x', False) self.augments: List[str] = self.dataset_config.augments - + self.loss_multiplier: float = self.dataset_config.loss_multiplier self.network_weight: float = self.dataset_config.network_weight self.is_reg = self.dataset_config.is_reg @@ -124,6 +124,8 @@ class DataLoaderBatchDTO: control_tensors.append(x.control_tensor) self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) + self.loss_multiplier_list: List[float] = [x.loss_multiplier for x in self.file_items] + if any([x.clip_image_tensor is not None for x in self.file_items]): # find one to use as a base base_clip_image_tensor = None diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 3aef4d14..9a32814b 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -137,6 +137,13 @@ class CaptionMixin: prompt = self.default_prompt if hasattr(self, 'default_caption'): prompt = self.default_caption + + # handle replacements + replacement_list = self.dataset_config.replacements if isinstance(self.dataset_config.replacements, list) else [] + for replacement in replacement_list: + from_string, to_string = replacement.split('|') + prompt = prompt.replace(from_string, to_string) + return prompt diff --git a/toolkit/guidance.py b/toolkit/guidance.py index ef005c8f..d8e6ad5a 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -482,6 +482,87 @@ def get_guided_loss_polarity( return loss +def get_guided_tnt( + noisy_latents: torch.Tensor, + conditional_embeds: PromptEmbeds, + match_adapter_assist: bool, + network_weight_list: list, + timesteps: torch.Tensor, + pred_kwargs: dict, + batch: 'DataLoaderBatchDTO', + noise: torch.Tensor, + sd: 'StableDiffusion', + **kwargs +): + dtype = get_torch_dtype(sd.torch_dtype) + device = sd.device_torch + with torch.no_grad(): + dtype = get_torch_dtype(dtype) + noise = noise.to(device, dtype=dtype).detach() + + conditional_latents = batch.latents.to(device, dtype=dtype).detach() + unconditional_latents = batch.unconditional_latents.to(device, dtype=dtype).detach() + + conditional_noisy_latents = sd.add_noise( + conditional_latents, + noise, + timesteps + ).detach() + + unconditional_noisy_latents = sd.add_noise( + unconditional_latents, + noise, + timesteps + ).detach() + + # double up everything to run it through all at once + cat_embeds = concat_prompt_embeds([conditional_embeds, conditional_embeds]) + cat_latents = torch.cat([conditional_noisy_latents, unconditional_noisy_latents], dim=0) + cat_timesteps = torch.cat([timesteps, timesteps], dim=0) + + + # turn the LoRA network back on. + sd.unet.train() + if sd.network is not None: + cat_network_weight_list = [weight for weight in network_weight_list * 2] + sd.network.multiplier = cat_network_weight_list + sd.network.is_active = True + + prediction = sd.predict_noise( + latents=cat_latents.to(device, dtype=dtype).detach(), + conditional_embeddings=cat_embeds.to(device, dtype=dtype).detach(), + timestep=cat_timesteps, + guidance_scale=1.0, + **pred_kwargs # adapter residuals in here + ) + this_prediction, that_prediction = torch.chunk(prediction, 2, dim=0) + + this_loss = torch.nn.functional.mse_loss( + this_prediction.float(), + noise.float(), + reduction="none" + ) + + that_loss = torch.nn.functional.mse_loss( + that_prediction.float(), + noise.float(), + reduction="none" + ) * -1.0 + + loss = this_loss + that_loss + + loss = loss.mean([1, 2, 3]) + + loss.backward() + + # detach it so parent class can run backward on no grads without throwing error + loss = loss.detach() + loss.requires_grad_(True) + + return loss + + + # this processes all guidance losses based on the batch information def get_guidance_loss( noisy_latents: torch.Tensor, @@ -529,6 +610,20 @@ def get_guidance_loss( sd, **kwargs ) + elif guidance_type == "tnt": + assert unconditional_embeds is None, "Unconditional embeds are not supported for polarity guidance" + return get_guided_loss_polarity( + noisy_latents, + conditional_embeds, + match_adapter_assist, + network_weight_list, + timesteps, + pred_kwargs, + batch, + noise, + sd, + **kwargs + ) elif guidance_type == "targeted_polarity": assert unconditional_embeds is None, "Unconditional embeds are not supported for targeted polarity guidance" diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 7177f519..5bd72254 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1273,6 +1273,7 @@ class StableDiffusion: images = torch.stack(image_list) latents = self.vae.encode(images).latent_dist.sample() + # latents = self.vae.encode(images, return_dict=False)[0] latents = latents * self.vae.config['scaling_factor'] latents = latents.to(device, dtype=dtype)