From 3916e67455ded57ec43aceaf84e732f95ee92612 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 17 Jul 2025 07:12:21 -0600 Subject: [PATCH] Scale target vae latent before targeting it --- jobs/process/TrainVAEProcess.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 1492b19f..3ec22f78 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -411,6 +411,17 @@ class TrainVAEProcess(BaseTrainProcess): img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) img = img # latent = self.vae.encode(img).latent_dist.sample() + + target_latent = None + if self.target_latent_vae is not None: + target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor + target_input_size = (int(img.shape[2] * target_input_scale), int(img.shape[3] * target_input_scale)) + # resize to target input size + target_input_batch = Resize(target_input_size)(img).to(self.device, dtype=torch.float32) + target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach() + shift = self.target_latent_vae.config['shift_factor'] if self.target_latent_vae.config['shift_factor'] is not None else 0 + target_latent = self.target_latent_vae.config['scaling_factor'] * (target_latent - shift) + target_latent = target_latent.to(self.device, dtype=self.torch_dtype) latent = self.vae.encode(img, return_dict=False)[0] latent_img = latent.clone() @@ -443,6 +454,9 @@ class TrainVAEProcess(BaseTrainProcess): # convert to pillow image latent_img = Image.fromarray(latent_img) + if target_latent is not None: + latent = target_latent.to(latent.device, dtype=latent.dtype) + decoded = self.vae.decode(latent).sample decoded = (decoded / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -589,7 +603,7 @@ class TrainVAEProcess(BaseTrainProcess): if self.lpips_weight > 0 and self.lpips_loss is None: # self.lpips_loss = lpips.LPIPS(net='vgg') - self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype) + self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=torch.bfloat16) optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, optimizer_params=self.optimizer_params) @@ -670,6 +684,9 @@ class TrainVAEProcess(BaseTrainProcess): # resize to target input size target_input_batch = Resize(target_input_size)(batch).to(self.device, dtype=torch.float32) target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach() + # shift scale it + shift = self.target_latent_vae.config['shift_factor'] if self.target_latent_vae.config['shift_factor'] is not None else 0 + target_latent = self.target_latent_vae.config['scaling_factor'] * (target_latent - shift) target_latent = target_latent.to(self.device, dtype=self.torch_dtype) @@ -781,9 +798,9 @@ class TrainVAEProcess(BaseTrainProcess): mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight if self.lpips_weight > 0: lpips_loss = self.lpips_loss( - pred.clamp(-1, 1), - batch.clamp(-1, 1) - ).mean() * self.lpips_weight + pred.clamp(-1, 1).to(self.device, dtype=torch.bfloat16), + batch.clamp(-1, 1).to(self.device, dtype=torch.bfloat16) + ).float().mean() * self.lpips_weight else: lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight