Small updates and bug fixes for various things

This commit is contained in:
Jaret Burkett
2025-06-03 20:08:35 -06:00
parent b6d25fcd10
commit adc31ec77d
4 changed files with 56 additions and 6 deletions

View File

@@ -7,7 +7,7 @@ import re
import sys
from typing import List, Optional, Dict, Type, Union
import torch
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel
from transformers import CLIPTextModel
from toolkit.models.lokr import LokrModule
@@ -522,6 +522,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
transformer.pos_embed = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
elif base_model is not None and base_model.arch == "wan21":
transformer: WanTransformer3DModel = unet
self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding)
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
transformer.patch_embedding = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
else:
unet: UNet2DConditionModel = unet
@@ -539,7 +547,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
if self.full_train_in_out:
if self.is_pixart or self.is_auraflow or self.is_flux:
base_model = self.base_model_ref() if self.base_model_ref is not None else None
if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"):
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
else:

View File

@@ -13,6 +13,22 @@ def total_variation(image):
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
def total_variation_deltas(image):
"""
Compute per-pixel total variation deltas.
Input:
- image: Tensor of shape (N, C, H, W)
Returns:
- Tensor with shape (N, C, H, W), padded to match input shape
"""
dh = torch.zeros_like(image)
dv = torch.zeros_like(image)
dh[:, :, :, :-1] = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1])
dv[:, :, :-1, :] = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :])
return dh + dv
class ComparativeTotalVariation(torch.nn.Module):