mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Improvements to VAE trainer. Allow CLIP loss.
This commit is contained in:
@@ -24,6 +24,7 @@ from toolkit.optimizer import get_optimizer
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from toolkit.models.autoencoder_tiny_with_pooled_exits import AutoencoderTinyWithPooledExits
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
import torchvision.utils
|
||||
@@ -34,6 +35,7 @@ from torchvision.transforms import Resize
|
||||
import lpips
|
||||
import random
|
||||
import traceback
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -80,13 +82,14 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.style_weight = self.get_conf('style_weight', 0, as_type=float)
|
||||
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
|
||||
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
||||
self.clip_weight = self.get_conf('clip_weight', 0, as_type=float)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.mae_weight = self.get_conf('mae_weight', 0, as_type=float)
|
||||
self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 0, as_type=float)
|
||||
self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float)
|
||||
self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching
|
||||
self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
|
||||
self.lpips_weight = self.get_conf('lpips_weight', 0, as_type=float)
|
||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||
self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float)
|
||||
self.optimizer_params = self.get_conf('optimizer_params', {})
|
||||
@@ -96,7 +99,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.random_scaling = self.get_conf('random_scaling', False, as_type=bool)
|
||||
self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny
|
||||
|
||||
self.VaeClass = AutoencoderKL if self.vae_type == 'AutoencoderKL' else AutoencoderTiny
|
||||
self.do_pooled_exits = False
|
||||
self.VaeClass = AutoencoderKL
|
||||
if self.vae_type == 'AutoencoderTiny':
|
||||
self.VaeClass = AutoencoderTiny
|
||||
if self.vae_type == 'AutoencoderTinyWithPooledExits':
|
||||
self.VaeClass = AutoencoderTinyWithPooledExits
|
||||
self.do_pooled_exits = True
|
||||
|
||||
if not self.train_encoder:
|
||||
# remove losses that only target encoder
|
||||
@@ -108,6 +117,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.vgg_19 = None
|
||||
self.clip = None
|
||||
self.clip_image_processor = None
|
||||
self.clip_image_size = 256
|
||||
self.style_weight_scalers = []
|
||||
self.content_weight_scalers = []
|
||||
self.lpips_loss:lpips.LPIPS = None
|
||||
@@ -226,6 +238,64 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
self.print(f"Style weight scalers: {self.style_weight_scalers}")
|
||||
self.print(f"Content weight scalers: {self.content_weight_scalers}")
|
||||
|
||||
def setup_clip(self):
|
||||
ckpt = 'google/siglip2-base-patch16-256'
|
||||
if self.resolution == 512:
|
||||
ckpt = 'google/siglip2-so400m-patch16-512'
|
||||
# ckpt = 'google/siglip2-base-patch16-512'
|
||||
self.clip_image_size = 512
|
||||
self.print(f"Loading CLIP model from {ckpt}")
|
||||
vision_encoder = SiglipVisionModel.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16).eval()
|
||||
processor = SiglipImageProcessor.from_pretrained(ckpt)
|
||||
self.clip = vision_encoder
|
||||
self.clip_image_processor = processor
|
||||
|
||||
def get_clip_embeddings(self, image_n1p1):
|
||||
tensors_0_1 = (image_n1p1 + 1) / 2
|
||||
tensors_0_1 = tensors_0_1.clamp(0, 1)
|
||||
|
||||
# resize if needed
|
||||
if tensors_0_1.shape[-2:] != (self.clip_image_size, self.clip_image_size):
|
||||
tensors_0_1 = torch.nn.functional.interpolate(tensors_0_1, size=(self.clip_image_size, self.clip_image_size), mode='bilinear', align_corners=False)
|
||||
|
||||
mean = torch.tensor([0.5, 0.5, 0.5]).to(
|
||||
tensors_0_1.device, dtype=tensors_0_1.dtype
|
||||
).view([1, 3, 1, 1]).detach()
|
||||
std = torch.tensor([0.5, 0.5, 0.5]).to(
|
||||
tensors_0_1.device, dtype=tensors_0_1.dtype
|
||||
).view([1, 3, 1, 1]).detach()
|
||||
|
||||
# tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||
clip_image = (tensors_0_1 - mean) / std
|
||||
|
||||
id_embeds = self.clip(
|
||||
clip_image.to(self.clip.device, dtype=torch.bfloat16),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
last_hidden_state = id_embeds['last_hidden_state']
|
||||
return last_hidden_state
|
||||
|
||||
def get_clip_loss(self, pred, target):
|
||||
# pred and target come in as -1 to 1.
|
||||
with torch.no_grad():
|
||||
target_embeddings = self.get_clip_embeddings(target).float()
|
||||
pred_embeddings = self.get_clip_embeddings(pred).float()
|
||||
return torch.nn.functional.mse_loss(pred_embeddings, target_embeddings)
|
||||
|
||||
def get_pooled_output_loss(self, pooled_outputs, target):
|
||||
if pooled_outputs is None:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
# pooled_outputs is a list of tensors, each with shape (batch_size, 3, h, w)
|
||||
# target is a tensor with shape (batch_size, 3, h, w)
|
||||
loss = 0.0
|
||||
for pooled_output in pooled_outputs:
|
||||
with torch.no_grad():
|
||||
# resize target to match pooled_output size
|
||||
target_resized = torch.nn.functional.interpolate(target, size=pooled_output.shape[2:], mode='bilinear', align_corners=False)
|
||||
loss += torch.nn.functional.mse_loss(pooled_output.float(), target_resized.float())
|
||||
return loss / len(pooled_outputs) if len(pooled_outputs) > 0 else torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_style_loss(self):
|
||||
if self.style_weight > 0:
|
||||
@@ -607,6 +677,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.use_critic:
|
||||
self.critic.setup()
|
||||
|
||||
if self.clip_weight > 0:
|
||||
self.setup_clip()
|
||||
|
||||
if self.lpips_weight > 0 and self.lpips_loss is None:
|
||||
# self.lpips_loss = lpips.LPIPS(net='vgg')
|
||||
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16)
|
||||
@@ -645,6 +718,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
"lpm": [],
|
||||
"kl": [],
|
||||
"tv": [],
|
||||
"clip": [],
|
||||
"pool": [],
|
||||
"ptn": [],
|
||||
"crD": [],
|
||||
"crG": [],
|
||||
@@ -699,7 +774,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
# forward pass
|
||||
# grad only if eq_vae
|
||||
with torch.set_grad_enabled(self.train_encoder):
|
||||
if self.vae_type == 'AutoencoderTiny':
|
||||
if self.vae_type != 'AutoencoderKL':
|
||||
# AutoencoderTiny cannot do latent distribution sampling
|
||||
latents = self.vae.encode(batch, return_dict=False)[0]
|
||||
mu, logvar = None, None
|
||||
@@ -791,7 +866,11 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||
forward_latents = forward_latents / self.vae.config['scaling_factor'] + shift
|
||||
|
||||
pred = self.vae.decode(forward_latents).sample
|
||||
pooled_outputs = None
|
||||
if self.do_pooled_exits:
|
||||
pred, pooled_outputs = self.vae.decode_with_pooled_exits(forward_latents)
|
||||
else:
|
||||
pred = self.vae.decode(forward_latents).sample
|
||||
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0:
|
||||
@@ -810,6 +889,11 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
||||
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
||||
mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight
|
||||
pool_loss = self.get_pooled_output_loss(pooled_outputs, batch)
|
||||
if self.clip_weight > 0:
|
||||
clip_loss = self.get_clip_loss(pred, batch) * self.clip_weight
|
||||
else:
|
||||
clip_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
if self.lpips_weight > 0:
|
||||
lpips_loss = self.lpips_loss(
|
||||
pred.clamp(-1, 1).to(self.device, dtype=torch.bfloat16),
|
||||
@@ -850,7 +934,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
else:
|
||||
lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss + clip_loss + pool_loss
|
||||
|
||||
# check if loss is NaN or Inf
|
||||
if torch.isnan(loss) or torch.isinf(loss):
|
||||
@@ -864,6 +948,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.print(f" - LPIPS loss: {lpips_loss.item()}")
|
||||
self.print(f" - TV loss: {tv_loss.item()}")
|
||||
self.print(f" - Pattern loss: {pattern_loss.item()}")
|
||||
self.print(f" - CLIP loss: {clip_loss.item()}")
|
||||
self.print(f" - Pooled output loss: {pool_loss.item()}")
|
||||
self.print(f" - Critic gen loss: {critic_gen_loss.item()}")
|
||||
self.print(f" - Critic D loss: {critic_d_loss}")
|
||||
self.print(f" - Mean variance loss: {mv_loss.item()}")
|
||||
@@ -901,6 +987,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
loss_string += f" tv: {tv_loss.item():.2e}"
|
||||
if self.pattern_weight > 0:
|
||||
loss_string += f" ptn: {pattern_loss.item():.2e}"
|
||||
if self.clip_weight > 0:
|
||||
loss_string += f" clip: {clip_loss.item():.2e}"
|
||||
if self.do_pooled_exits:
|
||||
loss_string += f" pool: {pool_loss.item():.2e}"
|
||||
if self.use_critic and self.critic_weight > 0:
|
||||
loss_string += f" crG: {critic_gen_loss.item():.2e}"
|
||||
if self.use_critic:
|
||||
@@ -943,6 +1033,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
epoch_losses["kl"].append(kld_loss.item())
|
||||
epoch_losses["tv"].append(tv_loss.item())
|
||||
epoch_losses["ptn"].append(pattern_loss.item())
|
||||
epoch_losses["clip"].append(clip_loss.item())
|
||||
epoch_losses["pool"].append(pool_loss.item())
|
||||
epoch_losses["crG"].append(critic_gen_loss.item())
|
||||
epoch_losses["crD"].append(critic_d_loss)
|
||||
epoch_losses["mvl"].append(mv_loss.item())
|
||||
@@ -959,6 +1051,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
log_losses["kl"].append(kld_loss.item())
|
||||
log_losses["tv"].append(tv_loss.item())
|
||||
log_losses["ptn"].append(pattern_loss.item())
|
||||
log_losses["clip"].append(clip_loss.item())
|
||||
log_losses["pool"].append(pool_loss.item())
|
||||
log_losses["crG"].append(critic_gen_loss.item())
|
||||
log_losses["crD"].append(critic_d_loss)
|
||||
log_losses["mvl"].append(mv_loss.item())
|
||||
|
||||
187
toolkit/models/autoencoder_tiny_with_pooled_exits.py
Normal file
187
toolkit/models/autoencoder_tiny_with_pooled_exits.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
from diffusers import AutoencoderTiny
|
||||
from diffusers.models.autoencoders.vae import (
|
||||
EncoderTiny,
|
||||
get_activation,
|
||||
AutoencoderTinyBlock,
|
||||
DecoderOutput
|
||||
)
|
||||
from diffusers.utils.accelerate_utils import apply_forward_hook
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class DecoderTinyWithPooledExits(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: Tuple[int, ...],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
upsampling_scaling_factor: int,
|
||||
act_fn: str,
|
||||
upsample_fn: str,
|
||||
):
|
||||
super().__init__()
|
||||
layers = []
|
||||
self.ordered_layers = []
|
||||
l = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
self.ordered_layers.append(l)
|
||||
layers.append(l)
|
||||
l = get_activation(act_fn)
|
||||
self.ordered_layers.append(l)
|
||||
layers.append(l)
|
||||
|
||||
pooled_exits = []
|
||||
|
||||
|
||||
for i, num_block in enumerate(num_blocks):
|
||||
is_final_block = i == (len(num_blocks) - 1)
|
||||
num_channels = block_out_channels[i]
|
||||
|
||||
for _ in range(num_block):
|
||||
l = AutoencoderTinyBlock(num_channels, num_channels, act_fn)
|
||||
layers.append(l)
|
||||
self.ordered_layers.append(l)
|
||||
|
||||
if not is_final_block:
|
||||
l = nn.Upsample(
|
||||
scale_factor=upsampling_scaling_factor, mode=upsample_fn
|
||||
)
|
||||
layers.append(l)
|
||||
self.ordered_layers.append(l)
|
||||
|
||||
conv_out_channel = num_channels if not is_final_block else out_channels
|
||||
l = nn.Conv2d(
|
||||
num_channels,
|
||||
conv_out_channel,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=is_final_block,
|
||||
)
|
||||
layers.append(l)
|
||||
self.ordered_layers.append(l)
|
||||
|
||||
if not is_final_block:
|
||||
p = nn.Conv2d(
|
||||
conv_out_channel,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=True,
|
||||
)
|
||||
p._is_pooled_exit = True
|
||||
pooled_exits.append(p)
|
||||
self.ordered_layers.append(p)
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.pooled_exits = nn.ModuleList(pooled_exits)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor, pooled_outputs=False) -> torch.Tensor:
|
||||
r"""The forward method of the `DecoderTiny` class."""
|
||||
# Clamp.
|
||||
x = torch.tanh(x / 3) * 3
|
||||
|
||||
pooled_output_list = []
|
||||
|
||||
for layer in self.ordered_layers:
|
||||
# see if is pooled exit
|
||||
try:
|
||||
if hasattr(layer, '_is_pooled_exit') and layer._is_pooled_exit:
|
||||
if pooled_outputs:
|
||||
pooled_output = layer(x)
|
||||
pooled_output_list.append(pooled_output)
|
||||
else:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
x = self._gradient_checkpointing_func(layer, x)
|
||||
else:
|
||||
x = layer(x)
|
||||
except RuntimeError as e:
|
||||
raise e
|
||||
|
||||
# scale image from [0, 1] to [-1, 1] to match diffusers convention
|
||||
x = x.mul(2).sub(1)
|
||||
|
||||
if pooled_outputs:
|
||||
return x, pooled_output_list
|
||||
return x
|
||||
|
||||
|
||||
class AutoencoderTinyWithPooledExits(AutoencoderTiny):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
|
||||
act_fn: str = "relu",
|
||||
upsample_fn: str = "nearest",
|
||||
latent_channels: int = 4,
|
||||
upsampling_scaling_factor: int = 2,
|
||||
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
|
||||
num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
|
||||
latent_magnitude: int = 3,
|
||||
latent_shift: float = 0.5,
|
||||
force_upcast: bool = False,
|
||||
scaling_factor: float = 1.0,
|
||||
shift_factor: float = 0.0,
|
||||
):
|
||||
super(AutoencoderTiny, self).__init__()
|
||||
|
||||
if len(encoder_block_out_channels) != len(num_encoder_blocks):
|
||||
raise ValueError(
|
||||
"`encoder_block_out_channels` should have the same length as `num_encoder_blocks`."
|
||||
)
|
||||
if len(decoder_block_out_channels) != len(num_decoder_blocks):
|
||||
raise ValueError(
|
||||
"`decoder_block_out_channels` should have the same length as `num_decoder_blocks`."
|
||||
)
|
||||
|
||||
self.encoder = EncoderTiny(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
num_blocks=num_encoder_blocks,
|
||||
block_out_channels=encoder_block_out_channels,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.decoder = DecoderTinyWithPooledExits(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
num_blocks=num_decoder_blocks,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
upsampling_scaling_factor=upsampling_scaling_factor,
|
||||
act_fn=act_fn,
|
||||
upsample_fn=upsample_fn,
|
||||
)
|
||||
|
||||
self.latent_magnitude = latent_magnitude
|
||||
self.latent_shift = latent_shift
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.spatial_scale_factor = 2**out_channels
|
||||
self.tile_overlap_factor = 0.125
|
||||
self.tile_sample_min_size = 512
|
||||
self.tile_latent_min_size = (
|
||||
self.tile_sample_min_size // self.spatial_scale_factor
|
||||
)
|
||||
|
||||
self.register_to_config(block_out_channels=decoder_block_out_channels)
|
||||
self.register_to_config(force_upcast=False)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode_with_pooled_exits(
|
||||
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = False
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
output, pooled_outputs = self.decoder(x, pooled_outputs=True)
|
||||
|
||||
if not return_dict:
|
||||
return (output, pooled_outputs)
|
||||
|
||||
return DecoderOutput(sample=output)
|
||||
Reference in New Issue
Block a user