mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Small updates and bug fixes for various things
This commit is contained in:
@@ -7,7 +7,7 @@ import re
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel
|
||||
from transformers import CLIPTextModel
|
||||
from toolkit.models.lokr import LokrModule
|
||||
|
||||
@@ -522,6 +522,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
|
||||
transformer.pos_embed = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
elif base_model is not None and base_model.arch == "wan21":
|
||||
transformer: WanTransformer3DModel = unet
|
||||
self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding)
|
||||
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
||||
|
||||
transformer.patch_embedding = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
else:
|
||||
unet: UNet2DConditionModel = unet
|
||||
@@ -539,7 +547,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
||||
|
||||
if self.full_train_in_out:
|
||||
if self.is_pixart or self.is_auraflow or self.is_flux:
|
||||
base_model = self.base_model_ref() if self.base_model_ref is not None else None
|
||||
if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"):
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
||||
else:
|
||||
|
||||
@@ -13,6 +13,22 @@ def total_variation(image):
|
||||
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
|
||||
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
|
||||
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
||||
|
||||
def total_variation_deltas(image):
|
||||
"""
|
||||
Compute per-pixel total variation deltas.
|
||||
Input:
|
||||
- image: Tensor of shape (N, C, H, W)
|
||||
Returns:
|
||||
- Tensor with shape (N, C, H, W), padded to match input shape
|
||||
"""
|
||||
dh = torch.zeros_like(image)
|
||||
dv = torch.zeros_like(image)
|
||||
|
||||
dh[:, :, :, :-1] = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1])
|
||||
dv[:, :, :-1, :] = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :])
|
||||
|
||||
return dh + dv
|
||||
|
||||
|
||||
class ComparativeTotalVariation(torch.nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user