mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-27 00:49:47 +00:00
Use scale shift in vae latent space for vae trainer
This commit is contained in:
@@ -424,6 +424,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
target_latent = target_latent.to(self.device, dtype=self.torch_dtype)
|
||||
latent = self.vae.encode(img, return_dict=False)[0]
|
||||
|
||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||
latent = self.vae.config['scaling_factor'] * (latent - shift)
|
||||
|
||||
latent_img = latent.clone()
|
||||
bs, ch, h, w = latent_img.shape
|
||||
grid_size = math.ceil(math.sqrt(ch))
|
||||
@@ -456,6 +459,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
if target_latent is not None:
|
||||
latent = target_latent.to(latent.device, dtype=latent.dtype)
|
||||
|
||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||
latent = latent / self.vae.config['scaling_factor'] + shift
|
||||
|
||||
decoded = self.vae.decode(latent).sample
|
||||
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
||||
@@ -702,6 +708,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
|
||||
# scale shift latent to config
|
||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||
latents = self.vae.config['scaling_factor'] * (latents - shift)
|
||||
|
||||
if target_latent is not None:
|
||||
# forward_latents = target_latent.detach()
|
||||
lat_mse_loss = self.get_mse_loss(target_latent.float(), latents.float())
|
||||
@@ -776,6 +786,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if not self.train_encoder:
|
||||
# detach latents if not training encoder
|
||||
forward_latents = forward_latents.detach()
|
||||
|
||||
# shift latents to match vae config
|
||||
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
||||
forward_latents = forward_latents / self.vae.config['scaling_factor'] + shift
|
||||
|
||||
pred = self.vae.decode(forward_latents).sample
|
||||
|
||||
|
||||
Reference in New Issue
Block a user