Added Critic support to VAE training. Still tweaking and working on it. Many other fixes

This commit is contained in:
Jaret Burkett
2023-07-19 15:57:32 -06:00
parent 6ada328d8d
commit 557732e7ff
9 changed files with 415 additions and 59 deletions

View File

@@ -6,7 +6,7 @@ from collections import OrderedDict
from PIL import Image
from PIL.ImageOps import exif_transpose
from safetensors.torch import save_file
from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader, ConcatDataset
import torch
from torch import nn
@@ -15,8 +15,9 @@ from torchvision.transforms import transforms
from jobs.process import BaseTrainProcess
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty
from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype
from diffusers import AutoencoderKL
@@ -36,6 +37,139 @@ def unnormalize(tensor):
return (tensor / 2 + 0.5).clamp(0, 1)
class Critic:
process: 'TrainVAEProcess'
def __init__(
self,
learning_rate=1e-5,
device='cpu',
optimizer='adam',
num_critic_per_gen=1,
dtype='float32',
lambda_gp=10,
start_step=0,
warmup_steps=1000,
process=None
):
self.learning_rate = learning_rate
self.device = device
self.optimizer_type = optimizer
self.num_critic_per_gen = num_critic_per_gen
self.dtype = dtype
self.torch_dtype = get_torch_dtype(self.dtype)
self.process = process
self.model = None
self.optimizer = None
self.scheduler = None
self.warmup_steps = warmup_steps
self.start_step = start_step
self.lambda_gp = lambda_gp
self.print = self.process.print
print(f" Critic config: {self.__dict__}")
def setup(self):
from .models.vgg19_critic import Vgg19Critic
self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype)
self.load_weights()
self.model.train()
self.model.requires_grad_(True)
params = self.model.parameters()
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
self.optimizer,
total_iters=self.process.max_steps * self.num_critic_per_gen,
factor=1,
verbose=False
)
def load_weights(self):
path_to_load = None
self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
if files and len(files) > 0:
latest_file = max(files, key=os.path.getmtime)
print(f" - Latest checkpoint is: {latest_file}")
path_to_load = latest_file
else:
self.print(f" - No checkpoint found, starting from scratch")
if path_to_load:
self.model.load_state_dict(load_file(path_to_load))
def save(self, step=None):
self.process.update_training_metadata()
save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
step_num = ''
if step is not None:
# zeropad 9 digits
step_num = f"_{str(step).zfill(9)}"
save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors")
save_file(self.model.state_dict(), save_path, save_meta)
self.print(f"Saved critic to {save_path}")
def get_critic_loss(self, vgg_output):
if self.start_step > self.process.step_num:
return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
warmup_scaler = 1.0
# we need a warmup when we come on of 1000 steps
# we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps
if self.process.step_num < self.start_step + self.warmup_steps:
warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
# set model to not train for generator loss
self.model.eval()
self.model.requires_grad_(False)
vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
# run model
stacked_output = self.model(vgg_pred)
return (-torch.mean(stacked_output)) * warmup_scaler
def step(self, vgg_output):
# train critic here
self.model.train()
self.model.requires_grad_(True)
critic_losses = []
for i in range(self.num_critic_per_gen):
inputs = vgg_output.detach()
inputs = inputs.to(self.device, dtype=self.torch_dtype)
self.optimizer.zero_grad()
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
stacked_output = self.model(inputs)
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
# Compute gradient penalty
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# Compute WGAN-GP critic loss
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
critic_loss.backward()
self.optimizer.zero_grad()
self.optimizer.step()
self.scheduler.step()
critic_losses.append(critic_loss.item())
# avg loss
loss = np.mean(critic_losses)
return loss
def get_lr(self):
if self.optimizer_type.startswith('dadaptation'):
learning_rate = (
self.optimizer.param_groups[0]["d"] *
self.optimizer.param_groups[0]["lr"]
)
else:
learning_rate = self.optimizer.param_groups[0]['lr']
return learning_rate
class TrainVAEProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -61,6 +195,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
self.writer = self.job.writer
@@ -68,6 +203,22 @@ class TrainVAEProcess(BaseTrainProcess):
self.save_root = os.path.join(self.training_folder, self.job.name)
self.vgg_19 = None
self.progress_bar = None
self.style_weight_scalers = []
self.content_weight_scalers = []
self.step_num = 0
self.epoch_num = 0
self.use_critic = self.get_conf('use_critic', False, as_type=bool)
self.critic = None
if self.use_critic:
self.critic = Critic(
device=self.device,
dtype=self.dtype,
process=self,
**self.get_conf('critic', {}) # pass any other params
)
if self.sample_every is not None and self.sample_sources is None:
raise ValueError('sample_every is specified but sample_sources is not')
@@ -89,6 +240,16 @@ class TrainVAEProcess(BaseTrainProcess):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
def update_training_metadata(self):
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
def get_training_info(self):
info = OrderedDict({
'step': self.step_num,
'epoch': self.epoch_num,
})
return info
def print(self, message, **kwargs):
if self.progress_bar is not None:
self.progress_bar.write(message, **kwargs)
@@ -117,19 +278,46 @@ class TrainVAEProcess(BaseTrainProcess):
def setup_vgg19(self):
if self.vgg_19 is None:
self.vgg_19, self.style_losses, self.content_losses, output = get_style_model_and_losses(
single_target=True, device=self.device)
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
single_target=True,
device=self.device,
output_layer_name='pool_4',
dtype=self.torch_dtype
)
self.vgg_19.to(self.device, dtype=self.torch_dtype)
self.vgg_19.requires_grad_(False)
# we run random noise through first to get layer scalers to normalize the loss per layer
# bs of 2 because we run pred and target through stacked
noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype)
self.vgg_19(noise)
for style_loss in self.style_losses:
# get a scaler to normalize to 1
scaler = 1 / torch.mean(style_loss.loss).item()
self.style_weight_scalers.append(scaler)
for content_loss in self.content_losses:
# get a scaler to normalize to 1
scaler = 1 / torch.mean(content_loss.loss).item()
self.content_weight_scalers.append(scaler)
self.print(f"Style weight scalers: {self.style_weight_scalers}")
self.print(f"Content weight scalers: {self.content_weight_scalers}")
def get_style_loss(self):
if self.style_weight > 0:
return torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
# scale all losses with loss scalers
loss = torch.sum(
torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)]))
return loss
else:
return torch.tensor(0.0, device=self.device)
def get_content_loss(self):
if self.content_weight > 0:
return torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
# scale all losses with loss scalers
loss = torch.sum(torch.stack(
[loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)]))
return loss
else:
return torch.tensor(0.0, device=self.device)
@@ -160,7 +348,6 @@ class TrainVAEProcess(BaseTrainProcess):
else:
return torch.tensor(0.0, device=self.device)
def save(self, step=None):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -170,6 +357,7 @@ class TrainVAEProcess(BaseTrainProcess):
# zeropad 9 digits
step_num = f"_{str(step).zfill(9)}"
self.update_training_metadata()
filename = f'{self.job.name}{step_num}.safetensors'
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
@@ -184,7 +372,10 @@ class TrainVAEProcess(BaseTrainProcess):
# having issues with meta
save_file(state_dict, os.path.join(self.save_root, filename), save_meta)
print(f"Saved to {os.path.join(self.save_root, filename)}")
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
if self.use_critic:
self.critic.save(step)
def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
@@ -268,6 +459,9 @@ class TrainVAEProcess(BaseTrainProcess):
num_steps = self.max_steps
if num_steps is None or num_steps > max_epoch_steps:
num_steps = max_epoch_steps
self.max_steps = num_steps
self.epochs = num_epochs
start_step = self.step_num
self.print(f"Training VAE")
self.print(f" - Training folder: {self.training_folder}")
@@ -304,18 +498,14 @@ class TrainVAEProcess(BaseTrainProcess):
params += list(self.vae.decoder.conv_out.parameters())
self.vae.decoder.conv_out.requires_grad_(True)
if self.style_weight > 0 or self.content_weight > 0:
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
self.setup_vgg19()
self.vgg_19.requires_grad_(False)
self.vgg_19.eval()
if self.use_critic:
self.critic.setup()
# todo allow other optimizers
if self.optimizer_type == 'dadaptation':
import dadaptation
print("Using DAdaptAdam optimizer")
optimizer = dadaptation.DAdaptAdam(params, lr=1)
else:
optimizer = torch.optim.Adam(params, lr=float(self.learning_rate))
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
# setup scheduler
# todo allow other schedulers
@@ -333,7 +523,6 @@ class TrainVAEProcess(BaseTrainProcess):
leave=True
)
step = 0
# sample first
self.sample()
blank_losses = OrderedDict({
@@ -343,15 +532,17 @@ class TrainVAEProcess(BaseTrainProcess):
"mse": [],
"kl": [],
"tv": [],
"crD": [],
"crG": [],
})
epoch_losses = copy.deepcopy(blank_losses)
log_losses = copy.deepcopy(blank_losses)
for epoch in range(num_epochs):
if step >= num_steps:
# range start at self.epoch_num go to self.epochs
for epoch in range(self.epoch_num, self.epochs, 1):
if self.step_num >= self.max_steps:
break
for batch in self.data_loader:
if step >= num_steps:
if self.step_num >= self.max_steps:
break
batch = batch.to(self.device, dtype=self.torch_dtype)
@@ -365,18 +556,27 @@ class TrainVAEProcess(BaseTrainProcess):
pred = self.vae.decode(latents).sample
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0:
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
stacked = torch.cat([pred, batch], dim=0)
stacked = (stacked / 2 + 0.5).clamp(0, 1)
self.vgg_19(stacked)
if self.use_critic:
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
else:
critic_d_loss = 0.0
style_loss = self.get_style_loss() * self.style_weight
content_loss = self.get_content_loss() * self.content_weight
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
if self.use_critic:
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
else:
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss
# Backward pass and optimization
optimizer.zero_grad()
@@ -398,6 +598,10 @@ class TrainVAEProcess(BaseTrainProcess):
loss_string += f" mse: {mse_loss.item():.2e}"
if self.tv_weight > 0:
loss_string += f" tv: {tv_loss.item():.2e}"
if self.use_critic and self.critic_weight > 0:
loss_string += f" crG: {critic_gen_loss.item():.2e}"
if self.use_critic:
loss_string += f" crD: {critic_d_loss:.2e}"
if self.optimizer_type.startswith('dadaptation'):
learning_rate = (
@@ -406,7 +610,13 @@ class TrainVAEProcess(BaseTrainProcess):
)
else:
learning_rate = optimizer.param_groups[0]['lr']
self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
lr_critic_string = ''
if self.use_critic:
lr_critic = self.critic.get_lr()
lr_critic_string = f" lrC: {lr_critic:.1e}"
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}")
self.progress_bar.set_description(f"E: {epoch}")
self.progress_bar.update(1)
@@ -416,6 +626,8 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["mse"].append(mse_loss.item())
epoch_losses["kl"].append(kld_loss.item())
epoch_losses["tv"].append(tv_loss.item())
epoch_losses["crG"].append(critic_gen_loss.item())
epoch_losses["crD"].append(critic_d_loss)
log_losses["total"].append(loss_value)
log_losses["style"].append(style_loss.item())
@@ -423,30 +635,33 @@ class TrainVAEProcess(BaseTrainProcess):
log_losses["mse"].append(mse_loss.item())
log_losses["kl"].append(kld_loss.item())
log_losses["tv"].append(tv_loss.item())
log_losses["crG"].append(critic_gen_loss.item())
log_losses["crD"].append(critic_d_loss)
if step != 0:
if self.sample_every and step % self.sample_every == 0:
# don't do on first step
if self.step_num != start_step:
if self.sample_every and self.step_num % self.sample_every == 0:
# print above the progress bar
self.print(f"Sampling at step {step}")
self.sample(step)
self.print(f"Sampling at step {self.step_num}")
self.sample(self.step_num)
if self.save_every and step % self.save_every == 0:
if self.save_every and self.step_num % self.save_every == 0:
# print above the progress bar
self.print(f"Saving at step {step}")
self.save(step)
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
if self.log_every and step % self.log_every == 0:
if self.log_every and self.step_num % self.log_every == 0:
# log to tensorboard
if self.writer is not None:
# get avg loss
for key in log_losses:
log_losses[key] = sum(log_losses[key]) / len(log_losses[key])
if log_losses[key] > 0:
self.writer.add_scalar(f"loss/{key}", log_losses[key], step)
# if log_losses[key] > 0:
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
# reset log losses
log_losses = copy.deepcopy(blank_losses)
step += 1
self.step_num += 1
# end epoch
if self.writer is not None:
# get avg loss