mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Add sharpening before losses with a split loss on vae training
This commit is contained in:
@@ -339,16 +339,28 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
def get_mse_loss(self, pred, target):
|
||||
if self.mse_weight > 0:
|
||||
loss_fn = nn.MSELoss()
|
||||
loss = loss_fn(pred, target)
|
||||
return loss
|
||||
loss_normal = loss_fn(pred, target)
|
||||
|
||||
pred_sharp = sharpen_image(pred)
|
||||
with torch.no_grad():
|
||||
target_sharp = sharpen_image(target)
|
||||
|
||||
loss_sharp = loss_fn(pred_sharp, target_sharp)
|
||||
|
||||
return (loss_sharp + loss_normal) / 2
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_mae_loss(self, pred, target):
|
||||
if self.mae_weight > 0:
|
||||
loss_fn = nn.L1Loss()
|
||||
loss = loss_fn(pred, target)
|
||||
return loss
|
||||
loss_normal = loss_fn(pred, target)
|
||||
|
||||
pred_sharp = sharpen_image(pred)
|
||||
with torch.no_grad():
|
||||
target_sharp = sharpen_image(target)
|
||||
loss_sharp = loss_fn(pred_sharp, target_sharp)
|
||||
return (loss_sharp + loss_normal) / 2
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
@@ -820,9 +832,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||
latents = self.vae.config['scaling_factor'] * (latents - shift)
|
||||
|
||||
if target_latent is not None:
|
||||
if target_latent is not None and self.train_encoder:
|
||||
# forward_latents = target_latent.detach()
|
||||
lat_mse_loss = self.get_mse_loss(target_latent.float(), latents.float())
|
||||
lat_mse_loss = torch.nn.MSELoss()(target_latent.float(), latents.float())
|
||||
latents = target_latent.detach()
|
||||
forward_latents = target_latent.detach()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user