From 2e84b3d5b18080086d8d26b0ee54b3f3dfb83721 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 12 Jul 2025 16:55:15 -0600 Subject: [PATCH] Update VAE trainer to handle fixed latent target. Also minor bug fixes and improvements --- jobs/process/TrainVAEProcess.py | 76 ++++++++++++++++++++++++++++----- jobs/process/models/critic.py | 2 +- 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 4530b7cd..19c72472 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -59,12 +59,15 @@ class TrainVAEProcess(BaseTrainProcess): super().__init__(process_id, job, config) self.data_loader = None self.vae = None + self.target_latent_vae = None self.device = self.get_conf('device', self.job.device) self.vae_path = self.get_conf('vae_path', None) + self.target_latent_vae_path = self.get_conf('target_latent_vae_path', None) self.eq_vae = self.get_conf('eq_vae', False) self.datasets_objects = self.get_conf('datasets', required=True) self.batch_size = self.get_conf('batch_size', 1, as_type=int) self.resolution = self.get_conf('resolution', 256, as_type=int) + self.sample_resolution = self.get_conf('sample_resolution', self.resolution, as_type=int) self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) self.sample_every = self.get_conf('sample_every', None) self.optimizer_type = self.get_conf('optimizer', 'adam') @@ -78,6 +81,7 @@ class TrainVAEProcess(BaseTrainProcess): self.content_weight = self.get_conf('content_weight', 0, as_type=float) self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) + self.mae_weight = self.get_conf('mae_weight', 0, as_type=float) self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float) self.tv_weight = self.get_conf('tv_weight', 0, as_type=float) self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float) @@ -106,6 +110,7 @@ class TrainVAEProcess(BaseTrainProcess): self.lpips_loss:lpips.LPIPS = None self.vae_scale_factor = 8 + self.target_vae_scale_factor = 8 self.step_num = 0 self.epoch_num = 0 @@ -244,6 +249,14 @@ class TrainVAEProcess(BaseTrainProcess): return loss else: return torch.tensor(0.0, device=self.device) + + def get_mae_loss(self, pred, target): + if self.mae_weight > 0: + loss_fn = nn.L1Loss() + loss = loss_fn(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) def get_kld_loss(self, mu, log_var): if self.kld_weight > 0: @@ -389,7 +402,7 @@ class TrainVAEProcess(BaseTrainProcess): min_dim = min(img.width, img.height) img = img.crop((0, 0, min_dim, min_dim)) # resize - img = img.resize((self.resolution, self.resolution)) + img = img.resize((self.sample_resolution, self.sample_resolution)) input_img = img img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) @@ -420,7 +433,7 @@ class TrainVAEProcess(BaseTrainProcess): latent_img = (latent_img / 2 + 0.5).clamp(0, 1) # resize to 256x256 - latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest') + latent_img = torch.nn.functional.interpolate(latent_img, size=(self.sample_resolution, self.sample_resolution), mode='nearest') latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy() latent_img = (latent_img * 255).astype(np.uint8) # convert to pillow image @@ -435,17 +448,19 @@ class TrainVAEProcess(BaseTrainProcess): decoded = Image.fromarray((decoded * 255).astype(np.uint8)) # stack input image and decoded image - input_img = input_img.resize((self.resolution, self.resolution)) - decoded = decoded.resize((self.resolution, self.resolution)) + input_img = input_img.resize((self.sample_resolution, self.sample_resolution)) + decoded = decoded.resize((self.sample_resolution, self.sample_resolution)) - output_img = Image.new('RGB', (self.resolution * 3, self.resolution)) + output_img = Image.new('RGB', (self.sample_resolution * 3, self.sample_resolution)) output_img.paste(input_img, (0, 0)) - output_img.paste(decoded, (self.resolution, 0)) - output_img.paste(latent_img, (self.resolution * 2, 0)) + output_img.paste(decoded, (self.sample_resolution, 0)) + output_img.paste(latent_img, (self.sample_resolution * 2, 0)) scale_up = 2 if output_img.height <= 300: scale_up = 4 + if output_img.height >= 1000: + scale_up = 1 # scale up using nearest neighbor output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST) @@ -492,6 +507,16 @@ class TrainVAEProcess(BaseTrainProcess): self.vae.eval() self.vae.decoder.train() self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1) + + if self.target_latent_vae_path is not None: + self.print(f"Loading target latent VAE from {self.target_latent_vae_path}") + self.target_latent_vae = AutoencoderKL.from_pretrained(self.target_latent_vae_path) + self.target_latent_vae.to(self.device, dtype=self.torch_dtype) + self.target_latent_vae.eval() + self.target_vae_scale_factor = 2 ** (len(self.target_latent_vae.config['block_out_channels']) - 1) + else: + self.target_latent_vae = None + self.target_vae_scale_factor = self.vae_scale_factor def run(self): super().run() @@ -571,7 +596,7 @@ class TrainVAEProcess(BaseTrainProcess): optimizer, total_iters=num_steps, factor=1, - verbose=False + # verbose=False ) # setup tqdm progress bar @@ -589,6 +614,8 @@ class TrainVAEProcess(BaseTrainProcess): "style": [], "content": [], "mse": [], + "mae": [], + "lat_mse": [], "mvl": [], "ltv": [], "lpm": [], @@ -630,6 +657,16 @@ class TrainVAEProcess(BaseTrainProcess): if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0: batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor, batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch) + + target_latent = None + lat_mse_loss = torch.tensor(0.0, device=self.device) + if self.target_latent_vae is not None: + target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor + target_input_size = (int(batch.shape[2] * target_input_scale), int(batch.shape[3] * target_input_scale)) + # resize to target input size + target_input_batch = Resize(target_input_size)(batch) + target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach() + # forward pass # grad only if eq_vae @@ -638,7 +675,13 @@ class TrainVAEProcess(BaseTrainProcess): mu, logvar = dgd.mean, dgd.logvar latents = dgd.sample() - if self.eq_vae: + if target_latent is not None: + # forward_latents = target_latent.detach() + lat_mse_loss = self.get_mse_loss(target_latent, latents) + latents = target_latent.detach() + forward_latents = target_latent.detach() + + elif self.eq_vae: # process flips, rotate, scale latent_chunks = list(latents.chunk(latents.shape[0], dim=0)) batch_chunks = list(batch.chunk(batch.shape[0], dim=0)) @@ -698,7 +741,7 @@ class TrainVAEProcess(BaseTrainProcess): batch = torch.cat(batch_chunks, dim=0) else: - latents.detach().requires_grad_(True) + # latents.detach().requires_grad_(True) forward_latents = latents forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype) @@ -725,6 +768,7 @@ class TrainVAEProcess(BaseTrainProcess): content_loss = self.get_content_loss() * self.content_weight kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight + mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight if self.lpips_weight > 0: lpips_loss = self.lpips_loss( pred.clamp(-1, 1), @@ -765,7 +809,7 @@ class TrainVAEProcess(BaseTrainProcess): else: lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) - loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss # check if loss is NaN or Inf if torch.isnan(loss) or torch.isinf(loss): @@ -774,6 +818,8 @@ class TrainVAEProcess(BaseTrainProcess): self.print(f" - Content loss: {content_loss.item()}") self.print(f" - KLD loss: {kld_loss.item()}") self.print(f" - MSE loss: {mse_loss.item()}") + self.print(f" - MAE loss: {mae_loss.item()}") + self.print(f" - Latent MSE loss: {lat_mse_loss.item()}") self.print(f" - LPIPS loss: {lpips_loss.item()}") self.print(f" - TV loss: {tv_loss.item()}") self.print(f" - Pattern loss: {pattern_loss.item()}") @@ -806,6 +852,10 @@ class TrainVAEProcess(BaseTrainProcess): loss_string += f" kld: {kld_loss.item():.2e}" if self.mse_weight > 0: loss_string += f" mse: {mse_loss.item():.2e}" + if self.mae_weight > 0: + loss_string += f" mae: {mae_loss.item():.2e}" + if self.target_latent_vae: + loss_string += f" lat_mse: {lat_mse_loss.item():.2e}" if self.tv_weight > 0: loss_string += f" tv: {tv_loss.item():.2e}" if self.pattern_weight > 0: @@ -847,6 +897,8 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses["style"].append(style_loss.item()) epoch_losses["content"].append(content_loss.item()) epoch_losses["mse"].append(mse_loss.item()) + epoch_losses["mae"].append(mae_loss.item()) + epoch_losses["lat_mse"].append(lat_mse_loss.item()) epoch_losses["kl"].append(kld_loss.item()) epoch_losses["tv"].append(tv_loss.item()) epoch_losses["ptn"].append(pattern_loss.item()) @@ -861,6 +913,8 @@ class TrainVAEProcess(BaseTrainProcess): log_losses["style"].append(style_loss.item()) log_losses["content"].append(content_loss.item()) log_losses["mse"].append(mse_loss.item()) + log_losses["mae"].append(mae_loss.item()) + log_losses["lat_mse"].append(lat_mse_loss.item()) log_losses["kl"].append(kld_loss.item()) log_losses["tv"].append(tv_loss.item()) log_losses["ptn"].append(pattern_loss.item()) diff --git a/jobs/process/models/critic.py b/jobs/process/models/critic.py index c792a9be..118db5a0 100644 --- a/jobs/process/models/critic.py +++ b/jobs/process/models/critic.py @@ -152,7 +152,7 @@ class Critic: self.optimizer, total_iters=self.process.max_steps * self.num_critic_per_gen, factor=1, - verbose=False, + # verbose=False, ) def load_weights(self):