diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 3ec22f78..c324603e 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -424,6 +424,9 @@ class TrainVAEProcess(BaseTrainProcess): target_latent = target_latent.to(self.device, dtype=self.torch_dtype) latent = self.vae.encode(img, return_dict=False)[0] + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + latent = self.vae.config['scaling_factor'] * (latent - shift) + latent_img = latent.clone() bs, ch, h, w = latent_img.shape grid_size = math.ceil(math.sqrt(ch)) @@ -456,6 +459,9 @@ class TrainVAEProcess(BaseTrainProcess): if target_latent is not None: latent = target_latent.to(latent.device, dtype=latent.dtype) + + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + latent = latent / self.vae.config['scaling_factor'] + shift decoded = self.vae.decode(latent).sample decoded = (decoded / 2 + 0.5).clamp(0, 1) @@ -702,6 +708,10 @@ class TrainVAEProcess(BaseTrainProcess): mu, logvar = dgd.mean, dgd.logvar latents = dgd.sample() + # scale shift latent to config + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + latents = self.vae.config['scaling_factor'] * (latents - shift) + if target_latent is not None: # forward_latents = target_latent.detach() lat_mse_loss = self.get_mse_loss(target_latent.float(), latents.float()) @@ -776,6 +786,10 @@ class TrainVAEProcess(BaseTrainProcess): if not self.train_encoder: # detach latents if not training encoder forward_latents = forward_latents.detach() + + # shift latents to match vae config + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + forward_latents = forward_latents / self.vae.config['scaling_factor'] + shift pred = self.vae.decode(forward_latents).sample