diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index c324603e..27daefab 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -24,6 +24,7 @@ from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype from diffusers import AutoencoderKL, AutoencoderTiny +from toolkit.models.autoencoder_tiny_with_pooled_exits import AutoencoderTinyWithPooledExits from tqdm import tqdm import math import torchvision.utils @@ -34,6 +35,7 @@ from torchvision.transforms import Resize import lpips import random import traceback +from transformers import SiglipImageProcessor, SiglipVisionModel IMAGE_TRANSFORMS = transforms.Compose( [ @@ -80,13 +82,14 @@ class TrainVAEProcess(BaseTrainProcess): self.style_weight = self.get_conf('style_weight', 0, as_type=float) self.content_weight = self.get_conf('content_weight', 0, as_type=float) self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) + self.clip_weight = self.get_conf('clip_weight', 0, as_type=float) self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) self.mae_weight = self.get_conf('mae_weight', 0, as_type=float) self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float) self.tv_weight = self.get_conf('tv_weight', 0, as_type=float) self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float) self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching - self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float) + self.lpips_weight = self.get_conf('lpips_weight', 0, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float) self.optimizer_params = self.get_conf('optimizer_params', {}) @@ -96,7 +99,13 @@ class TrainVAEProcess(BaseTrainProcess): self.random_scaling = self.get_conf('random_scaling', False, as_type=bool) self.vae_type = self.get_conf('vae_type', 'AutoencoderKL', as_type=str) # AutoencoderKL or AutoencoderTiny - self.VaeClass = AutoencoderKL if self.vae_type == 'AutoencoderKL' else AutoencoderTiny + self.do_pooled_exits = False + self.VaeClass = AutoencoderKL + if self.vae_type == 'AutoencoderTiny': + self.VaeClass = AutoencoderTiny + if self.vae_type == 'AutoencoderTinyWithPooledExits': + self.VaeClass = AutoencoderTinyWithPooledExits + self.do_pooled_exits = True if not self.train_encoder: # remove losses that only target encoder @@ -108,6 +117,9 @@ class TrainVAEProcess(BaseTrainProcess): self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.torch_dtype = get_torch_dtype(self.dtype) self.vgg_19 = None + self.clip = None + self.clip_image_processor = None + self.clip_image_size = 256 self.style_weight_scalers = [] self.content_weight_scalers = [] self.lpips_loss:lpips.LPIPS = None @@ -226,6 +238,64 @@ class TrainVAEProcess(BaseTrainProcess): self.print(f"Style weight scalers: {self.style_weight_scalers}") self.print(f"Content weight scalers: {self.content_weight_scalers}") + + def setup_clip(self): + ckpt = 'google/siglip2-base-patch16-256' + if self.resolution == 512: + ckpt = 'google/siglip2-so400m-patch16-512' + # ckpt = 'google/siglip2-base-patch16-512' + self.clip_image_size = 512 + self.print(f"Loading CLIP model from {ckpt}") + vision_encoder = SiglipVisionModel.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16).eval() + processor = SiglipImageProcessor.from_pretrained(ckpt) + self.clip = vision_encoder + self.clip_image_processor = processor + + def get_clip_embeddings(self, image_n1p1): + tensors_0_1 = (image_n1p1 + 1) / 2 + tensors_0_1 = tensors_0_1.clamp(0, 1) + + # resize if needed + if tensors_0_1.shape[-2:] != (self.clip_image_size, self.clip_image_size): + tensors_0_1 = torch.nn.functional.interpolate(tensors_0_1, size=(self.clip_image_size, self.clip_image_size), mode='bilinear', align_corners=False) + + mean = torch.tensor([0.5, 0.5, 0.5]).to( + tensors_0_1.device, dtype=tensors_0_1.dtype + ).view([1, 3, 1, 1]).detach() + std = torch.tensor([0.5, 0.5, 0.5]).to( + tensors_0_1.device, dtype=tensors_0_1.dtype + ).view([1, 3, 1, 1]).detach() + + # tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0 + clip_image = (tensors_0_1 - mean) / std + + id_embeds = self.clip( + clip_image.to(self.clip.device, dtype=torch.bfloat16), + output_hidden_states=True, + ) + last_hidden_state = id_embeds['last_hidden_state'] + return last_hidden_state + + def get_clip_loss(self, pred, target): + # pred and target come in as -1 to 1. + with torch.no_grad(): + target_embeddings = self.get_clip_embeddings(target).float() + pred_embeddings = self.get_clip_embeddings(pred).float() + return torch.nn.functional.mse_loss(pred_embeddings, target_embeddings) + + def get_pooled_output_loss(self, pooled_outputs, target): + if pooled_outputs is None: + return torch.tensor(0.0, device=self.device) + + # pooled_outputs is a list of tensors, each with shape (batch_size, 3, h, w) + # target is a tensor with shape (batch_size, 3, h, w) + loss = 0.0 + for pooled_output in pooled_outputs: + with torch.no_grad(): + # resize target to match pooled_output size + target_resized = torch.nn.functional.interpolate(target, size=pooled_output.shape[2:], mode='bilinear', align_corners=False) + loss += torch.nn.functional.mse_loss(pooled_output.float(), target_resized.float()) + return loss / len(pooled_outputs) if len(pooled_outputs) > 0 else torch.tensor(0.0, device=self.device) def get_style_loss(self): if self.style_weight > 0: @@ -607,6 +677,9 @@ class TrainVAEProcess(BaseTrainProcess): if self.use_critic: self.critic.setup() + if self.clip_weight > 0: + self.setup_clip() + if self.lpips_weight > 0 and self.lpips_loss is None: # self.lpips_loss = lpips.LPIPS(net='vgg') self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16) @@ -645,6 +718,8 @@ class TrainVAEProcess(BaseTrainProcess): "lpm": [], "kl": [], "tv": [], + "clip": [], + "pool": [], "ptn": [], "crD": [], "crG": [], @@ -699,7 +774,7 @@ class TrainVAEProcess(BaseTrainProcess): # forward pass # grad only if eq_vae with torch.set_grad_enabled(self.train_encoder): - if self.vae_type == 'AutoencoderTiny': + if self.vae_type != 'AutoencoderKL': # AutoencoderTiny cannot do latent distribution sampling latents = self.vae.encode(batch, return_dict=False)[0] mu, logvar = None, None @@ -791,7 +866,11 @@ class TrainVAEProcess(BaseTrainProcess): shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 forward_latents = forward_latents / self.vae.config['scaling_factor'] + shift - pred = self.vae.decode(forward_latents).sample + pooled_outputs = None + if self.do_pooled_exits: + pred, pooled_outputs = self.vae.decode_with_pooled_exits(forward_latents) + else: + pred = self.vae.decode(forward_latents).sample # Run through VGG19 if self.style_weight > 0 or self.content_weight > 0: @@ -810,6 +889,11 @@ class TrainVAEProcess(BaseTrainProcess): kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight + pool_loss = self.get_pooled_output_loss(pooled_outputs, batch) + if self.clip_weight > 0: + clip_loss = self.get_clip_loss(pred, batch) * self.clip_weight + else: + clip_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) if self.lpips_weight > 0: lpips_loss = self.lpips_loss( pred.clamp(-1, 1).to(self.device, dtype=torch.bfloat16), @@ -850,7 +934,7 @@ class TrainVAEProcess(BaseTrainProcess): else: lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) - loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss + clip_loss + pool_loss # check if loss is NaN or Inf if torch.isnan(loss) or torch.isinf(loss): @@ -864,6 +948,8 @@ class TrainVAEProcess(BaseTrainProcess): self.print(f" - LPIPS loss: {lpips_loss.item()}") self.print(f" - TV loss: {tv_loss.item()}") self.print(f" - Pattern loss: {pattern_loss.item()}") + self.print(f" - CLIP loss: {clip_loss.item()}") + self.print(f" - Pooled output loss: {pool_loss.item()}") self.print(f" - Critic gen loss: {critic_gen_loss.item()}") self.print(f" - Critic D loss: {critic_d_loss}") self.print(f" - Mean variance loss: {mv_loss.item()}") @@ -901,6 +987,10 @@ class TrainVAEProcess(BaseTrainProcess): loss_string += f" tv: {tv_loss.item():.2e}" if self.pattern_weight > 0: loss_string += f" ptn: {pattern_loss.item():.2e}" + if self.clip_weight > 0: + loss_string += f" clip: {clip_loss.item():.2e}" + if self.do_pooled_exits: + loss_string += f" pool: {pool_loss.item():.2e}" if self.use_critic and self.critic_weight > 0: loss_string += f" crG: {critic_gen_loss.item():.2e}" if self.use_critic: @@ -943,6 +1033,8 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses["kl"].append(kld_loss.item()) epoch_losses["tv"].append(tv_loss.item()) epoch_losses["ptn"].append(pattern_loss.item()) + epoch_losses["clip"].append(clip_loss.item()) + epoch_losses["pool"].append(pool_loss.item()) epoch_losses["crG"].append(critic_gen_loss.item()) epoch_losses["crD"].append(critic_d_loss) epoch_losses["mvl"].append(mv_loss.item()) @@ -959,6 +1051,8 @@ class TrainVAEProcess(BaseTrainProcess): log_losses["kl"].append(kld_loss.item()) log_losses["tv"].append(tv_loss.item()) log_losses["ptn"].append(pattern_loss.item()) + log_losses["clip"].append(clip_loss.item()) + log_losses["pool"].append(pool_loss.item()) log_losses["crG"].append(critic_gen_loss.item()) log_losses["crD"].append(critic_d_loss) log_losses["mvl"].append(mv_loss.item()) diff --git a/toolkit/models/autoencoder_tiny_with_pooled_exits.py b/toolkit/models/autoencoder_tiny_with_pooled_exits.py new file mode 100644 index 00000000..5771955d --- /dev/null +++ b/toolkit/models/autoencoder_tiny_with_pooled_exits.py @@ -0,0 +1,187 @@ +from typing import Optional, Tuple, Union +from diffusers import AutoencoderTiny +from diffusers.models.autoencoders.vae import ( + EncoderTiny, + get_activation, + AutoencoderTinyBlock, + DecoderOutput +) +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.configuration_utils import register_to_config +import torch +import torch.nn as nn + +class DecoderTinyWithPooledExits(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + upsampling_scaling_factor: int, + act_fn: str, + upsample_fn: str, + ): + super().__init__() + layers = [] + self.ordered_layers = [] + l = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) + self.ordered_layers.append(l) + layers.append(l) + l = get_activation(act_fn) + self.ordered_layers.append(l) + layers.append(l) + + pooled_exits = [] + + + for i, num_block in enumerate(num_blocks): + is_final_block = i == (len(num_blocks) - 1) + num_channels = block_out_channels[i] + + for _ in range(num_block): + l = AutoencoderTinyBlock(num_channels, num_channels, act_fn) + layers.append(l) + self.ordered_layers.append(l) + + if not is_final_block: + l = nn.Upsample( + scale_factor=upsampling_scaling_factor, mode=upsample_fn + ) + layers.append(l) + self.ordered_layers.append(l) + + conv_out_channel = num_channels if not is_final_block else out_channels + l = nn.Conv2d( + num_channels, + conv_out_channel, + kernel_size=3, + padding=1, + bias=is_final_block, + ) + layers.append(l) + self.ordered_layers.append(l) + + if not is_final_block: + p = nn.Conv2d( + conv_out_channel, + out_channels=3, + kernel_size=3, + padding=1, + bias=True, + ) + p._is_pooled_exit = True + pooled_exits.append(p) + self.ordered_layers.append(p) + + self.layers = nn.ModuleList(layers) + self.pooled_exits = nn.ModuleList(pooled_exits) + self.gradient_checkpointing = False + + def forward(self, x: torch.Tensor, pooled_outputs=False) -> torch.Tensor: + r"""The forward method of the `DecoderTiny` class.""" + # Clamp. + x = torch.tanh(x / 3) * 3 + + pooled_output_list = [] + + for layer in self.ordered_layers: + # see if is pooled exit + try: + if hasattr(layer, '_is_pooled_exit') and layer._is_pooled_exit: + if pooled_outputs: + pooled_output = layer(x) + pooled_output_list.append(pooled_output) + else: + if torch.is_grad_enabled() and self.gradient_checkpointing: + x = self._gradient_checkpointing_func(layer, x) + else: + x = layer(x) + except RuntimeError as e: + raise e + + # scale image from [0, 1] to [-1, 1] to match diffusers convention + x = x.mul(2).sub(1) + + if pooled_outputs: + return x, pooled_output_list + return x + + +class AutoencoderTinyWithPooledExits(AutoencoderTiny): + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), + decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), + act_fn: str = "relu", + upsample_fn: str = "nearest", + latent_channels: int = 4, + upsampling_scaling_factor: int = 2, + num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), + num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), + latent_magnitude: int = 3, + latent_shift: float = 0.5, + force_upcast: bool = False, + scaling_factor: float = 1.0, + shift_factor: float = 0.0, + ): + super(AutoencoderTiny, self).__init__() + + if len(encoder_block_out_channels) != len(num_encoder_blocks): + raise ValueError( + "`encoder_block_out_channels` should have the same length as `num_encoder_blocks`." + ) + if len(decoder_block_out_channels) != len(num_decoder_blocks): + raise ValueError( + "`decoder_block_out_channels` should have the same length as `num_decoder_blocks`." + ) + + self.encoder = EncoderTiny( + in_channels=in_channels, + out_channels=latent_channels, + num_blocks=num_encoder_blocks, + block_out_channels=encoder_block_out_channels, + act_fn=act_fn, + ) + + self.decoder = DecoderTinyWithPooledExits( + in_channels=latent_channels, + out_channels=out_channels, + num_blocks=num_decoder_blocks, + block_out_channels=decoder_block_out_channels, + upsampling_scaling_factor=upsampling_scaling_factor, + act_fn=act_fn, + upsample_fn=upsample_fn, + ) + + self.latent_magnitude = latent_magnitude + self.latent_shift = latent_shift + self.scaling_factor = scaling_factor + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.spatial_scale_factor = 2**out_channels + self.tile_overlap_factor = 0.125 + self.tile_sample_min_size = 512 + self.tile_latent_min_size = ( + self.tile_sample_min_size // self.spatial_scale_factor + ) + + self.register_to_config(block_out_channels=decoder_block_out_channels) + self.register_to_config(force_upcast=False) + + @apply_forward_hook + def decode_with_pooled_exits( + self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = False + ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: + output, pooled_outputs = self.decoder(x, pooled_outputs=True) + + if not return_dict: + return (output, pooled_outputs) + + return DecoderOutput(sample=output)