mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Small updates and bug fixes for various things
This commit is contained in:
@@ -764,6 +764,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
conditional_embeds: Union[PromptEmbeds, None] = None,
|
conditional_embeds: Union[PromptEmbeds, None] = None,
|
||||||
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
unconditional_embeds: Union[PromptEmbeds, None] = None,
|
||||||
batch: Optional['DataLoaderBatchDTO'] = None,
|
batch: Optional['DataLoaderBatchDTO'] = None,
|
||||||
|
is_primary_pred: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
@@ -1553,6 +1554,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||||
unconditional_embeds=unconditional_embeds,
|
unconditional_embeds=unconditional_embeds,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
|
is_primary_pred=True,
|
||||||
**pred_kwargs
|
**pred_kwargs
|
||||||
)
|
)
|
||||||
self.after_unet_predict()
|
self.after_unet_predict()
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from jobs.process import BaseTrainProcess
|
|||||||
from toolkit.image_utils import show_tensors
|
from toolkit.image_utils import show_tensors
|
||||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
||||||
from toolkit.data_loader import ImageDataset
|
from toolkit.data_loader import ImageDataset
|
||||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation
|
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation, total_variation_deltas
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.optimizer import get_optimizer
|
from toolkit.optimizer import get_optimizer
|
||||||
from toolkit.style import get_style_model_and_losses
|
from toolkit.style import get_style_model_and_losses
|
||||||
@@ -283,10 +283,33 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
else:
|
else:
|
||||||
return torch.tensor(0.0, device=self.device)
|
return torch.tensor(0.0, device=self.device)
|
||||||
|
|
||||||
def get_ltv_loss(self, latent):
|
def get_ltv_loss(self, latent, images):
|
||||||
# loss to reduce the latent space variance
|
# loss to reduce the latent space variance
|
||||||
if self.ltv_weight > 0:
|
if self.ltv_weight > 0:
|
||||||
return total_variation(latent).mean()
|
with torch.no_grad():
|
||||||
|
images = images.to(latent.device, dtype=latent.dtype)
|
||||||
|
# resize down to latent size
|
||||||
|
images = torch.nn.functional.interpolate(images, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False)
|
||||||
|
|
||||||
|
# mean the color channel and then expand to latent size
|
||||||
|
images = images.mean(dim=1, keepdim=True)
|
||||||
|
images = images.repeat(1, latent.shape[1], 1, 1)
|
||||||
|
|
||||||
|
# normalize to a mean of 0 and std of 1
|
||||||
|
images_mean = images.mean(dim=(2, 3), keepdim=True)
|
||||||
|
images_std = images.std(dim=(2, 3), keepdim=True)
|
||||||
|
images = (images - images_mean) / (images_std + 1e-6)
|
||||||
|
|
||||||
|
# now we target the same std of the image for the latent space as to not reduce to 0
|
||||||
|
|
||||||
|
latent_tv = torch.abs(total_variation_deltas(latent))
|
||||||
|
images_tv = torch.abs(total_variation_deltas(images))
|
||||||
|
loss = torch.abs(latent_tv - images_tv) # keep it spatially aware
|
||||||
|
loss = loss.mean(dim=2, keepdim=True)
|
||||||
|
loss = loss.mean(dim=3, keepdim=True) # mean over height and width
|
||||||
|
loss = loss.mean(dim=1, keepdim=True) # mean over channels
|
||||||
|
loss = loss.mean()
|
||||||
|
return loss
|
||||||
else:
|
else:
|
||||||
return torch.tensor(0.0, device=self.device)
|
return torch.tensor(0.0, device=self.device)
|
||||||
|
|
||||||
@@ -733,7 +756,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
if self.ltv_weight > 0:
|
if self.ltv_weight > 0:
|
||||||
ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight
|
ltv_loss = self.get_ltv_loss(latents, batch) * self.ltv_weight
|
||||||
else:
|
else:
|
||||||
ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
from typing import List, Optional, Dict, Type, Union
|
from typing import List, Optional, Dict, Type, Union
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
|
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel, WanTransformer3DModel
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
from toolkit.models.lokr import LokrModule
|
from toolkit.models.lokr import LokrModule
|
||||||
|
|
||||||
@@ -523,6 +523,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
transformer.pos_embed = self.transformer_pos_embed
|
transformer.pos_embed = self.transformer_pos_embed
|
||||||
transformer.proj_out = self.transformer_proj_out
|
transformer.proj_out = self.transformer_proj_out
|
||||||
|
|
||||||
|
elif base_model is not None and base_model.arch == "wan21":
|
||||||
|
transformer: WanTransformer3DModel = unet
|
||||||
|
self.transformer_pos_embed = copy.deepcopy(transformer.patch_embedding)
|
||||||
|
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
||||||
|
|
||||||
|
transformer.patch_embedding = self.transformer_pos_embed
|
||||||
|
transformer.proj_out = self.transformer_proj_out
|
||||||
|
|
||||||
else:
|
else:
|
||||||
unet: UNet2DConditionModel = unet
|
unet: UNet2DConditionModel = unet
|
||||||
unet_conv_in: torch.nn.Conv2d = unet.conv_in
|
unet_conv_in: torch.nn.Conv2d = unet.conv_in
|
||||||
@@ -539,7 +547,8 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
|||||||
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
||||||
|
|
||||||
if self.full_train_in_out:
|
if self.full_train_in_out:
|
||||||
if self.is_pixart or self.is_auraflow or self.is_flux:
|
base_model = self.base_model_ref() if self.base_model_ref is not None else None
|
||||||
|
if self.is_pixart or self.is_auraflow or self.is_flux or (base_model is not None and base_model.arch == "wan21"):
|
||||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
||||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -14,6 +14,22 @@ def total_variation(image):
|
|||||||
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
|
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
|
||||||
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
||||||
|
|
||||||
|
def total_variation_deltas(image):
|
||||||
|
"""
|
||||||
|
Compute per-pixel total variation deltas.
|
||||||
|
Input:
|
||||||
|
- image: Tensor of shape (N, C, H, W)
|
||||||
|
Returns:
|
||||||
|
- Tensor with shape (N, C, H, W), padded to match input shape
|
||||||
|
"""
|
||||||
|
dh = torch.zeros_like(image)
|
||||||
|
dv = torch.zeros_like(image)
|
||||||
|
|
||||||
|
dh[:, :, :, :-1] = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1])
|
||||||
|
dv[:, :, :-1, :] = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :])
|
||||||
|
|
||||||
|
return dh + dv
|
||||||
|
|
||||||
|
|
||||||
class ComparativeTotalVariation(torch.nn.Module):
|
class ComparativeTotalVariation(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user