mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Update VAE trainer to handle fixed latent target. Also minor bug fixes and improvements
This commit is contained in:
@@ -59,12 +59,15 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
super().__init__(process_id, job, config)
|
super().__init__(process_id, job, config)
|
||||||
self.data_loader = None
|
self.data_loader = None
|
||||||
self.vae = None
|
self.vae = None
|
||||||
|
self.target_latent_vae = None
|
||||||
self.device = self.get_conf('device', self.job.device)
|
self.device = self.get_conf('device', self.job.device)
|
||||||
self.vae_path = self.get_conf('vae_path', None)
|
self.vae_path = self.get_conf('vae_path', None)
|
||||||
|
self.target_latent_vae_path = self.get_conf('target_latent_vae_path', None)
|
||||||
self.eq_vae = self.get_conf('eq_vae', False)
|
self.eq_vae = self.get_conf('eq_vae', False)
|
||||||
self.datasets_objects = self.get_conf('datasets', required=True)
|
self.datasets_objects = self.get_conf('datasets', required=True)
|
||||||
self.batch_size = self.get_conf('batch_size', 1, as_type=int)
|
self.batch_size = self.get_conf('batch_size', 1, as_type=int)
|
||||||
self.resolution = self.get_conf('resolution', 256, as_type=int)
|
self.resolution = self.get_conf('resolution', 256, as_type=int)
|
||||||
|
self.sample_resolution = self.get_conf('sample_resolution', self.resolution, as_type=int)
|
||||||
self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
|
self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
|
||||||
self.sample_every = self.get_conf('sample_every', None)
|
self.sample_every = self.get_conf('sample_every', None)
|
||||||
self.optimizer_type = self.get_conf('optimizer', 'adam')
|
self.optimizer_type = self.get_conf('optimizer', 'adam')
|
||||||
@@ -78,6 +81,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
|
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
|
||||||
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
||||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||||
|
self.mae_weight = self.get_conf('mae_weight', 0, as_type=float)
|
||||||
self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float)
|
self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float)
|
||||||
self.tv_weight = self.get_conf('tv_weight', 0, as_type=float)
|
self.tv_weight = self.get_conf('tv_weight', 0, as_type=float)
|
||||||
self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float)
|
self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float)
|
||||||
@@ -106,6 +110,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.lpips_loss:lpips.LPIPS = None
|
self.lpips_loss:lpips.LPIPS = None
|
||||||
|
|
||||||
self.vae_scale_factor = 8
|
self.vae_scale_factor = 8
|
||||||
|
self.target_vae_scale_factor = 8
|
||||||
|
|
||||||
self.step_num = 0
|
self.step_num = 0
|
||||||
self.epoch_num = 0
|
self.epoch_num = 0
|
||||||
@@ -244,6 +249,14 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
return loss
|
return loss
|
||||||
else:
|
else:
|
||||||
return torch.tensor(0.0, device=self.device)
|
return torch.tensor(0.0, device=self.device)
|
||||||
|
|
||||||
|
def get_mae_loss(self, pred, target):
|
||||||
|
if self.mae_weight > 0:
|
||||||
|
loss_fn = nn.L1Loss()
|
||||||
|
loss = loss_fn(pred, target)
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
return torch.tensor(0.0, device=self.device)
|
||||||
|
|
||||||
def get_kld_loss(self, mu, log_var):
|
def get_kld_loss(self, mu, log_var):
|
||||||
if self.kld_weight > 0:
|
if self.kld_weight > 0:
|
||||||
@@ -389,7 +402,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
min_dim = min(img.width, img.height)
|
min_dim = min(img.width, img.height)
|
||||||
img = img.crop((0, 0, min_dim, min_dim))
|
img = img.crop((0, 0, min_dim, min_dim))
|
||||||
# resize
|
# resize
|
||||||
img = img.resize((self.resolution, self.resolution))
|
img = img.resize((self.sample_resolution, self.sample_resolution))
|
||||||
|
|
||||||
input_img = img
|
input_img = img
|
||||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
||||||
@@ -420,7 +433,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
latent_img = (latent_img / 2 + 0.5).clamp(0, 1)
|
latent_img = (latent_img / 2 + 0.5).clamp(0, 1)
|
||||||
|
|
||||||
# resize to 256x256
|
# resize to 256x256
|
||||||
latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest')
|
latent_img = torch.nn.functional.interpolate(latent_img, size=(self.sample_resolution, self.sample_resolution), mode='nearest')
|
||||||
latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy()
|
latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy()
|
||||||
latent_img = (latent_img * 255).astype(np.uint8)
|
latent_img = (latent_img * 255).astype(np.uint8)
|
||||||
# convert to pillow image
|
# convert to pillow image
|
||||||
@@ -435,17 +448,19 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
|
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
|
||||||
|
|
||||||
# stack input image and decoded image
|
# stack input image and decoded image
|
||||||
input_img = input_img.resize((self.resolution, self.resolution))
|
input_img = input_img.resize((self.sample_resolution, self.sample_resolution))
|
||||||
decoded = decoded.resize((self.resolution, self.resolution))
|
decoded = decoded.resize((self.sample_resolution, self.sample_resolution))
|
||||||
|
|
||||||
output_img = Image.new('RGB', (self.resolution * 3, self.resolution))
|
output_img = Image.new('RGB', (self.sample_resolution * 3, self.sample_resolution))
|
||||||
output_img.paste(input_img, (0, 0))
|
output_img.paste(input_img, (0, 0))
|
||||||
output_img.paste(decoded, (self.resolution, 0))
|
output_img.paste(decoded, (self.sample_resolution, 0))
|
||||||
output_img.paste(latent_img, (self.resolution * 2, 0))
|
output_img.paste(latent_img, (self.sample_resolution * 2, 0))
|
||||||
|
|
||||||
scale_up = 2
|
scale_up = 2
|
||||||
if output_img.height <= 300:
|
if output_img.height <= 300:
|
||||||
scale_up = 4
|
scale_up = 4
|
||||||
|
if output_img.height >= 1000:
|
||||||
|
scale_up = 1
|
||||||
|
|
||||||
# scale up using nearest neighbor
|
# scale up using nearest neighbor
|
||||||
output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
|
output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
|
||||||
@@ -492,6 +507,16 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.vae.eval()
|
self.vae.eval()
|
||||||
self.vae.decoder.train()
|
self.vae.decoder.train()
|
||||||
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||||
|
|
||||||
|
if self.target_latent_vae_path is not None:
|
||||||
|
self.print(f"Loading target latent VAE from {self.target_latent_vae_path}")
|
||||||
|
self.target_latent_vae = AutoencoderKL.from_pretrained(self.target_latent_vae_path)
|
||||||
|
self.target_latent_vae.to(self.device, dtype=self.torch_dtype)
|
||||||
|
self.target_latent_vae.eval()
|
||||||
|
self.target_vae_scale_factor = 2 ** (len(self.target_latent_vae.config['block_out_channels']) - 1)
|
||||||
|
else:
|
||||||
|
self.target_latent_vae = None
|
||||||
|
self.target_vae_scale_factor = self.vae_scale_factor
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run()
|
super().run()
|
||||||
@@ -571,7 +596,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
optimizer,
|
optimizer,
|
||||||
total_iters=num_steps,
|
total_iters=num_steps,
|
||||||
factor=1,
|
factor=1,
|
||||||
verbose=False
|
# verbose=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# setup tqdm progress bar
|
# setup tqdm progress bar
|
||||||
@@ -589,6 +614,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
"style": [],
|
"style": [],
|
||||||
"content": [],
|
"content": [],
|
||||||
"mse": [],
|
"mse": [],
|
||||||
|
"mae": [],
|
||||||
|
"lat_mse": [],
|
||||||
"mvl": [],
|
"mvl": [],
|
||||||
"ltv": [],
|
"ltv": [],
|
||||||
"lpm": [],
|
"lpm": [],
|
||||||
@@ -630,6 +657,16 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
|
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
|
||||||
batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
|
batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
|
||||||
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
|
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
|
||||||
|
|
||||||
|
target_latent = None
|
||||||
|
lat_mse_loss = torch.tensor(0.0, device=self.device)
|
||||||
|
if self.target_latent_vae is not None:
|
||||||
|
target_input_scale = self.target_vae_scale_factor / self.vae_scale_factor
|
||||||
|
target_input_size = (int(batch.shape[2] * target_input_scale), int(batch.shape[3] * target_input_scale))
|
||||||
|
# resize to target input size
|
||||||
|
target_input_batch = Resize(target_input_size)(batch)
|
||||||
|
target_latent = self.target_latent_vae.encode(target_input_batch).latent_dist.sample().detach()
|
||||||
|
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
# grad only if eq_vae
|
# grad only if eq_vae
|
||||||
@@ -638,7 +675,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
mu, logvar = dgd.mean, dgd.logvar
|
mu, logvar = dgd.mean, dgd.logvar
|
||||||
latents = dgd.sample()
|
latents = dgd.sample()
|
||||||
|
|
||||||
if self.eq_vae:
|
if target_latent is not None:
|
||||||
|
# forward_latents = target_latent.detach()
|
||||||
|
lat_mse_loss = self.get_mse_loss(target_latent, latents)
|
||||||
|
latents = target_latent.detach()
|
||||||
|
forward_latents = target_latent.detach()
|
||||||
|
|
||||||
|
elif self.eq_vae:
|
||||||
# process flips, rotate, scale
|
# process flips, rotate, scale
|
||||||
latent_chunks = list(latents.chunk(latents.shape[0], dim=0))
|
latent_chunks = list(latents.chunk(latents.shape[0], dim=0))
|
||||||
batch_chunks = list(batch.chunk(batch.shape[0], dim=0))
|
batch_chunks = list(batch.chunk(batch.shape[0], dim=0))
|
||||||
@@ -698,7 +741,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
batch = torch.cat(batch_chunks, dim=0)
|
batch = torch.cat(batch_chunks, dim=0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
latents.detach().requires_grad_(True)
|
# latents.detach().requires_grad_(True)
|
||||||
forward_latents = latents
|
forward_latents = latents
|
||||||
|
|
||||||
forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype)
|
forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype)
|
||||||
@@ -725,6 +768,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
content_loss = self.get_content_loss() * self.content_weight
|
content_loss = self.get_content_loss() * self.content_weight
|
||||||
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
||||||
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
||||||
|
mae_loss = self.get_mae_loss(pred, batch) * self.mae_weight
|
||||||
if self.lpips_weight > 0:
|
if self.lpips_weight > 0:
|
||||||
lpips_loss = self.lpips_loss(
|
lpips_loss = self.lpips_loss(
|
||||||
pred.clamp(-1, 1),
|
pred.clamp(-1, 1),
|
||||||
@@ -765,7 +809,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
else:
|
else:
|
||||||
lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||||
|
|
||||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss
|
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + mae_loss + lat_mse_loss
|
||||||
|
|
||||||
# check if loss is NaN or Inf
|
# check if loss is NaN or Inf
|
||||||
if torch.isnan(loss) or torch.isinf(loss):
|
if torch.isnan(loss) or torch.isinf(loss):
|
||||||
@@ -774,6 +818,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.print(f" - Content loss: {content_loss.item()}")
|
self.print(f" - Content loss: {content_loss.item()}")
|
||||||
self.print(f" - KLD loss: {kld_loss.item()}")
|
self.print(f" - KLD loss: {kld_loss.item()}")
|
||||||
self.print(f" - MSE loss: {mse_loss.item()}")
|
self.print(f" - MSE loss: {mse_loss.item()}")
|
||||||
|
self.print(f" - MAE loss: {mae_loss.item()}")
|
||||||
|
self.print(f" - Latent MSE loss: {lat_mse_loss.item()}")
|
||||||
self.print(f" - LPIPS loss: {lpips_loss.item()}")
|
self.print(f" - LPIPS loss: {lpips_loss.item()}")
|
||||||
self.print(f" - TV loss: {tv_loss.item()}")
|
self.print(f" - TV loss: {tv_loss.item()}")
|
||||||
self.print(f" - Pattern loss: {pattern_loss.item()}")
|
self.print(f" - Pattern loss: {pattern_loss.item()}")
|
||||||
@@ -806,6 +852,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
loss_string += f" kld: {kld_loss.item():.2e}"
|
loss_string += f" kld: {kld_loss.item():.2e}"
|
||||||
if self.mse_weight > 0:
|
if self.mse_weight > 0:
|
||||||
loss_string += f" mse: {mse_loss.item():.2e}"
|
loss_string += f" mse: {mse_loss.item():.2e}"
|
||||||
|
if self.mae_weight > 0:
|
||||||
|
loss_string += f" mae: {mae_loss.item():.2e}"
|
||||||
|
if self.target_latent_vae:
|
||||||
|
loss_string += f" lat_mse: {lat_mse_loss.item():.2e}"
|
||||||
if self.tv_weight > 0:
|
if self.tv_weight > 0:
|
||||||
loss_string += f" tv: {tv_loss.item():.2e}"
|
loss_string += f" tv: {tv_loss.item():.2e}"
|
||||||
if self.pattern_weight > 0:
|
if self.pattern_weight > 0:
|
||||||
@@ -847,6 +897,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
epoch_losses["style"].append(style_loss.item())
|
epoch_losses["style"].append(style_loss.item())
|
||||||
epoch_losses["content"].append(content_loss.item())
|
epoch_losses["content"].append(content_loss.item())
|
||||||
epoch_losses["mse"].append(mse_loss.item())
|
epoch_losses["mse"].append(mse_loss.item())
|
||||||
|
epoch_losses["mae"].append(mae_loss.item())
|
||||||
|
epoch_losses["lat_mse"].append(lat_mse_loss.item())
|
||||||
epoch_losses["kl"].append(kld_loss.item())
|
epoch_losses["kl"].append(kld_loss.item())
|
||||||
epoch_losses["tv"].append(tv_loss.item())
|
epoch_losses["tv"].append(tv_loss.item())
|
||||||
epoch_losses["ptn"].append(pattern_loss.item())
|
epoch_losses["ptn"].append(pattern_loss.item())
|
||||||
@@ -861,6 +913,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
log_losses["style"].append(style_loss.item())
|
log_losses["style"].append(style_loss.item())
|
||||||
log_losses["content"].append(content_loss.item())
|
log_losses["content"].append(content_loss.item())
|
||||||
log_losses["mse"].append(mse_loss.item())
|
log_losses["mse"].append(mse_loss.item())
|
||||||
|
log_losses["mae"].append(mae_loss.item())
|
||||||
|
log_losses["lat_mse"].append(lat_mse_loss.item())
|
||||||
log_losses["kl"].append(kld_loss.item())
|
log_losses["kl"].append(kld_loss.item())
|
||||||
log_losses["tv"].append(tv_loss.item())
|
log_losses["tv"].append(tv_loss.item())
|
||||||
log_losses["ptn"].append(pattern_loss.item())
|
log_losses["ptn"].append(pattern_loss.item())
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ class Critic:
|
|||||||
self.optimizer,
|
self.optimizer,
|
||||||
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
||||||
factor=1,
|
factor=1,
|
||||||
verbose=False,
|
# verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self):
|
def load_weights(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user