mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Small updates and bug fixes for various things
This commit is contained in:
@@ -764,6 +764,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
||||
batch: Optional['DataLoaderBatchDTO'] = None,
|
||||
is_primary_pred: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -1553,6 +1554,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
batch=batch,
|
||||
is_primary_pred=True,
|
||||
**pred_kwargs
|
||||
)
|
||||
self.after_unet_predict()
|
||||
|
||||
@@ -18,7 +18,7 @@ from jobs.process import BaseTrainProcess
|
||||
from toolkit.image_utils import show_tensors
|
||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation, total_variation_deltas
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
@@ -283,10 +283,33 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_ltv_loss(self, latent):
|
||||
def get_ltv_loss(self, latent, images):
|
||||
# loss to reduce the latent space variance
|
||||
if self.ltv_weight > 0:
|
||||
return total_variation(latent).mean()
|
||||
with torch.no_grad():
|
||||
images = images.to(latent.device, dtype=latent.dtype)
|
||||
# resize down to latent size
|
||||
images = torch.nn.functional.interpolate(images, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False)
|
||||
|
||||
# mean the color channel and then expand to latent size
|
||||
images = images.mean(dim=1, keepdim=True)
|
||||
images = images.repeat(1, latent.shape[1], 1, 1)
|
||||
|
||||
# normalize to a mean of 0 and std of 1
|
||||
images_mean = images.mean(dim=(2, 3), keepdim=True)
|
||||
images_std = images.std(dim=(2, 3), keepdim=True)
|
||||
images = (images - images_mean) / (images_std + 1e-6)
|
||||
|
||||
# now we target the same std of the image for the latent space as to not reduce to 0
|
||||
|
||||
latent_tv = torch.abs(total_variation_deltas(latent))
|
||||
images_tv = torch.abs(total_variation_deltas(images))
|
||||
loss = torch.abs(latent_tv - images_tv) # keep it spatially aware
|
||||
loss = loss.mean(dim=2, keepdim=True)
|
||||
loss = loss.mean(dim=3, keepdim=True) # mean over height and width
|
||||
loss = loss.mean(dim=1, keepdim=True) # mean over channels
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
@@ -733,7 +756,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
if self.ltv_weight > 0:
|
||||
ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight
|
||||
ltv_loss = self.get_ltv_loss(latents, batch) * self.ltv_weight
|
||||
else:
|
||||
ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import re
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel
|
||||
from transformers import CLIPTextModel
|
||||
from toolkit.models.lokr import LokrModule
|
||||
|
||||
@@ -522,6 +522,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
|
||||
transformer.pos_embed = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
elif base_model is not None and base_model.arch == "wan21":
|
||||
transformer: WanTransformer3DModel = unet
|
||||
self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding)
|
||||
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
||||
|
||||
transformer.patch_embedding = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
else:
|
||||
unet: UNet2DConditionModel = unet
|
||||
@@ -539,7 +547,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
||||
|
||||
if self.full_train_in_out:
|
||||
if self.is_pixart or self.is_auraflow or self.is_flux:
|
||||
base_model = self.base_model_ref() if self.base_model_ref is not None else None
|
||||
if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"):
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
||||
else:
|
||||
|
||||
@@ -13,6 +13,22 @@ def total_variation(image):
|
||||
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
|
||||
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
|
||||
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
||||
|
||||
def total_variation_deltas(image):
|
||||
"""
|
||||
Compute per-pixel total variation deltas.
|
||||
Input:
|
||||
- image: Tensor of shape (N, C, H, W)
|
||||
Returns:
|
||||
- Tensor with shape (N, C, H, W), padded to match input shape
|
||||
"""
|
||||
dh = torch.zeros_like(image)
|
||||
dv = torch.zeros_like(image)
|
||||
|
||||
dh[:, :, :, :-1] = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1])
|
||||
dv[:, :, :-1, :] = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :])
|
||||
|
||||
return dh + dv
|
||||
|
||||
|
||||
class ComparativeTotalVariation(torch.nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user