Use scale shift in vae latent space for vae trainer

This commit is contained in:
Jaret Burkett
2025-07-17 08:14:07 -06:00
parent f500b9f240
commit e25d2feddf

View File

@@ -424,6 +424,9 @@ class TrainVAEProcess(BaseTrainProcess):
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
latent = self.vae.encode(img, return_dict=False)[0]
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
latent = self.vae.config['scaling_factor'] * (latent - shift)
latent_img = latent.clone()
bs, ch, h, w = latent_img.shape
grid_size = math.ceil(math.sqrt(ch))
@@ -456,6 +459,9 @@ class TrainVAEProcess(BaseTrainProcess):
if target_latent is not None:
latent = target_latent.to(latent.device, dtype=latent.dtype)
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
latent = latent / self.vae.config['scaling_factor'] + shift
decoded = self.vae.decode(latent).sample
decoded = (decoded / 2 + 0.5).clamp(0, 1)
@@ -702,6 +708,10 @@ class TrainVAEProcess(BaseTrainProcess):
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
# scale shift latent to config
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
latents = self.vae.config['scaling_factor'] * (latents - shift)
if target_latent is not None:
# forward_latents = target_latent.detach()
lat_mse_loss = self.get_mse_loss(target_latent.float(), latents.float())
@@ -776,6 +786,10 @@ class TrainVAEProcess(BaseTrainProcess):
if not self.train_encoder:
# detach latents if not training encoder
forward_latents = forward_latents.detach()
# shift latents to match vae config
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
forward_latents = forward_latents / self.vae.config['scaling_factor'] + shift
pred = self.vae.decode(forward_latents).sample