Small updates and bug fixes for various things

This commit is contained in:
Jaret Burkett
2025-06-03 20:08:35 -06:00
parent b6d25fcd10
commit adc31ec77d
4 changed files with 56 additions and 6 deletions

View File

@@ -764,6 +764,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds: Union[PromptEmbeds, None] = None, conditional_embeds: Union[PromptEmbeds, None] = None,
unconditional_embeds: Union[PromptEmbeds, None] = None, unconditional_embeds: Union[PromptEmbeds, None] = None,
batch: Optional['DataLoaderBatchDTO'] = None, batch: Optional['DataLoaderBatchDTO'] = None,
is_primary_pred: bool = False,
**kwargs, **kwargs,
): ):
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
@@ -1553,6 +1554,7 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
unconditional_embeds=unconditional_embeds, unconditional_embeds=unconditional_embeds,
batch=batch, batch=batch,
is_primary_pred=True,
**pred_kwargs **pred_kwargs
) )
self.after_unet_predict() self.after_unet_predict()

View File

@@ -18,7 +18,7 @@ from jobs.process import BaseTrainProcess
from toolkit.image_utils import show_tensors from toolkit.image_utils import show_tensors
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation, total_variation_deltas
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses from toolkit.style import get_style_model_and_losses
@@ -283,10 +283,33 @@ class TrainVAEProcess(BaseTrainProcess):
else: else:
return torch.tensor(0.0, device=self.device) return torch.tensor(0.0, device=self.device)
def get_ltv_loss(self, latent): def get_ltv_loss(self, latent, images):
# loss to reduce the latent space variance # loss to reduce the latent space variance
if self.ltv_weight > 0: if self.ltv_weight > 0:
return total_variation(latent).mean() with torch.no_grad():
images = images.to(latent.device, dtype=latent.dtype)
# resize down to latent size
images = torch.nn.functional.interpolate(images, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False)
# mean the color channel and then expand to latent size
images = images.mean(dim=1, keepdim=True)
images = images.repeat(1, latent.shape[1], 1, 1)
# normalize to a mean of 0 and std of 1
images_mean = images.mean(dim=(2, 3), keepdim=True)
images_std = images.std(dim=(2, 3), keepdim=True)
images = (images - images_mean) / (images_std + 1e-6)
# now we target the same std of the image for the latent space as to not reduce to 0
latent_tv = torch.abs(total_variation_deltas(latent))
images_tv = torch.abs(total_variation_deltas(images))
loss = torch.abs(latent_tv - images_tv) # keep it spatially aware
loss = loss.mean(dim=2, keepdim=True)
loss = loss.mean(dim=3, keepdim=True) # mean over height and width
loss = loss.mean(dim=1, keepdim=True) # mean over channels
loss = loss.mean()
return loss
else: else:
return torch.tensor(0.0, device=self.device) return torch.tensor(0.0, device=self.device)
@@ -733,7 +756,7 @@ class TrainVAEProcess(BaseTrainProcess):
mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
if self.ltv_weight > 0: if self.ltv_weight > 0:
ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight ltv_loss = self.get_ltv_loss(latents, batch) * self.ltv_weight
else: else:
ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)

View File

@@ -7,7 +7,7 @@ import re
import sys import sys
from typing import List, Optional, Dict, Type, Union from typing import List, Optional, Dict, Type, Union
import torch import torch
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel
from transformers import CLIPTextModel from transformers import CLIPTextModel
from toolkit.models.lokr import LokrModule from toolkit.models.lokr import LokrModule
@@ -523,6 +523,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
transformer.pos_embed = self.transformer_pos_embed transformer.pos_embed = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out transformer.proj_out = self.transformer_proj_out
elif base_model is not None and base_model.arch == "wan21":
transformer: WanTransformer3DModel = unet
self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding)
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
transformer.patch_embedding = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
else: else:
unet: UNet2DConditionModel = unet unet: UNet2DConditionModel = unet
unet_conv_in: torch.nn.Conv2d = unet.conv_in unet_conv_in: torch.nn.Conv2d = unet.conv_in
@@ -539,7 +547,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr) all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
if self.full_train_in_out: if self.full_train_in_out:
if self.is_pixart or self.is_auraflow or self.is_flux: base_model = self.base_model_ref() if self.base_model_ref is not None else None
if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"):
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())}) all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
else: else:

View File

@@ -14,6 +14,22 @@ def total_variation(image):
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements) torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
def total_variation_deltas(image):
"""
Compute per-pixel total variation deltas.
Input:
- image: Tensor of shape (N, C, H, W)
Returns:
- Tensor with shape (N, C, H, W), padded to match input shape
"""
dh = torch.zeros_like(image)
dv = torch.zeros_like(image)
dh[:, :, :, :-1] = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1])
dv[:, :, :-1, :] = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :])
return dh + dv
class ComparativeTotalVariation(torch.nn.Module): class ComparativeTotalVariation(torch.nn.Module):
""" """