Scale target vae latent before targeting it

This commit is contained in:
Jaret Burkett
2025-07-17 07:12:21 -06:00
parent e5ed450dc7
commit 3916e67455

View File

@@ -411,6 +411,17 @@ class TrainVAEProcess(BaseTrainProcess):
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
img = img
# latent = self.vae.encode(img).latent_dist.sample()
target_latent = None
if self.target_latent_vae is not None:
target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor
target_input_size = (int(img.shape[2] * target_input_scale), int(img.shape[3] * target_input_scale))
# resize to target input size
target_input_batch = Resize(target_input_size)(img).to(self.device, dtype=torch.float32)
target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach()
shift = self.target_latent_vae.config['shift_factor'] if self.target_latent_vae.config['shift_factor'] is not None else 0
target_latent = self.target_latent_vae.config['scaling_factor'] * (target_latent - shift)
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
latent = self.vae.encode(img, return_dict=False)[0]
latent_img = latent.clone()
@@ -443,6 +454,9 @@ class TrainVAEProcess(BaseTrainProcess):
# convert to pillow image
latent_img = Image.fromarray(latent_img)
if target_latent is not None:
latent = target_latent.to(latent.device, dtype=latent.dtype)
decoded = self.vae.decode(latent).sample
decoded = (decoded / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
@@ -589,7 +603,7 @@ class TrainVAEProcess(BaseTrainProcess):
if self.lpips_weight > 0 and self.lpips_loss is None:
# self.lpips_loss = lpips.LPIPS(net='vgg')
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype)
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16)
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
optimizer_params=self.optimizer_params)
@@ -670,6 +684,9 @@ class TrainVAEProcess(BaseTrainProcess):
# resize to target input size
target_input_batch = Resize(target_input_size)(batch).to(self.device, dtype=torch.float32)
target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach()
# shift scale it
shift = self.target_latent_vae.config['shift_factor'] if self.target_latent_vae.config['shift_factor'] is not None else 0
target_latent = self.target_latent_vae.config['scaling_factor'] * (target_latent - shift)
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
@@ -781,9 +798,9 @@ class TrainVAEProcess(BaseTrainProcess):
mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight
if self.lpips_weight > 0:
lpips_loss = self.lpips_loss(
pred.clamp(-1, 1),
batch.clamp(-1, 1)
).mean() * self.lpips_weight
pred.clamp(-1, 1).to(self.device, dtype=torch.bfloat16),
batch.clamp(-1, 1).to(self.device, dtype=torch.bfloat16)
).float().mean() * self.lpips_weight
else:
lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight