Improvements to VAE trainer. Allow CLIP loss.

This commit is contained in:
Jaret Burkett
2025-07-24 06:50:56 -06:00
parent ca5cf827a1
commit c5eb763342
2 changed files with 286 additions and 5 deletions

View File

@@ -24,6 +24,7 @@ from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype
from diffusers import AutoencoderKL, AutoencoderTiny
from toolkit.models.autoencoder_tiny_with_pooled_exits import AutoencoderTinyWithPooledExits
from tqdm import tqdm
import math
import torchvision.utils
@@ -34,6 +35,7 @@ from torchvision.transforms import Resize
import lpips
import random
import traceback
from transformers import SiglipImageProcessor, SiglipVisionModel
IMAGE_TRANSFORMS = transforms.Compose(
[
@@ -80,13 +82,14 @@ class TrainVAEProcess(BaseTrainProcess):
self.style_weight = self.get_conf('style_weight', 0, as_type=float)
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
self.clip_weight = self.get_conf('clip_weight', 0, as_type=float)
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.mae_weight = self.get_conf('mae_weight', 0, as_type=float)
self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 0, as_type=float)
self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float)
self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching
self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
self.lpips_weight = self.get_conf('lpips_weight', 0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float)
self.optimizer_params = self.get_conf('optimizer_params', {})
@@ -96,7 +99,13 @@ class TrainVAEProcess(BaseTrainProcess):
self.random_scaling = self.get_conf('random_scaling', False, as_type=bool)
self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny
self.VaeClass = AutoencoderKL if self.vae_type == 'AutoencoderKL' else AutoencoderTiny
self.do_pooled_exits = False
self.VaeClass = AutoencoderKL
if self.vae_type == 'AutoencoderTiny':
self.VaeClass = AutoencoderTiny
if self.vae_type == 'AutoencoderTinyWithPooledExits':
self.VaeClass = AutoencoderTinyWithPooledExits
self.do_pooled_exits = True
if not self.train_encoder:
# remove losses that only target encoder
@@ -108,6 +117,9 @@ class TrainVAEProcess(BaseTrainProcess):
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
self.torch_dtype = get_torch_dtype(self.dtype)
self.vgg_19 = None
self.clip = None
self.clip_image_processor = None
self.clip_image_size = 256
self.style_weight_scalers = []
self.content_weight_scalers = []
self.lpips_loss:lpips.LPIPS = None
@@ -226,6 +238,64 @@ class TrainVAEProcess(BaseTrainProcess):
self.print(f"Style weight scalers: {self.style_weight_scalers}")
self.print(f"Content weight scalers: {self.content_weight_scalers}")
def setup_clip(self):
ckpt = 'google/siglip2-base-patch16-256'
if self.resolution == 512:
ckpt = 'google/siglip2-so400m-patch16-512'
# ckpt = 'google/siglip2-base-patch16-512'
self.clip_image_size = 512
self.print(f"Loading CLIP model from {ckpt}")
vision_encoder = SiglipVisionModel.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16).eval()
processor = SiglipImageProcessor.from_pretrained(ckpt)
self.clip = vision_encoder
self.clip_image_processor = processor
def get_clip_embeddings(self, image_n1p1):
tensors_0_1 = (image_n1p1 + 1) / 2
tensors_0_1 = tensors_0_1.clamp(0, 1)
# resize if needed
if tensors_0_1.shape[-2:] != (self.clip_image_size, self.clip_image_size):
tensors_0_1 = torch.nn.functional.interpolate(tensors_0_1, size=(self.clip_image_size, self.clip_image_size), mode='bilinear', align_corners=False)
mean = torch.tensor([0.5, 0.5, 0.5]).to(
tensors_0_1.device, dtype=tensors_0_1.dtype
).view([1, 3, 1, 1]).detach()
std = torch.tensor([0.5, 0.5, 0.5]).to(
tensors_0_1.device, dtype=tensors_0_1.dtype
).view([1, 3, 1, 1]).detach()
# tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
clip_image = (tensors_0_1 - mean) / std
id_embeds = self.clip(
clip_image.to(self.clip.device, dtype=torch.bfloat16),
output_hidden_states=True,
)
last_hidden_state = id_embeds['last_hidden_state']
return last_hidden_state
def get_clip_loss(self, pred, target):
# pred and target come in as -1 to 1.
with torch.no_grad():
target_embeddings = self.get_clip_embeddings(target).float()
pred_embeddings = self.get_clip_embeddings(pred).float()
return torch.nn.functional.mse_loss(pred_embeddings, target_embeddings)
def get_pooled_output_loss(self, pooled_outputs, target):
if pooled_outputs is None:
return torch.tensor(0.0, device=self.device)
# pooled_outputs is a list of tensors, each with shape (batch_size, 3, h, w)
# target is a tensor with shape (batch_size, 3, h, w)
loss = 0.0
for pooled_output in pooled_outputs:
with torch.no_grad():
# resize target to match pooled_output size
target_resized = torch.nn.functional.interpolate(target, size=pooled_output.shape[2:], mode='bilinear', align_corners=False)
loss += torch.nn.functional.mse_loss(pooled_output.float(), target_resized.float())
return loss / len(pooled_outputs) if len(pooled_outputs) > 0 else torch.tensor(0.0, device=self.device)
def get_style_loss(self):
if self.style_weight > 0:
@@ -607,6 +677,9 @@ class TrainVAEProcess(BaseTrainProcess):
if self.use_critic:
self.critic.setup()
if self.clip_weight > 0:
self.setup_clip()
if self.lpips_weight > 0 and self.lpips_loss is None:
# self.lpips_loss = lpips.LPIPS(net='vgg')
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16)
@@ -645,6 +718,8 @@ class TrainVAEProcess(BaseTrainProcess):
"lpm": [],
"kl": [],
"tv": [],
"clip": [],
"pool": [],
"ptn": [],
"crD": [],
"crG": [],
@@ -699,7 +774,7 @@ class TrainVAEProcess(BaseTrainProcess):
# forward pass
# grad only if eq_vae
with torch.set_grad_enabled(self.train_encoder):
if self.vae_type == 'AutoencoderTiny':
if self.vae_type != 'AutoencoderKL':
# AutoencoderTiny cannot do latent distribution sampling
latents = self.vae.encode(batch, return_dict=False)[0]
mu, logvar = None, None
@@ -791,7 +866,11 @@ class TrainVAEProcess(BaseTrainProcess):
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
forward_latents = forward_latents / self.vae.config['scaling_factor'] + shift
pred = self.vae.decode(forward_latents).sample
pooled_outputs = None
if self.do_pooled_exits:
pred, pooled_outputs = self.vae.decode_with_pooled_exits(forward_latents)
else:
pred = self.vae.decode(forward_latents).sample
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0:
@@ -810,6 +889,11 @@ class TrainVAEProcess(BaseTrainProcess):
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight
pool_loss = self.get_pooled_output_loss(pooled_outputs, batch)
if self.clip_weight > 0:
clip_loss = self.get_clip_loss(pred, batch) * self.clip_weight
else:
clip_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
if self.lpips_weight > 0:
lpips_loss = self.lpips_loss(
pred.clamp(-1, 1).to(self.device, dtype=torch.bfloat16),
@@ -850,7 +934,7 @@ class TrainVAEProcess(BaseTrainProcess):
else:
lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss + clip_loss + pool_loss
# check if loss is NaN or Inf
if torch.isnan(loss) or torch.isinf(loss):
@@ -864,6 +948,8 @@ class TrainVAEProcess(BaseTrainProcess):
self.print(f" - LPIPS loss: {lpips_loss.item()}")
self.print(f" - TV loss: {tv_loss.item()}")
self.print(f" - Pattern loss: {pattern_loss.item()}")
self.print(f" - CLIP loss: {clip_loss.item()}")
self.print(f" - Pooled output loss: {pool_loss.item()}")
self.print(f" - Critic gen loss: {critic_gen_loss.item()}")
self.print(f" - Critic D loss: {critic_d_loss}")
self.print(f" - Mean variance loss: {mv_loss.item()}")
@@ -901,6 +987,10 @@ class TrainVAEProcess(BaseTrainProcess):
loss_string += f" tv: {tv_loss.item():.2e}"
if self.pattern_weight > 0:
loss_string += f" ptn: {pattern_loss.item():.2e}"
if self.clip_weight > 0:
loss_string += f" clip: {clip_loss.item():.2e}"
if self.do_pooled_exits:
loss_string += f" pool: {pool_loss.item():.2e}"
if self.use_critic and self.critic_weight > 0:
loss_string += f" crG: {critic_gen_loss.item():.2e}"
if self.use_critic:
@@ -943,6 +1033,8 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["kl"].append(kld_loss.item())
epoch_losses["tv"].append(tv_loss.item())
epoch_losses["ptn"].append(pattern_loss.item())
epoch_losses["clip"].append(clip_loss.item())
epoch_losses["pool"].append(pool_loss.item())
epoch_losses["crG"].append(critic_gen_loss.item())
epoch_losses["crD"].append(critic_d_loss)
epoch_losses["mvl"].append(mv_loss.item())
@@ -959,6 +1051,8 @@ class TrainVAEProcess(BaseTrainProcess):
log_losses["kl"].append(kld_loss.item())
log_losses["tv"].append(tv_loss.item())
log_losses["ptn"].append(pattern_loss.item())
log_losses["clip"].append(clip_loss.item())
log_losses["pool"].append(pool_loss.item())
log_losses["crG"].append(critic_gen_loss.item())
log_losses["crD"].append(critic_d_loss)
log_losses["mvl"].append(mv_loss.item())

View File

@@ -0,0 +1,187 @@
from typing import Optional, Tuple, Union
from diffusers import AutoencoderTiny
from diffusers.models.autoencoders.vae import (
EncoderTiny,
get_activation,
AutoencoderTinyBlock,
DecoderOutput
)
from diffusers.utils.accelerate_utils import apply_forward_hook
from diffusers.configuration_utils import register_to_config
import torch
import torch.nn as nn
class DecoderTinyWithPooledExits(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_blocks: Tuple[int, ...],
block_out_channels: Tuple[int, ...],
upsampling_scaling_factor: int,
act_fn: str,
upsample_fn: str,
):
super().__init__()
layers = []
self.ordered_layers = []
l = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
self.ordered_layers.append(l)
layers.append(l)
l = get_activation(act_fn)
self.ordered_layers.append(l)
layers.append(l)
pooled_exits = []
for i, num_block in enumerate(num_blocks):
is_final_block = i == (len(num_blocks) - 1)
num_channels = block_out_channels[i]
for _ in range(num_block):
l = AutoencoderTinyBlock(num_channels, num_channels, act_fn)
layers.append(l)
self.ordered_layers.append(l)
if not is_final_block:
l = nn.Upsample(
scale_factor=upsampling_scaling_factor, mode=upsample_fn
)
layers.append(l)
self.ordered_layers.append(l)
conv_out_channel = num_channels if not is_final_block else out_channels
l = nn.Conv2d(
num_channels,
conv_out_channel,
kernel_size=3,
padding=1,
bias=is_final_block,
)
layers.append(l)
self.ordered_layers.append(l)
if not is_final_block:
p = nn.Conv2d(
conv_out_channel,
out_channels=3,
kernel_size=3,
padding=1,
bias=True,
)
p._is_pooled_exit = True
pooled_exits.append(p)
self.ordered_layers.append(p)
self.layers = nn.ModuleList(layers)
self.pooled_exits = nn.ModuleList(pooled_exits)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor, pooled_outputs=False) -> torch.Tensor:
r"""The forward method of the `DecoderTiny` class."""
# Clamp.
x = torch.tanh(x / 3) * 3
pooled_output_list = []
for layer in self.ordered_layers:
# see if is pooled exit
try:
if hasattr(layer, '_is_pooled_exit') and layer._is_pooled_exit:
if pooled_outputs:
pooled_output = layer(x)
pooled_output_list.append(pooled_output)
else:
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(layer, x)
else:
x = layer(x)
except RuntimeError as e:
raise e
# scale image from [0, 1] to [-1, 1] to match diffusers convention
x = x.mul(2).sub(1)
if pooled_outputs:
return x, pooled_output_list
return x
class AutoencoderTinyWithPooledExits(AutoencoderTiny):
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
act_fn: str = "relu",
upsample_fn: str = "nearest",
latent_channels: int = 4,
upsampling_scaling_factor: int = 2,
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
latent_magnitude: int = 3,
latent_shift: float = 0.5,
force_upcast: bool = False,
scaling_factor: float = 1.0,
shift_factor: float = 0.0,
):
super(AutoencoderTiny, self).__init__()
if len(encoder_block_out_channels) != len(num_encoder_blocks):
raise ValueError(
"`encoder_block_out_channels` should have the same length as `num_encoder_blocks`."
)
if len(decoder_block_out_channels) != len(num_decoder_blocks):
raise ValueError(
"`decoder_block_out_channels` should have the same length as `num_decoder_blocks`."
)
self.encoder = EncoderTiny(
in_channels=in_channels,
out_channels=latent_channels,
num_blocks=num_encoder_blocks,
block_out_channels=encoder_block_out_channels,
act_fn=act_fn,
)
self.decoder = DecoderTinyWithPooledExits(
in_channels=latent_channels,
out_channels=out_channels,
num_blocks=num_decoder_blocks,
block_out_channels=decoder_block_out_channels,
upsampling_scaling_factor=upsampling_scaling_factor,
act_fn=act_fn,
upsample_fn=upsample_fn,
)
self.latent_magnitude = latent_magnitude
self.latent_shift = latent_shift
self.scaling_factor = scaling_factor
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.spatial_scale_factor = 2**out_channels
self.tile_overlap_factor = 0.125
self.tile_sample_min_size = 512
self.tile_latent_min_size = (
self.tile_sample_min_size // self.spatial_scale_factor
)
self.register_to_config(block_out_channels=decoder_block_out_channels)
self.register_to_config(force_upcast=False)
@apply_forward_hook
def decode_with_pooled_exits(
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = False
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
output, pooled_outputs = self.decoder(x, pooled_outputs=True)
if not return_dict:
return (output, pooled_outputs)
return DecoderOutput(sample=output)