mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Small updates and bug fixes for various things
This commit is contained in:
@@ -18,7 +18,7 @@ from jobs.process import BaseTrainProcess
|
||||
from toolkit.image_utils import show_tensors
|
||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation, total_variation_deltas
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
@@ -283,10 +283,33 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_ltv_loss(self, latent):
|
||||
def get_ltv_loss(self, latent, images):
|
||||
# loss to reduce the latent space variance
|
||||
if self.ltv_weight > 0:
|
||||
return total_variation(latent).mean()
|
||||
with torch.no_grad():
|
||||
images = images.to(latent.device, dtype=latent.dtype)
|
||||
# resize down to latent size
|
||||
images = torch.nn.functional.interpolate(images, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False)
|
||||
|
||||
# mean the color channel and then expand to latent size
|
||||
images = images.mean(dim=1, keepdim=True)
|
||||
images = images.repeat(1, latent.shape[1], 1, 1)
|
||||
|
||||
# normalize to a mean of 0 and std of 1
|
||||
images_mean = images.mean(dim=(2, 3), keepdim=True)
|
||||
images_std = images.std(dim=(2, 3), keepdim=True)
|
||||
images = (images - images_mean) / (images_std + 1e-6)
|
||||
|
||||
# now we target the same std of the image for the latent space as to not reduce to 0
|
||||
|
||||
latent_tv = torch.abs(total_variation_deltas(latent))
|
||||
images_tv = torch.abs(total_variation_deltas(images))
|
||||
loss = torch.abs(latent_tv - images_tv) # keep it spatially aware
|
||||
loss = loss.mean(dim=2, keepdim=True)
|
||||
loss = loss.mean(dim=3, keepdim=True) # mean over height and width
|
||||
loss = loss.mean(dim=1, keepdim=True) # mean over channels
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
@@ -733,7 +756,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
if self.ltv_weight > 0:
|
||||
ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight
|
||||
ltv_loss = self.get_ltv_loss(latents, batch) * self.ltv_weight
|
||||
else:
|
||||
ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user