Update VAE trainer to handle fixed latent target. Also minor bug fixes and improvements

This commit is contained in:
Jaret Burkett
2025-07-12 16:55:15 -06:00
parent 7ab44ae0cd
commit 2e84b3d5b1
2 changed files with 66 additions and 12 deletions

View File

@@ -59,12 +59,15 @@ class TrainVAEProcess(BaseTrainProcess):
super().__init__(process_id, job, config)
self.data_loader = None
self.vae = None
self.target_latent_vae = None
self.device = self.get_conf('device', self.job.device)
self.vae_path = self.get_conf('vae_path', None)
self.target_latent_vae_path = self.get_conf('target_latent_vae_path', None)
self.eq_vae = self.get_conf('eq_vae', False)
self.datasets_objects = self.get_conf('datasets', required=True)
self.batch_size = self.get_conf('batch_size', 1, as_type=int)
self.resolution = self.get_conf('resolution', 256, as_type=int)
self.sample_resolution = self.get_conf('sample_resolution', self.resolution, as_type=int)
self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
self.sample_every = self.get_conf('sample_every', None)
self.optimizer_type = self.get_conf('optimizer', 'adam')
@@ -78,6 +81,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.mae_weight = self.get_conf('mae_weight', 0, as_type=float)
self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 0, as_type=float)
self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float)
@@ -106,6 +110,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.lpips_loss:lpips.LPIPS = None
self.vae_scale_factor = 8
self.target_vae_scale_factor = 8
self.step_num = 0
self.epoch_num = 0
@@ -244,6 +249,14 @@ class TrainVAEProcess(BaseTrainProcess):
return loss
else:
return torch.tensor(0.0, device=self.device)
def get_mae_loss(self, pred, target):
if self.mae_weight > 0:
loss_fn = nn.L1Loss()
loss = loss_fn(pred, target)
return loss
else:
return torch.tensor(0.0, device=self.device)
def get_kld_loss(self, mu, log_var):
if self.kld_weight > 0:
@@ -389,7 +402,7 @@ class TrainVAEProcess(BaseTrainProcess):
min_dim = min(img.width, img.height)
img = img.crop((0, 0, min_dim, min_dim))
# resize
img = img.resize((self.resolution, self.resolution))
img = img.resize((self.sample_resolution, self.sample_resolution))
input_img = img
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
@@ -420,7 +433,7 @@ class TrainVAEProcess(BaseTrainProcess):
latent_img = (latent_img / 2 + 0.5).clamp(0, 1)
# resize to 256x256
latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest')
latent_img = torch.nn.functional.interpolate(latent_img, size=(self.sample_resolution, self.sample_resolution), mode='nearest')
latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy()
latent_img = (latent_img * 255).astype(np.uint8)
# convert to pillow image
@@ -435,17 +448,19 @@ class TrainVAEProcess(BaseTrainProcess):
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
# stack input image and decoded image
input_img = input_img.resize((self.resolution, self.resolution))
decoded = decoded.resize((self.resolution, self.resolution))
input_img = input_img.resize((self.sample_resolution, self.sample_resolution))
decoded = decoded.resize((self.sample_resolution, self.sample_resolution))
output_img = Image.new('RGB', (self.resolution * 3, self.resolution))
output_img = Image.new('RGB', (self.sample_resolution * 3, self.sample_resolution))
output_img.paste(input_img, (0, 0))
output_img.paste(decoded, (self.resolution, 0))
output_img.paste(latent_img, (self.resolution * 2, 0))
output_img.paste(decoded, (self.sample_resolution, 0))
output_img.paste(latent_img, (self.sample_resolution * 2, 0))
scale_up = 2
if output_img.height <= 300:
scale_up = 4
if output_img.height >= 1000:
scale_up = 1
# scale up using nearest neighbor
output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
@@ -492,6 +507,16 @@ class TrainVAEProcess(BaseTrainProcess):
self.vae.eval()
self.vae.decoder.train()
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
if self.target_latent_vae_path is not None:
self.print(f"Loading target latent VAE from {self.target_latent_vae_path}")
self.target_latent_vae = AutoencoderKL.from_pretrained(self.target_latent_vae_path)
self.target_latent_vae.to(self.device, dtype=self.torch_dtype)
self.target_latent_vae.eval()
self.target_vae_scale_factor = 2 ** (len(self.target_latent_vae.config['block_out_channels']) - 1)
else:
self.target_latent_vae = None
self.target_vae_scale_factor = self.vae_scale_factor
def run(self):
super().run()
@@ -571,7 +596,7 @@ class TrainVAEProcess(BaseTrainProcess):
optimizer,
total_iters=num_steps,
factor=1,
verbose=False
# verbose=False
)
# setup tqdm progress bar
@@ -589,6 +614,8 @@ class TrainVAEProcess(BaseTrainProcess):
"style": [],
"content": [],
"mse": [],
"mae": [],
"lat_mse": [],
"mvl": [],
"ltv": [],
"lpm": [],
@@ -630,6 +657,16 @@ class TrainVAEProcess(BaseTrainProcess):
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
target_latent = None
lat_mse_loss = torch.tensor(0.0, device=self.device)
if self.target_latent_vae is not None:
target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor
target_input_size = (int(batch.shape[2] * target_input_scale), int(batch.shape[3] * target_input_scale))
# resize to target input size
target_input_batch = Resize(target_input_size)(batch)
target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach()
# forward pass
# grad only if eq_vae
@@ -638,7 +675,13 @@ class TrainVAEProcess(BaseTrainProcess):
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
if self.eq_vae:
if target_latent is not None:
# forward_latents = target_latent.detach()
lat_mse_loss = self.get_mse_loss(target_latent, latents)
latents = target_latent.detach()
forward_latents = target_latent.detach()
elif self.eq_vae:
# process flips, rotate, scale
latent_chunks = list(latents.chunk(latents.shape[0], dim=0))
batch_chunks = list(batch.chunk(batch.shape[0], dim=0))
@@ -698,7 +741,7 @@ class TrainVAEProcess(BaseTrainProcess):
batch = torch.cat(batch_chunks, dim=0)
else:
latents.detach().requires_grad_(True)
# latents.detach().requires_grad_(True)
forward_latents = latents
forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype)
@@ -725,6 +768,7 @@ class TrainVAEProcess(BaseTrainProcess):
content_loss = self.get_content_loss() * self.content_weight
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight
if self.lpips_weight > 0:
lpips_loss = self.lpips_loss(
pred.clamp(-1, 1),
@@ -765,7 +809,7 @@ class TrainVAEProcess(BaseTrainProcess):
else:
lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss
# check if loss is NaN or Inf
if torch.isnan(loss) or torch.isinf(loss):
@@ -774,6 +818,8 @@ class TrainVAEProcess(BaseTrainProcess):
self.print(f" - Content loss: {content_loss.item()}")
self.print(f" - KLD loss: {kld_loss.item()}")
self.print(f" - MSE loss: {mse_loss.item()}")
self.print(f" - MAE loss: {mae_loss.item()}")
self.print(f" - Latent MSE loss: {lat_mse_loss.item()}")
self.print(f" - LPIPS loss: {lpips_loss.item()}")
self.print(f" - TV loss: {tv_loss.item()}")
self.print(f" - Pattern loss: {pattern_loss.item()}")
@@ -806,6 +852,10 @@ class TrainVAEProcess(BaseTrainProcess):
loss_string += f" kld: {kld_loss.item():.2e}"
if self.mse_weight > 0:
loss_string += f" mse: {mse_loss.item():.2e}"
if self.mae_weight > 0:
loss_string += f" mae: {mae_loss.item():.2e}"
if self.target_latent_vae:
loss_string += f" lat_mse: {lat_mse_loss.item():.2e}"
if self.tv_weight > 0:
loss_string += f" tv: {tv_loss.item():.2e}"
if self.pattern_weight > 0:
@@ -847,6 +897,8 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["style"].append(style_loss.item())
epoch_losses["content"].append(content_loss.item())
epoch_losses["mse"].append(mse_loss.item())
epoch_losses["mae"].append(mae_loss.item())
epoch_losses["lat_mse"].append(lat_mse_loss.item())
epoch_losses["kl"].append(kld_loss.item())
epoch_losses["tv"].append(tv_loss.item())
epoch_losses["ptn"].append(pattern_loss.item())
@@ -861,6 +913,8 @@ class TrainVAEProcess(BaseTrainProcess):
log_losses["style"].append(style_loss.item())
log_losses["content"].append(content_loss.item())
log_losses["mse"].append(mse_loss.item())
log_losses["mae"].append(mae_loss.item())
log_losses["lat_mse"].append(lat_mse_loss.item())
log_losses["kl"].append(kld_loss.item())
log_losses["tv"].append(tv_loss.item())
log_losses["ptn"].append(pattern_loss.item())

View File

@@ -152,7 +152,7 @@ class Critic:
self.optimizer,
total_iters=self.process.max_steps * self.num_critic_per_gen,
factor=1,
verbose=False,
# verbose=False,
)
def load_weights(self):