mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Bug fixes and little improvements here and there.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
@@ -13,6 +14,7 @@ from torch import nn
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.image_utils import show_tensors
|
||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
|
||||
@@ -25,6 +27,8 @@ from tqdm import tqdm
|
||||
import time
|
||||
import numpy as np
|
||||
from .models.vgg19_critic import Critic
|
||||
from torchvision.transforms import Resize
|
||||
import lpips
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -62,6 +66,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||
self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
|
||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
|
||||
self.optimizer_params = self.get_conf('optimizer_params', {})
|
||||
@@ -71,6 +76,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.vgg_19 = None
|
||||
self.style_weight_scalers = []
|
||||
self.content_weight_scalers = []
|
||||
self.lpips_loss:lpips.LPIPS = None
|
||||
|
||||
self.vae_scale_factor = 8
|
||||
|
||||
self.step_num = 0
|
||||
self.epoch_num = 0
|
||||
@@ -137,6 +145,15 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
num_workers=6
|
||||
)
|
||||
|
||||
def remove_oldest_checkpoint(self):
|
||||
max_to_keep = 4
|
||||
folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
|
||||
if len(folders) > max_to_keep:
|
||||
folders.sort(key=os.path.getmtime)
|
||||
for folder in folders[:-max_to_keep]:
|
||||
print(f"Removing {folder}")
|
||||
shutil.rmtree(folder)
|
||||
|
||||
def setup_vgg19(self):
|
||||
if self.vgg_19 is None:
|
||||
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
|
||||
@@ -211,7 +228,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
def get_pattern_loss(self, pred, target):
|
||||
if self._pattern_loss is None:
|
||||
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device,
|
||||
self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device,
|
||||
dtype=self.torch_dtype)
|
||||
loss = torch.mean(self._pattern_loss(pred, target))
|
||||
return loss
|
||||
@@ -226,25 +243,21 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
|
||||
self.update_training_metadata()
|
||||
filename = f'{self.job.name}{step_num}.safetensors'
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
filename = f'{self.job.name}{step_num}_diffusers'
|
||||
|
||||
state_dict = convert_diffusers_back_to_ldm(self.vae)
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(torch.float32)
|
||||
state_dict[key] = v
|
||||
|
||||
# having issues with meta
|
||||
save_file(state_dict, os.path.join(self.save_root, filename), save_meta)
|
||||
self.vae = self.vae.to("cpu", dtype=torch.float16)
|
||||
self.vae.save_pretrained(
|
||||
save_directory=os.path.join(self.save_root, filename)
|
||||
)
|
||||
self.vae = self.vae.to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
|
||||
|
||||
if self.use_critic:
|
||||
self.critic.save(step)
|
||||
|
||||
self.remove_oldest_checkpoint()
|
||||
|
||||
def sample(self, step=None):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if not os.path.exists(sample_folder):
|
||||
@@ -280,6 +293,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
output_img.paste(input_img, (0, 0))
|
||||
output_img.paste(decoded, (self.resolution, 0))
|
||||
|
||||
scale_up = 2
|
||||
if output_img.height <= 300:
|
||||
scale_up = 4
|
||||
|
||||
# scale up using nearest neighbor
|
||||
output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST)
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
# zero-pad 9 digits
|
||||
@@ -294,7 +314,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
path_to_load = self.vae_path
|
||||
# see if we have a checkpoint in out output to resume from
|
||||
self.print(f"Looking for latest checkpoint in {self.save_root}")
|
||||
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors"))
|
||||
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers"))
|
||||
if files and len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getmtime)
|
||||
print(f" - Latest checkpoint is: {latest_file}")
|
||||
@@ -306,13 +326,14 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.print(f"Loading VAE")
|
||||
self.print(f" - Loading VAE: {path_to_load}")
|
||||
if self.vae is None:
|
||||
self.vae = load_vae(path_to_load, dtype=self.torch_dtype)
|
||||
self.vae = AutoencoderKL.from_pretrained(path_to_load)
|
||||
|
||||
# set decoder to train
|
||||
self.vae.to(self.device, dtype=self.torch_dtype)
|
||||
self.vae.requires_grad_(False)
|
||||
self.vae.eval()
|
||||
self.vae.decoder.train()
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -374,6 +395,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.use_critic:
|
||||
self.critic.setup()
|
||||
|
||||
if self.lpips_weight > 0 and self.lpips_loss is None:
|
||||
# self.lpips_loss = lpips.LPIPS(net='vgg')
|
||||
self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||
optimizer_params=self.optimizer_params)
|
||||
|
||||
@@ -397,6 +422,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.sample()
|
||||
blank_losses = OrderedDict({
|
||||
"total": [],
|
||||
"lpips": [],
|
||||
"style": [],
|
||||
"content": [],
|
||||
"mse": [],
|
||||
@@ -415,17 +441,29 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
for batch in self.data_loader:
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
with torch.no_grad():
|
||||
|
||||
batch = batch.to(self.device, dtype=self.torch_dtype)
|
||||
batch = batch.to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
# forward pass
|
||||
dgd = self.vae.encode(batch).latent_dist
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
latents.requires_grad_(True)
|
||||
# resize so it matches size of vae evenly
|
||||
if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
|
||||
batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor,
|
||||
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
|
||||
|
||||
# forward pass
|
||||
dgd = self.vae.encode(batch).latent_dist
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
latents.detach().requires_grad_(True)
|
||||
|
||||
pred = self.vae.decode(latents).sample
|
||||
|
||||
with torch.no_grad():
|
||||
show_tensors(
|
||||
pred.clamp(-1, 1).clone(),
|
||||
"combined tensor"
|
||||
)
|
||||
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
stacked = torch.cat([pred, batch], dim=0)
|
||||
@@ -441,14 +479,31 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
content_loss = self.get_content_loss() * self.content_weight
|
||||
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
||||
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
||||
if self.lpips_weight > 0:
|
||||
lpips_loss = self.lpips_loss(
|
||||
pred.clamp(-1, 1),
|
||||
batch.clamp(-1, 1)
|
||||
).mean() * self.lpips_weight
|
||||
else:
|
||||
lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
|
||||
pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
|
||||
if self.use_critic:
|
||||
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
|
||||
|
||||
# do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
|
||||
if self.lpips_weight > 0:
|
||||
max_target = lpips_loss.abs() * 0.1
|
||||
with torch.no_grad():
|
||||
crit_g_scaler = 1.0
|
||||
if critic_gen_loss.abs() > max_target:
|
||||
crit_g_scaler = max_target / critic_gen_loss.abs()
|
||||
|
||||
critic_gen_loss *= crit_g_scaler
|
||||
else:
|
||||
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -460,6 +515,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
loss_value = loss.item()
|
||||
# get exponent like 3.54e-4
|
||||
loss_string = f"loss: {loss_value:.2e}"
|
||||
if self.lpips_weight > 0:
|
||||
loss_string += f" lpips: {lpips_loss.item():.2e}"
|
||||
if self.content_weight > 0:
|
||||
loss_string += f" cnt: {content_loss.item():.2e}"
|
||||
if self.style_weight > 0:
|
||||
@@ -496,6 +553,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.progress_bar.update(1)
|
||||
|
||||
epoch_losses["total"].append(loss_value)
|
||||
epoch_losses["lpips"].append(lpips_loss.item())
|
||||
epoch_losses["style"].append(style_loss.item())
|
||||
epoch_losses["content"].append(content_loss.item())
|
||||
epoch_losses["mse"].append(mse_loss.item())
|
||||
@@ -506,6 +564,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
epoch_losses["crD"].append(critic_d_loss)
|
||||
|
||||
log_losses["total"].append(loss_value)
|
||||
log_losses["lpips"].append(lpips_loss.item())
|
||||
log_losses["style"].append(style_loss.item())
|
||||
log_losses["content"].append(content_loss.item())
|
||||
log_losses["mse"].append(mse_loss.item())
|
||||
|
||||
Reference in New Issue
Block a user