mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added Critic support to VAE training. Still tweaking and working on it. Many other fixes
This commit is contained in:
@@ -6,7 +6,7 @@ from collections import OrderedDict
|
||||
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import save_file
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -15,8 +15,9 @@ from torchvision.transforms import transforms
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.losses import ComparativeTotalVariation
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from diffusers import AutoencoderKL
|
||||
@@ -36,6 +37,139 @@ def unnormalize(tensor):
|
||||
return (tensor / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
|
||||
class Critic:
|
||||
process: 'TrainVAEProcess'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate=1e-5,
|
||||
device='cpu',
|
||||
optimizer='adam',
|
||||
num_critic_per_gen=1,
|
||||
dtype='float32',
|
||||
lambda_gp=10,
|
||||
start_step=0,
|
||||
warmup_steps=1000,
|
||||
process=None
|
||||
):
|
||||
self.learning_rate = learning_rate
|
||||
self.device = device
|
||||
self.optimizer_type = optimizer
|
||||
self.num_critic_per_gen = num_critic_per_gen
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.process = process
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.scheduler = None
|
||||
self.warmup_steps = warmup_steps
|
||||
self.start_step = start_step
|
||||
self.lambda_gp = lambda_gp
|
||||
self.print = self.process.print
|
||||
print(f" Critic config: {self.__dict__}")
|
||||
|
||||
def setup(self):
|
||||
from .models.vgg19_critic import Vgg19Critic
|
||||
self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype)
|
||||
self.load_weights()
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
params = self.model.parameters()
|
||||
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
|
||||
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
|
||||
self.optimizer,
|
||||
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
||||
factor=1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
def load_weights(self):
|
||||
path_to_load = None
|
||||
self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
|
||||
files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
|
||||
if files and len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getmtime)
|
||||
print(f" - Latest checkpoint is: {latest_file}")
|
||||
path_to_load = latest_file
|
||||
else:
|
||||
self.print(f" - No checkpoint found, starting from scratch")
|
||||
if path_to_load:
|
||||
self.model.load_state_dict(load_file(path_to_load))
|
||||
|
||||
def save(self, step=None):
|
||||
self.process.update_training_metadata()
|
||||
save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
# zeropad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors")
|
||||
save_file(self.model.state_dict(), save_path, save_meta)
|
||||
self.print(f"Saved critic to {save_path}")
|
||||
|
||||
def get_critic_loss(self, vgg_output):
|
||||
if self.start_step > self.process.step_num:
|
||||
return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
warmup_scaler = 1.0
|
||||
# we need a warmup when we come on of 1000 steps
|
||||
# we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps
|
||||
if self.process.step_num < self.start_step + self.warmup_steps:
|
||||
warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
|
||||
# set model to not train for generator loss
|
||||
self.model.eval()
|
||||
self.model.requires_grad_(False)
|
||||
vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
|
||||
|
||||
# run model
|
||||
stacked_output = self.model(vgg_pred)
|
||||
|
||||
return (-torch.mean(stacked_output)) * warmup_scaler
|
||||
|
||||
def step(self, vgg_output):
|
||||
|
||||
# train critic here
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
|
||||
critic_losses = []
|
||||
for i in range(self.num_critic_per_gen):
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
|
||||
stacked_output = self.model(inputs)
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
self.optimizer.zero_grad()
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# avg loss
|
||||
loss = np.mean(critic_losses)
|
||||
return loss
|
||||
|
||||
def get_lr(self):
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
self.optimizer.param_groups[0]["d"] *
|
||||
self.optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
return learning_rate
|
||||
|
||||
|
||||
class TrainVAEProcess(BaseTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -61,6 +195,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||
|
||||
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||
self.writer = self.job.writer
|
||||
@@ -68,6 +203,22 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||
self.vgg_19 = None
|
||||
self.progress_bar = None
|
||||
self.style_weight_scalers = []
|
||||
self.content_weight_scalers = []
|
||||
|
||||
self.step_num = 0
|
||||
self.epoch_num = 0
|
||||
|
||||
self.use_critic = self.get_conf('use_critic', False, as_type=bool)
|
||||
self.critic = None
|
||||
|
||||
if self.use_critic:
|
||||
self.critic = Critic(
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
process=self,
|
||||
**self.get_conf('critic', {}) # pass any other params
|
||||
)
|
||||
|
||||
if self.sample_every is not None and self.sample_sources is None:
|
||||
raise ValueError('sample_every is specified but sample_sources is not')
|
||||
@@ -89,6 +240,16 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
|
||||
def update_training_metadata(self):
|
||||
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
'step': self.step_num,
|
||||
'epoch': self.epoch_num,
|
||||
})
|
||||
return info
|
||||
|
||||
def print(self, message, **kwargs):
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.write(message, **kwargs)
|
||||
@@ -117,19 +278,46 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
def setup_vgg19(self):
|
||||
if self.vgg_19 is None:
|
||||
self.vgg_19, self.style_losses, self.content_losses, output = get_style_model_and_losses(
|
||||
single_target=True, device=self.device)
|
||||
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
|
||||
single_target=True,
|
||||
device=self.device,
|
||||
output_layer_name='pool_4',
|
||||
dtype=self.torch_dtype
|
||||
)
|
||||
self.vgg_19.to(self.device, dtype=self.torch_dtype)
|
||||
self.vgg_19.requires_grad_(False)
|
||||
|
||||
# we run random noise through first to get layer scalers to normalize the loss per layer
|
||||
# bs of 2 because we run pred and target through stacked
|
||||
noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype)
|
||||
self.vgg_19(noise)
|
||||
for style_loss in self.style_losses:
|
||||
# get a scaler to normalize to 1
|
||||
scaler = 1 / torch.mean(style_loss.loss).item()
|
||||
self.style_weight_scalers.append(scaler)
|
||||
for content_loss in self.content_losses:
|
||||
# get a scaler to normalize to 1
|
||||
scaler = 1 / torch.mean(content_loss.loss).item()
|
||||
self.content_weight_scalers.append(scaler)
|
||||
|
||||
self.print(f"Style weight scalers: {self.style_weight_scalers}")
|
||||
self.print(f"Content weight scalers: {self.content_weight_scalers}")
|
||||
|
||||
def get_style_loss(self):
|
||||
if self.style_weight > 0:
|
||||
return torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
|
||||
# scale all losses with loss scalers
|
||||
loss = torch.sum(
|
||||
torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)]))
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_content_loss(self):
|
||||
if self.content_weight > 0:
|
||||
return torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
|
||||
# scale all losses with loss scalers
|
||||
loss = torch.sum(torch.stack(
|
||||
[loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)]))
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
@@ -160,7 +348,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
|
||||
def save(self, step=None):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
@@ -170,6 +357,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
# zeropad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
|
||||
self.update_training_metadata()
|
||||
filename = f'{self.job.name}{step_num}.safetensors'
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
@@ -184,7 +372,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
# having issues with meta
|
||||
save_file(state_dict, os.path.join(self.save_root, filename), save_meta)
|
||||
|
||||
print(f"Saved to {os.path.join(self.save_root, filename)}")
|
||||
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
|
||||
|
||||
if self.use_critic:
|
||||
self.critic.save(step)
|
||||
|
||||
def sample(self, step=None):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
@@ -268,6 +459,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
num_steps = self.max_steps
|
||||
if num_steps is None or num_steps > max_epoch_steps:
|
||||
num_steps = max_epoch_steps
|
||||
self.max_steps = num_steps
|
||||
self.epochs = num_epochs
|
||||
start_step = self.step_num
|
||||
|
||||
self.print(f"Training VAE")
|
||||
self.print(f" - Training folder: {self.training_folder}")
|
||||
@@ -304,18 +498,14 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
params += list(self.vae.decoder.conv_out.parameters())
|
||||
self.vae.decoder.conv_out.requires_grad_(True)
|
||||
|
||||
if self.style_weight > 0 or self.content_weight > 0:
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
self.setup_vgg19()
|
||||
self.vgg_19.requires_grad_(False)
|
||||
self.vgg_19.eval()
|
||||
if self.use_critic:
|
||||
self.critic.setup()
|
||||
|
||||
# todo allow other optimizers
|
||||
if self.optimizer_type == 'dadaptation':
|
||||
import dadaptation
|
||||
print("Using DAdaptAdam optimizer")
|
||||
optimizer = dadaptation.DAdaptAdam(params, lr=1)
|
||||
else:
|
||||
optimizer = torch.optim.Adam(params, lr=float(self.learning_rate))
|
||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
|
||||
|
||||
# setup scheduler
|
||||
# todo allow other schedulers
|
||||
@@ -333,7 +523,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
leave=True
|
||||
)
|
||||
|
||||
step = 0
|
||||
# sample first
|
||||
self.sample()
|
||||
blank_losses = OrderedDict({
|
||||
@@ -343,15 +532,17 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
"mse": [],
|
||||
"kl": [],
|
||||
"tv": [],
|
||||
"crD": [],
|
||||
"crG": [],
|
||||
})
|
||||
epoch_losses = copy.deepcopy(blank_losses)
|
||||
log_losses = copy.deepcopy(blank_losses)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
if step >= num_steps:
|
||||
# range start at self.epoch_num go to self.epochs
|
||||
for epoch in range(self.epoch_num, self.epochs, 1):
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
for batch in self.data_loader:
|
||||
if step >= num_steps:
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
|
||||
batch = batch.to(self.device, dtype=self.torch_dtype)
|
||||
@@ -365,18 +556,27 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
pred = self.vae.decode(latents).sample
|
||||
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0:
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
stacked = torch.cat([pred, batch], dim=0)
|
||||
stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
|
||||
if self.use_critic:
|
||||
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
|
||||
else:
|
||||
critic_d_loss = 0.0
|
||||
|
||||
style_loss = self.get_style_loss() * self.style_weight
|
||||
content_loss = self.get_content_loss() * self.content_weight
|
||||
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
||||
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
||||
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
|
||||
if self.use_critic:
|
||||
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
|
||||
else:
|
||||
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -398,6 +598,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
loss_string += f" mse: {mse_loss.item():.2e}"
|
||||
if self.tv_weight > 0:
|
||||
loss_string += f" tv: {tv_loss.item():.2e}"
|
||||
if self.use_critic and self.critic_weight > 0:
|
||||
loss_string += f" crG: {critic_gen_loss.item():.2e}"
|
||||
if self.use_critic:
|
||||
loss_string += f" crD: {critic_d_loss:.2e}"
|
||||
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
@@ -406,7 +610,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
|
||||
|
||||
lr_critic_string = ''
|
||||
if self.use_critic:
|
||||
lr_critic = self.critic.get_lr()
|
||||
lr_critic_string = f" lrC: {lr_critic:.1e}"
|
||||
|
||||
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}")
|
||||
self.progress_bar.set_description(f"E: {epoch}")
|
||||
self.progress_bar.update(1)
|
||||
|
||||
@@ -416,6 +626,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
epoch_losses["mse"].append(mse_loss.item())
|
||||
epoch_losses["kl"].append(kld_loss.item())
|
||||
epoch_losses["tv"].append(tv_loss.item())
|
||||
epoch_losses["crG"].append(critic_gen_loss.item())
|
||||
epoch_losses["crD"].append(critic_d_loss)
|
||||
|
||||
log_losses["total"].append(loss_value)
|
||||
log_losses["style"].append(style_loss.item())
|
||||
@@ -423,30 +635,33 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
log_losses["mse"].append(mse_loss.item())
|
||||
log_losses["kl"].append(kld_loss.item())
|
||||
log_losses["tv"].append(tv_loss.item())
|
||||
log_losses["crG"].append(critic_gen_loss.item())
|
||||
log_losses["crD"].append(critic_d_loss)
|
||||
|
||||
if step != 0:
|
||||
if self.sample_every and step % self.sample_every == 0:
|
||||
# don't do on first step
|
||||
if self.step_num != start_step:
|
||||
if self.sample_every and self.step_num % self.sample_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Sampling at step {step}")
|
||||
self.sample(step)
|
||||
self.print(f"Sampling at step {self.step_num}")
|
||||
self.sample(self.step_num)
|
||||
|
||||
if self.save_every and step % self.save_every == 0:
|
||||
if self.save_every and self.step_num % self.save_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Saving at step {step}")
|
||||
self.save(step)
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
|
||||
if self.log_every and step % self.log_every == 0:
|
||||
if self.log_every and self.step_num % self.log_every == 0:
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
# get avg loss
|
||||
for key in log_losses:
|
||||
log_losses[key] = sum(log_losses[key]) / len(log_losses[key])
|
||||
if log_losses[key] > 0:
|
||||
self.writer.add_scalar(f"loss/{key}", log_losses[key], step)
|
||||
# if log_losses[key] > 0:
|
||||
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
|
||||
# reset log losses
|
||||
log_losses = copy.deepcopy(blank_losses)
|
||||
|
||||
step += 1
|
||||
self.step_num += 1
|
||||
# end epoch
|
||||
if self.writer is not None:
|
||||
# get avg loss
|
||||
|
||||
Reference in New Issue
Block a user