Bug fixes and little improvements here and there.

This commit is contained in:
Jaret Burkett
2024-06-08 06:24:20 -06:00
parent 833c833f28
commit 3f3636b788
12 changed files with 358 additions and 117 deletions

View File

@@ -1,6 +1,7 @@
import copy
import glob
import os
import shutil
import time
from collections import OrderedDict
@@ -13,6 +14,7 @@ from torch import nn
from torchvision.transforms import transforms
from jobs.process import BaseTrainProcess
from toolkit.image_utils import show_tensors
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
@@ -25,6 +27,8 @@ from tqdm import tqdm
import time
import numpy as np
from .models.vgg19_critic import Critic
from torchvision.transforms import Resize
import lpips
IMAGE_TRANSFORMS = transforms.Compose(
[
@@ -62,6 +66,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
self.optimizer_params = self.get_conf('optimizer_params', {})
@@ -71,6 +76,9 @@ class TrainVAEProcess(BaseTrainProcess):
self.vgg_19 = None
self.style_weight_scalers = []
self.content_weight_scalers = []
self.lpips_loss:lpips.LPIPS = None
self.vae_scale_factor = 8
self.step_num = 0
self.epoch_num = 0
@@ -137,6 +145,15 @@ class TrainVAEProcess(BaseTrainProcess):
num_workers=6
)
def remove_oldest_checkpoint(self):
max_to_keep = 4
folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
if len(folders) > max_to_keep:
folders.sort(key=os.path.getmtime)
for folder in folders[:-max_to_keep]:
print(f"Removing {folder}")
shutil.rmtree(folder)
def setup_vgg19(self):
if self.vgg_19 is None:
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
@@ -211,7 +228,7 @@ class TrainVAEProcess(BaseTrainProcess):
def get_pattern_loss(self, pred, target):
if self._pattern_loss is None:
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device,
self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device,
dtype=self.torch_dtype)
loss = torch.mean(self._pattern_loss(pred, target))
return loss
@@ -226,25 +243,21 @@ class TrainVAEProcess(BaseTrainProcess):
step_num = f"_{str(step).zfill(9)}"
self.update_training_metadata()
filename = f'{self.job.name}{step_num}.safetensors'
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
filename = f'{self.job.name}{step_num}_diffusers'
state_dict = convert_diffusers_back_to_ldm(self.vae)
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(torch.float32)
state_dict[key] = v
# having issues with meta
save_file(state_dict, os.path.join(self.save_root, filename), save_meta)
self.vae = self.vae.to("cpu", dtype=torch.float16)
self.vae.save_pretrained(
save_directory=os.path.join(self.save_root, filename)
)
self.vae = self.vae.to(self.device, dtype=self.torch_dtype)
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
if self.use_critic:
self.critic.save(step)
self.remove_oldest_checkpoint()
def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder):
@@ -280,6 +293,13 @@ class TrainVAEProcess(BaseTrainProcess):
output_img.paste(input_img, (0, 0))
output_img.paste(decoded, (self.resolution, 0))
scale_up = 2
if output_img.height <= 300:
scale_up = 4
# scale up using nearest neighbor
output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
step_num = ''
if step is not None:
# zero-pad 9 digits
@@ -294,7 +314,7 @@ class TrainVAEProcess(BaseTrainProcess):
path_to_load = self.vae_path
# see if we have a checkpoint in out output to resume from
self.print(f"Looking for latest checkpoint in {self.save_root}")
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors"))
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
if files and len(files) > 0:
latest_file = max(files, key=os.path.getmtime)
print(f" - Latest checkpoint is: {latest_file}")
@@ -306,13 +326,14 @@ class TrainVAEProcess(BaseTrainProcess):
self.print(f"Loading VAE")
self.print(f" - Loading VAE: {path_to_load}")
if self.vae is None:
self.vae = load_vae(path_to_load, dtype=self.torch_dtype)
self.vae = AutoencoderKL.from_pretrained(path_to_load)
# set decoder to train
self.vae.to(self.device, dtype=self.torch_dtype)
self.vae.requires_grad_(False)
self.vae.eval()
self.vae.decoder.train()
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
def run(self):
super().run()
@@ -374,6 +395,10 @@ class TrainVAEProcess(BaseTrainProcess):
if self.use_critic:
self.critic.setup()
if self.lpips_weight > 0 and self.lpips_loss is None:
# self.lpips_loss = lpips.LPIPS(net='vgg')
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype)
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
optimizer_params=self.optimizer_params)
@@ -397,6 +422,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.sample()
blank_losses = OrderedDict({
"total": [],
"lpips": [],
"style": [],
"content": [],
"mse": [],
@@ -415,17 +441,29 @@ class TrainVAEProcess(BaseTrainProcess):
for batch in self.data_loader:
if self.step_num >= self.max_steps:
break
with torch.no_grad():
batch = batch.to(self.device, dtype=self.torch_dtype)
batch = batch.to(self.device, dtype=self.torch_dtype)
# forward pass
dgd = self.vae.encode(batch).latent_dist
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
latents.requires_grad_(True)
# resize so it matches size of vae evenly
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
# forward pass
dgd = self.vae.encode(batch).latent_dist
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
latents.detach().requires_grad_(True)
pred = self.vae.decode(latents).sample
with torch.no_grad():
show_tensors(
pred.clamp(-1, 1).clone(),
"combined tensor"
)
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
stacked = torch.cat([pred, batch], dim=0)
@@ -441,14 +479,31 @@ class TrainVAEProcess(BaseTrainProcess):
content_loss = self.get_content_loss() * self.content_weight
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
if self.lpips_weight > 0:
lpips_loss = self.lpips_loss(
pred.clamp(-1, 1),
batch.clamp(-1, 1)
).mean() * self.lpips_weight
else:
lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
if self.use_critic:
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
# do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
if self.lpips_weight > 0:
max_target = lpips_loss.abs() * 0.1
with torch.no_grad():
crit_g_scaler = 1.0
if critic_gen_loss.abs() > max_target:
crit_g_scaler = max_target / critic_gen_loss.abs()
critic_gen_loss *= crit_g_scaler
else:
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss
# Backward pass and optimization
optimizer.zero_grad()
@@ -460,6 +515,8 @@ class TrainVAEProcess(BaseTrainProcess):
loss_value = loss.item()
# get exponent like 3.54e-4
loss_string = f"loss: {loss_value:.2e}"
if self.lpips_weight > 0:
loss_string += f" lpips: {lpips_loss.item():.2e}"
if self.content_weight > 0:
loss_string += f" cnt: {content_loss.item():.2e}"
if self.style_weight > 0:
@@ -496,6 +553,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.progress_bar.update(1)
epoch_losses["total"].append(loss_value)
epoch_losses["lpips"].append(lpips_loss.item())
epoch_losses["style"].append(style_loss.item())
epoch_losses["content"].append(content_loss.item())
epoch_losses["mse"].append(mse_loss.item())
@@ -506,6 +564,7 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["crD"].append(critic_d_loss)
log_losses["total"].append(loss_value)
log_losses["lpips"].append(lpips_loss.item())
log_losses["style"].append(style_loss.item())
log_losses["content"].append(content_loss.item())
log_losses["mse"].append(mse_loss.item())