From 3f518d995156dfd1c0987f2d763d5d90f6b7b0fb Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 27 Jul 2025 15:11:56 -0600 Subject: [PATCH] Add sharpening before losses with a split loss on vae training --- jobs/process/TrainVAEProcess.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 04be4a1e..95c38b43 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -339,16 +339,28 @@ class TrainVAEProcess(BaseTrainProcess): def get_mse_loss(self, pred, target): if self.mse_weight > 0: loss_fn = nn.MSELoss() - loss = loss_fn(pred, target) - return loss + loss_normal = loss_fn(pred, target) + + pred_sharp = sharpen_image(pred) + with torch.no_grad(): + target_sharp = sharpen_image(target) + + loss_sharp = loss_fn(pred_sharp, target_sharp) + + return (loss_sharp + loss_normal) / 2 else: return torch.tensor(0.0, device=self.device) def get_mae_loss(self, pred, target): if self.mae_weight > 0: loss_fn = nn.L1Loss() - loss = loss_fn(pred, target) - return loss + loss_normal = loss_fn(pred, target) + + pred_sharp = sharpen_image(pred) + with torch.no_grad(): + target_sharp = sharpen_image(target) + loss_sharp = loss_fn(pred_sharp, target_sharp) + return (loss_sharp + loss_normal) / 2 else: return torch.tensor(0.0, device=self.device) @@ -820,9 +832,9 @@ class TrainVAEProcess(BaseTrainProcess): shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 latents = self.vae.config['scaling_factor'] * (latents - shift) - if target_latent is not None: + if target_latent is not None and self.train_encoder: # forward_latents = target_latent.detach() - lat_mse_loss = self.get_mse_loss(target_latent.float(), latents.float()) + lat_mse_loss = torch.nn.MSELoss()(target_latent.float(), latents.float()) latents = target_latent.detach() forward_latents = target_latent.detach()