From adc31ec77d92dd2191441172fe73a9a728da834b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 3 Jun 2025 20:08:35 -0600 Subject: [PATCH] Small updates and bug fixes for various things --- extensions_built_in/sd_trainer/SDTrainer.py | 2 ++ jobs/process/TrainVAEProcess.py | 31 ++++++++++++++++++--- toolkit/lora_special.py | 13 +++++++-- toolkit/losses.py | 16 +++++++++++ 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 7468e0fc..ba44ff8d 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -764,6 +764,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds: Union[PromptEmbeds, None] = None, unconditional_embeds: Union[PromptEmbeds, None] = None, batch: Optional['DataLoaderBatchDTO'] = None, + is_primary_pred: bool = False, **kwargs, ): dtype = get_torch_dtype(self.train_config.dtype) @@ -1553,6 +1554,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeds=unconditional_embeds, batch=batch, + is_primary_pred=True, **pred_kwargs ) self.after_unet_predict() diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 0b6d5fab..4530b7cd 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -18,7 +18,7 @@ from jobs.process import BaseTrainProcess from toolkit.image_utils import show_tensors from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.data_loader import ImageDataset -from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation, total_variation_deltas from toolkit.metadata import get_meta_for_safetensors from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses @@ -283,10 +283,33 @@ class TrainVAEProcess(BaseTrainProcess): else: return torch.tensor(0.0, device=self.device) - def get_ltv_loss(self, latent): + def get_ltv_loss(self, latent, images): # loss to reduce the latent space variance if self.ltv_weight > 0: - return total_variation(latent).mean() + with torch.no_grad(): + images = images.to(latent.device, dtype=latent.dtype) + # resize down to latent size + images = torch.nn.functional.interpolate(images, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False) + + # mean the color channel and then expand to latent size + images = images.mean(dim=1, keepdim=True) + images = images.repeat(1, latent.shape[1], 1, 1) + + # normalize to a mean of 0 and std of 1 + images_mean = images.mean(dim=(2, 3), keepdim=True) + images_std = images.std(dim=(2, 3), keepdim=True) + images = (images - images_mean) / (images_std + 1e-6) + + # now we target the same std of the image for the latent space as to not reduce to 0 + + latent_tv = torch.abs(total_variation_deltas(latent)) + images_tv = torch.abs(total_variation_deltas(images)) + loss = torch.abs(latent_tv - images_tv) # keep it spatially aware + loss = loss.mean(dim=2, keepdim=True) + loss = loss.mean(dim=3, keepdim=True) # mean over height and width + loss = loss.mean(dim=1, keepdim=True) # mean over channels + loss = loss.mean() + return loss else: return torch.tensor(0.0, device=self.device) @@ -733,7 +756,7 @@ class TrainVAEProcess(BaseTrainProcess): mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) if self.ltv_weight > 0: - ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight + ltv_loss = self.get_ltv_loss(latents, batch) * self.ltv_weight else: ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index b0c2d7f4..cd443735 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -7,7 +7,7 @@ import re import sys from typing import List, Optional, Dict, Type, Union import torch -from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel +from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel from transformers import CLIPTextModel from toolkit.models.lokr import LokrModule @@ -522,6 +522,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): transformer.pos_embed = self.transformer_pos_embed transformer.proj_out = self.transformer_proj_out + + elif base_model is not None and base_model.arch == "wan21": + transformer: WanTransformer3DModel = unet + self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding) + self.transformer_proj_out = copy.deepcopy(transformer.proj_out) + + transformer.patch_embedding = self.transformer_pos_embed + transformer.proj_out = self.transformer_proj_out else: unet: UNet2DConditionModel = unet @@ -539,7 +547,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) if self.full_train_in_out: - if self.is_pixart or self.is_auraflow or self.is_flux: + base_model = self.base_model_ref() if self.base_model_ref is not None else None + if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"): all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) else: diff --git a/toolkit/losses.py b/toolkit/losses.py index eeea3571..fef9310d 100644 --- a/toolkit/losses.py +++ b/toolkit/losses.py @@ -13,6 +13,22 @@ def total_variation(image): n_elements = image.shape[1] * image.shape[2] * image.shape[3] return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements) + +def total_variation_deltas(image): + """ + Compute per-pixel total variation deltas. + Input: + - image: Tensor of shape (N, C, H, W) + Returns: + - Tensor with shape (N, C, H, W), padded to match input shape + """ + dh = torch.zeros_like(image) + dv = torch.zeros_like(image) + + dh[:, :, :, :-1] = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1]) + dv[:, :, :-1, :] = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :]) + + return dh + dv class ComparativeTotalVariation(torch.nn.Module):