Add sharpening before losses with a split loss on vae training

This commit is contained in:
Jaret Burkett
2025-07-27 15:11:56 -06:00
parent 77dc38a574
commit 3f518d9951

View File

@@ -339,16 +339,28 @@ class TrainVAEProcess(BaseTrainProcess):
def get_mse_loss(self, pred, target):
if self.mse_weight > 0:
loss_fn = nn.MSELoss()
loss = loss_fn(pred, target)
return loss
loss_normal = loss_fn(pred, target)
pred_sharp = sharpen_image(pred)
with torch.no_grad():
target_sharp = sharpen_image(target)
loss_sharp = loss_fn(pred_sharp, target_sharp)
return (loss_sharp + loss_normal) / 2
else:
return torch.tensor(0.0, device=self.device)
def get_mae_loss(self, pred, target):
if self.mae_weight > 0:
loss_fn = nn.L1Loss()
loss = loss_fn(pred, target)
return loss
loss_normal = loss_fn(pred, target)
pred_sharp = sharpen_image(pred)
with torch.no_grad():
target_sharp = sharpen_image(target)
loss_sharp = loss_fn(pred_sharp, target_sharp)
return (loss_sharp + loss_normal) / 2
else:
return torch.tensor(0.0, device=self.device)
@@ -820,9 +832,9 @@ class TrainVAEProcess(BaseTrainProcess):
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
latents = self.vae.config['scaling_factor'] * (latents - shift)
if target_latent is not None:
if target_latent is not None and self.train_encoder:
# forward_latents = target_latent.detach()
lat_mse_loss = self.get_mse_loss(target_latent.float(), latents.float())
lat_mse_loss = torch.nn.MSELoss()(target_latent.float(), latents.float())
latents = target_latent.detach()
forward_latents = target_latent.detach()