mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Update VAE trainer to handle fixed latent target. Also minor bug fixes and improvements
This commit is contained in:
@@ -59,12 +59,15 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
super().__init__(process_id, job, config)
|
||||
self.data_loader = None
|
||||
self.vae = None
|
||||
self.target_latent_vae = None
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.vae_path = self.get_conf('vae_path', None)
|
||||
self.target_latent_vae_path = self.get_conf('target_latent_vae_path', None)
|
||||
self.eq_vae = self.get_conf('eq_vae', False)
|
||||
self.datasets_objects = self.get_conf('datasets', required=True)
|
||||
self.batch_size = self.get_conf('batch_size', 1, as_type=int)
|
||||
self.resolution = self.get_conf('resolution', 256, as_type=int)
|
||||
self.sample_resolution = self.get_conf('sample_resolution', self.resolution, as_type=int)
|
||||
self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
|
||||
self.sample_every = self.get_conf('sample_every', None)
|
||||
self.optimizer_type = self.get_conf('optimizer', 'adam')
|
||||
@@ -78,6 +81,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
|
||||
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.mae_weight = self.get_conf('mae_weight', 0, as_type=float)
|
||||
self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 0, as_type=float)
|
||||
self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float)
|
||||
@@ -106,6 +110,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.lpips_loss:lpips.LPIPS = None
|
||||
|
||||
self.vae_scale_factor = 8
|
||||
self.target_vae_scale_factor = 8
|
||||
|
||||
self.step_num = 0
|
||||
self.epoch_num = 0
|
||||
@@ -244,6 +249,14 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_mae_loss(self, pred, target):
|
||||
if self.mae_weight > 0:
|
||||
loss_fn = nn.L1Loss()
|
||||
loss = loss_fn(pred, target)
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_kld_loss(self, mu, log_var):
|
||||
if self.kld_weight > 0:
|
||||
@@ -389,7 +402,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
min_dim = min(img.width, img.height)
|
||||
img = img.crop((0, 0, min_dim, min_dim))
|
||||
# resize
|
||||
img = img.resize((self.resolution, self.resolution))
|
||||
img = img.resize((self.sample_resolution, self.sample_resolution))
|
||||
|
||||
input_img = img
|
||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
||||
@@ -420,7 +433,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
latent_img = (latent_img / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# resize to 256x256
|
||||
latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest')
|
||||
latent_img = torch.nn.functional.interpolate(latent_img, size=(self.sample_resolution, self.sample_resolution), mode='nearest')
|
||||
latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy()
|
||||
latent_img = (latent_img * 255).astype(np.uint8)
|
||||
# convert to pillow image
|
||||
@@ -435,17 +448,19 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
|
||||
|
||||
# stack input image and decoded image
|
||||
input_img = input_img.resize((self.resolution, self.resolution))
|
||||
decoded = decoded.resize((self.resolution, self.resolution))
|
||||
input_img = input_img.resize((self.sample_resolution, self.sample_resolution))
|
||||
decoded = decoded.resize((self.sample_resolution, self.sample_resolution))
|
||||
|
||||
output_img = Image.new('RGB', (self.resolution * 3, self.resolution))
|
||||
output_img = Image.new('RGB', (self.sample_resolution * 3, self.sample_resolution))
|
||||
output_img.paste(input_img, (0, 0))
|
||||
output_img.paste(decoded, (self.resolution, 0))
|
||||
output_img.paste(latent_img, (self.resolution * 2, 0))
|
||||
output_img.paste(decoded, (self.sample_resolution, 0))
|
||||
output_img.paste(latent_img, (self.sample_resolution * 2, 0))
|
||||
|
||||
scale_up = 2
|
||||
if output_img.height <= 300:
|
||||
scale_up = 4
|
||||
if output_img.height >= 1000:
|
||||
scale_up = 1
|
||||
|
||||
# scale up using nearest neighbor
|
||||
output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
|
||||
@@ -492,6 +507,16 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.vae.eval()
|
||||
self.vae.decoder.train()
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
if self.target_latent_vae_path is not None:
|
||||
self.print(f"Loading target latent VAE from {self.target_latent_vae_path}")
|
||||
self.target_latent_vae = AutoencoderKL.from_pretrained(self.target_latent_vae_path)
|
||||
self.target_latent_vae.to(self.device, dtype=self.torch_dtype)
|
||||
self.target_latent_vae.eval()
|
||||
self.target_vae_scale_factor = 2 ** (len(self.target_latent_vae.config['block_out_channels']) - 1)
|
||||
else:
|
||||
self.target_latent_vae = None
|
||||
self.target_vae_scale_factor = self.vae_scale_factor
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -571,7 +596,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
optimizer,
|
||||
total_iters=num_steps,
|
||||
factor=1,
|
||||
verbose=False
|
||||
# verbose=False
|
||||
)
|
||||
|
||||
# setup tqdm progress bar
|
||||
@@ -589,6 +614,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
"style": [],
|
||||
"content": [],
|
||||
"mse": [],
|
||||
"mae": [],
|
||||
"lat_mse": [],
|
||||
"mvl": [],
|
||||
"ltv": [],
|
||||
"lpm": [],
|
||||
@@ -630,6 +657,16 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
|
||||
batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
|
||||
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
|
||||
|
||||
target_latent = None
|
||||
lat_mse_loss = torch.tensor(0.0, device=self.device)
|
||||
if self.target_latent_vae is not None:
|
||||
target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor
|
||||
target_input_size = (int(batch.shape[2] * target_input_scale), int(batch.shape[3] * target_input_scale))
|
||||
# resize to target input size
|
||||
target_input_batch = Resize(target_input_size)(batch)
|
||||
target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach()
|
||||
|
||||
|
||||
# forward pass
|
||||
# grad only if eq_vae
|
||||
@@ -638,7 +675,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
|
||||
if self.eq_vae:
|
||||
if target_latent is not None:
|
||||
# forward_latents = target_latent.detach()
|
||||
lat_mse_loss = self.get_mse_loss(target_latent, latents)
|
||||
latents = target_latent.detach()
|
||||
forward_latents = target_latent.detach()
|
||||
|
||||
elif self.eq_vae:
|
||||
# process flips, rotate, scale
|
||||
latent_chunks = list(latents.chunk(latents.shape[0], dim=0))
|
||||
batch_chunks = list(batch.chunk(batch.shape[0], dim=0))
|
||||
@@ -698,7 +741,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
batch = torch.cat(batch_chunks, dim=0)
|
||||
|
||||
else:
|
||||
latents.detach().requires_grad_(True)
|
||||
# latents.detach().requires_grad_(True)
|
||||
forward_latents = latents
|
||||
|
||||
forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype)
|
||||
@@ -725,6 +768,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
content_loss = self.get_content_loss() * self.content_weight
|
||||
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
||||
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
||||
mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight
|
||||
if self.lpips_weight > 0:
|
||||
lpips_loss = self.lpips_loss(
|
||||
pred.clamp(-1, 1),
|
||||
@@ -765,7 +809,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
else:
|
||||
lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss
|
||||
|
||||
# check if loss is NaN or Inf
|
||||
if torch.isnan(loss) or torch.isinf(loss):
|
||||
@@ -774,6 +818,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.print(f" - Content loss: {content_loss.item()}")
|
||||
self.print(f" - KLD loss: {kld_loss.item()}")
|
||||
self.print(f" - MSE loss: {mse_loss.item()}")
|
||||
self.print(f" - MAE loss: {mae_loss.item()}")
|
||||
self.print(f" - Latent MSE loss: {lat_mse_loss.item()}")
|
||||
self.print(f" - LPIPS loss: {lpips_loss.item()}")
|
||||
self.print(f" - TV loss: {tv_loss.item()}")
|
||||
self.print(f" - Pattern loss: {pattern_loss.item()}")
|
||||
@@ -806,6 +852,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
loss_string += f" kld: {kld_loss.item():.2e}"
|
||||
if self.mse_weight > 0:
|
||||
loss_string += f" mse: {mse_loss.item():.2e}"
|
||||
if self.mae_weight > 0:
|
||||
loss_string += f" mae: {mae_loss.item():.2e}"
|
||||
if self.target_latent_vae:
|
||||
loss_string += f" lat_mse: {lat_mse_loss.item():.2e}"
|
||||
if self.tv_weight > 0:
|
||||
loss_string += f" tv: {tv_loss.item():.2e}"
|
||||
if self.pattern_weight > 0:
|
||||
@@ -847,6 +897,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
epoch_losses["style"].append(style_loss.item())
|
||||
epoch_losses["content"].append(content_loss.item())
|
||||
epoch_losses["mse"].append(mse_loss.item())
|
||||
epoch_losses["mae"].append(mae_loss.item())
|
||||
epoch_losses["lat_mse"].append(lat_mse_loss.item())
|
||||
epoch_losses["kl"].append(kld_loss.item())
|
||||
epoch_losses["tv"].append(tv_loss.item())
|
||||
epoch_losses["ptn"].append(pattern_loss.item())
|
||||
@@ -861,6 +913,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
log_losses["style"].append(style_loss.item())
|
||||
log_losses["content"].append(content_loss.item())
|
||||
log_losses["mse"].append(mse_loss.item())
|
||||
log_losses["mae"].append(mae_loss.item())
|
||||
log_losses["lat_mse"].append(lat_mse_loss.item())
|
||||
log_losses["kl"].append(kld_loss.item())
|
||||
log_losses["tv"].append(tv_loss.item())
|
||||
log_losses["ptn"].append(pattern_loss.item())
|
||||
|
||||
@@ -152,7 +152,7 @@ class Critic:
|
||||
self.optimizer,
|
||||
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
||||
factor=1,
|
||||
verbose=False,
|
||||
# verbose=False,
|
||||
)
|
||||
|
||||
def load_weights(self):
|
||||
|
||||
Reference in New Issue
Block a user