Small updates and bug fixes for various things

This commit is contained in:
Jaret Burkett
2025-06-03 20:08:35 -06:00
parent b6d25fcd10
commit adc31ec77d
4 changed files with 56 additions and 6 deletions

View File

@@ -18,7 +18,7 @@ from jobs.process import BaseTrainProcess
from toolkit.image_utils import show_tensors
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation, total_variation_deltas
from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
@@ -283,10 +283,33 @@ class TrainVAEProcess(BaseTrainProcess):
else:
return torch.tensor(0.0, device=self.device)
def get_ltv_loss(self, latent):
def get_ltv_loss(self, latent, images):
# loss to reduce the latent space variance
if self.ltv_weight > 0:
return total_variation(latent).mean()
with torch.no_grad():
images = images.to(latent.device, dtype=latent.dtype)
# resize down to latent size
images = torch.nn.functional.interpolate(images, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False)
# mean the color channel and then expand to latent size
images = images.mean(dim=1, keepdim=True)
images = images.repeat(1, latent.shape[1], 1, 1)
# normalize to a mean of 0 and std of 1
images_mean = images.mean(dim=(2, 3), keepdim=True)
images_std = images.std(dim=(2, 3), keepdim=True)
images = (images - images_mean) / (images_std + 1e-6)
# now we target the same std of the image for the latent space as to not reduce to 0
latent_tv = torch.abs(total_variation_deltas(latent))
images_tv = torch.abs(total_variation_deltas(images))
loss = torch.abs(latent_tv - images_tv) # keep it spatially aware
loss = loss.mean(dim=2, keepdim=True)
loss = loss.mean(dim=3, keepdim=True) # mean over height and width
loss = loss.mean(dim=1, keepdim=True) # mean over channels
loss = loss.mean()
return loss
else:
return torch.tensor(0.0, device=self.device)
@@ -733,7 +756,7 @@ class TrainVAEProcess(BaseTrainProcess):
mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
if self.ltv_weight > 0:
ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight
ltv_loss = self.get_ltv_loss(latents, batch) * self.ltv_weight
else:
ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)