mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added Critic support to VAE training. Still tweaking and working on it. Many other fixes
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import os
|
||||
|
||||
from jobs import BaseJob
|
||||
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
|
||||
from datetime import datetime
|
||||
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
|
||||
@@ -45,7 +48,8 @@ class TrainJob(BaseJob):
|
||||
def setup_tensorboard(self):
|
||||
if self.log_dir:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
self.writer = SummaryWriter(
|
||||
log_dir=self.log_dir,
|
||||
filename_suffix=f"_{self.name}"
|
||||
)
|
||||
now = datetime.now()
|
||||
time_str = now.strftime('%Y%m%d-%H%M%S')
|
||||
summary_name = f"{self.name}_{time_str}"
|
||||
summary_dir = os.path.join(self.log_dir, summary_name)
|
||||
self.writer = SummaryWriter(summary_dir)
|
||||
|
||||
@@ -17,11 +17,23 @@ class BaseProcess:
|
||||
self.job = job
|
||||
self.config = config
|
||||
self.meta = copy.deepcopy(self.job.meta)
|
||||
print(json.dumps(self.config, indent=4))
|
||||
|
||||
def get_conf(self, key, default=None, required=False, as_type=None):
|
||||
if key in self.config:
|
||||
value = self.config[key]
|
||||
if as_type is not None and value is not None:
|
||||
# split key by '.' and recursively get the value
|
||||
keys = key.split('.')
|
||||
|
||||
# see if it exists in the config
|
||||
value = self.config
|
||||
for subkey in keys:
|
||||
if subkey in value:
|
||||
value = value[subkey]
|
||||
else:
|
||||
value = None
|
||||
break
|
||||
|
||||
if value is not None:
|
||||
if as_type is not None:
|
||||
value = as_type(value)
|
||||
return value
|
||||
elif required:
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections import OrderedDict
|
||||
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import save_file
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -15,8 +15,9 @@ from torchvision.transforms import transforms
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.losses import ComparativeTotalVariation
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from diffusers import AutoencoderKL
|
||||
@@ -36,6 +37,139 @@ def unnormalize(tensor):
|
||||
return (tensor / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
|
||||
class Critic:
|
||||
process: 'TrainVAEProcess'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate=1e-5,
|
||||
device='cpu',
|
||||
optimizer='adam',
|
||||
num_critic_per_gen=1,
|
||||
dtype='float32',
|
||||
lambda_gp=10,
|
||||
start_step=0,
|
||||
warmup_steps=1000,
|
||||
process=None
|
||||
):
|
||||
self.learning_rate = learning_rate
|
||||
self.device = device
|
||||
self.optimizer_type = optimizer
|
||||
self.num_critic_per_gen = num_critic_per_gen
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.process = process
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.scheduler = None
|
||||
self.warmup_steps = warmup_steps
|
||||
self.start_step = start_step
|
||||
self.lambda_gp = lambda_gp
|
||||
self.print = self.process.print
|
||||
print(f" Critic config: {self.__dict__}")
|
||||
|
||||
def setup(self):
|
||||
from .models.vgg19_critic import Vgg19Critic
|
||||
self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype)
|
||||
self.load_weights()
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
params = self.model.parameters()
|
||||
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
|
||||
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
|
||||
self.optimizer,
|
||||
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
||||
factor=1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
def load_weights(self):
|
||||
path_to_load = None
|
||||
self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
|
||||
files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
|
||||
if files and len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getmtime)
|
||||
print(f" - Latest checkpoint is: {latest_file}")
|
||||
path_to_load = latest_file
|
||||
else:
|
||||
self.print(f" - No checkpoint found, starting from scratch")
|
||||
if path_to_load:
|
||||
self.model.load_state_dict(load_file(path_to_load))
|
||||
|
||||
def save(self, step=None):
|
||||
self.process.update_training_metadata()
|
||||
save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
# zeropad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors")
|
||||
save_file(self.model.state_dict(), save_path, save_meta)
|
||||
self.print(f"Saved critic to {save_path}")
|
||||
|
||||
def get_critic_loss(self, vgg_output):
|
||||
if self.start_step > self.process.step_num:
|
||||
return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
warmup_scaler = 1.0
|
||||
# we need a warmup when we come on of 1000 steps
|
||||
# we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps
|
||||
if self.process.step_num < self.start_step + self.warmup_steps:
|
||||
warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
|
||||
# set model to not train for generator loss
|
||||
self.model.eval()
|
||||
self.model.requires_grad_(False)
|
||||
vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
|
||||
|
||||
# run model
|
||||
stacked_output = self.model(vgg_pred)
|
||||
|
||||
return (-torch.mean(stacked_output)) * warmup_scaler
|
||||
|
||||
def step(self, vgg_output):
|
||||
|
||||
# train critic here
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
|
||||
critic_losses = []
|
||||
for i in range(self.num_critic_per_gen):
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
|
||||
stacked_output = self.model(inputs)
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
self.optimizer.zero_grad()
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# avg loss
|
||||
loss = np.mean(critic_losses)
|
||||
return loss
|
||||
|
||||
def get_lr(self):
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
self.optimizer.param_groups[0]["d"] *
|
||||
self.optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
return learning_rate
|
||||
|
||||
|
||||
class TrainVAEProcess(BaseTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
@@ -61,6 +195,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||
|
||||
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||
self.writer = self.job.writer
|
||||
@@ -68,6 +203,22 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||
self.vgg_19 = None
|
||||
self.progress_bar = None
|
||||
self.style_weight_scalers = []
|
||||
self.content_weight_scalers = []
|
||||
|
||||
self.step_num = 0
|
||||
self.epoch_num = 0
|
||||
|
||||
self.use_critic = self.get_conf('use_critic', False, as_type=bool)
|
||||
self.critic = None
|
||||
|
||||
if self.use_critic:
|
||||
self.critic = Critic(
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
process=self,
|
||||
**self.get_conf('critic', {}) # pass any other params
|
||||
)
|
||||
|
||||
if self.sample_every is not None and self.sample_sources is None:
|
||||
raise ValueError('sample_every is specified but sample_sources is not')
|
||||
@@ -89,6 +240,16 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
|
||||
def update_training_metadata(self):
|
||||
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
'step': self.step_num,
|
||||
'epoch': self.epoch_num,
|
||||
})
|
||||
return info
|
||||
|
||||
def print(self, message, **kwargs):
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.write(message, **kwargs)
|
||||
@@ -117,19 +278,46 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
def setup_vgg19(self):
|
||||
if self.vgg_19 is None:
|
||||
self.vgg_19, self.style_losses, self.content_losses, output = get_style_model_and_losses(
|
||||
single_target=True, device=self.device)
|
||||
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
|
||||
single_target=True,
|
||||
device=self.device,
|
||||
output_layer_name='pool_4',
|
||||
dtype=self.torch_dtype
|
||||
)
|
||||
self.vgg_19.to(self.device, dtype=self.torch_dtype)
|
||||
self.vgg_19.requires_grad_(False)
|
||||
|
||||
# we run random noise through first to get layer scalers to normalize the loss per layer
|
||||
# bs of 2 because we run pred and target through stacked
|
||||
noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype)
|
||||
self.vgg_19(noise)
|
||||
for style_loss in self.style_losses:
|
||||
# get a scaler to normalize to 1
|
||||
scaler = 1 / torch.mean(style_loss.loss).item()
|
||||
self.style_weight_scalers.append(scaler)
|
||||
for content_loss in self.content_losses:
|
||||
# get a scaler to normalize to 1
|
||||
scaler = 1 / torch.mean(content_loss.loss).item()
|
||||
self.content_weight_scalers.append(scaler)
|
||||
|
||||
self.print(f"Style weight scalers: {self.style_weight_scalers}")
|
||||
self.print(f"Content weight scalers: {self.content_weight_scalers}")
|
||||
|
||||
def get_style_loss(self):
|
||||
if self.style_weight > 0:
|
||||
return torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
|
||||
# scale all losses with loss scalers
|
||||
loss = torch.sum(
|
||||
torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)]))
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_content_loss(self):
|
||||
if self.content_weight > 0:
|
||||
return torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
|
||||
# scale all losses with loss scalers
|
||||
loss = torch.sum(torch.stack(
|
||||
[loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)]))
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
@@ -160,7 +348,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
|
||||
def save(self, step=None):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
@@ -170,6 +357,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
# zeropad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
|
||||
self.update_training_metadata()
|
||||
filename = f'{self.job.name}{step_num}.safetensors'
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
@@ -184,7 +372,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
# having issues with meta
|
||||
save_file(state_dict, os.path.join(self.save_root, filename), save_meta)
|
||||
|
||||
print(f"Saved to {os.path.join(self.save_root, filename)}")
|
||||
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
|
||||
|
||||
if self.use_critic:
|
||||
self.critic.save(step)
|
||||
|
||||
def sample(self, step=None):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
@@ -268,6 +459,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
num_steps = self.max_steps
|
||||
if num_steps is None or num_steps > max_epoch_steps:
|
||||
num_steps = max_epoch_steps
|
||||
self.max_steps = num_steps
|
||||
self.epochs = num_epochs
|
||||
start_step = self.step_num
|
||||
|
||||
self.print(f"Training VAE")
|
||||
self.print(f" - Training folder: {self.training_folder}")
|
||||
@@ -304,18 +498,14 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
params += list(self.vae.decoder.conv_out.parameters())
|
||||
self.vae.decoder.conv_out.requires_grad_(True)
|
||||
|
||||
if self.style_weight > 0 or self.content_weight > 0:
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
self.setup_vgg19()
|
||||
self.vgg_19.requires_grad_(False)
|
||||
self.vgg_19.eval()
|
||||
if self.use_critic:
|
||||
self.critic.setup()
|
||||
|
||||
# todo allow other optimizers
|
||||
if self.optimizer_type == 'dadaptation':
|
||||
import dadaptation
|
||||
print("Using DAdaptAdam optimizer")
|
||||
optimizer = dadaptation.DAdaptAdam(params, lr=1)
|
||||
else:
|
||||
optimizer = torch.optim.Adam(params, lr=float(self.learning_rate))
|
||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
|
||||
|
||||
# setup scheduler
|
||||
# todo allow other schedulers
|
||||
@@ -333,7 +523,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
leave=True
|
||||
)
|
||||
|
||||
step = 0
|
||||
# sample first
|
||||
self.sample()
|
||||
blank_losses = OrderedDict({
|
||||
@@ -343,15 +532,17 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
"mse": [],
|
||||
"kl": [],
|
||||
"tv": [],
|
||||
"crD": [],
|
||||
"crG": [],
|
||||
})
|
||||
epoch_losses = copy.deepcopy(blank_losses)
|
||||
log_losses = copy.deepcopy(blank_losses)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
if step >= num_steps:
|
||||
# range start at self.epoch_num go to self.epochs
|
||||
for epoch in range(self.epoch_num, self.epochs, 1):
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
for batch in self.data_loader:
|
||||
if step >= num_steps:
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
|
||||
batch = batch.to(self.device, dtype=self.torch_dtype)
|
||||
@@ -365,18 +556,27 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
pred = self.vae.decode(latents).sample
|
||||
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0:
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
stacked = torch.cat([pred, batch], dim=0)
|
||||
stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
|
||||
if self.use_critic:
|
||||
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
|
||||
else:
|
||||
critic_d_loss = 0.0
|
||||
|
||||
style_loss = self.get_style_loss() * self.style_weight
|
||||
content_loss = self.get_content_loss() * self.content_weight
|
||||
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
||||
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
||||
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
|
||||
if self.use_critic:
|
||||
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
|
||||
else:
|
||||
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -398,6 +598,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
loss_string += f" mse: {mse_loss.item():.2e}"
|
||||
if self.tv_weight > 0:
|
||||
loss_string += f" tv: {tv_loss.item():.2e}"
|
||||
if self.use_critic and self.critic_weight > 0:
|
||||
loss_string += f" crG: {critic_gen_loss.item():.2e}"
|
||||
if self.use_critic:
|
||||
loss_string += f" crD: {critic_d_loss:.2e}"
|
||||
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
@@ -406,7 +610,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
|
||||
|
||||
lr_critic_string = ''
|
||||
if self.use_critic:
|
||||
lr_critic = self.critic.get_lr()
|
||||
lr_critic_string = f" lrC: {lr_critic:.1e}"
|
||||
|
||||
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}")
|
||||
self.progress_bar.set_description(f"E: {epoch}")
|
||||
self.progress_bar.update(1)
|
||||
|
||||
@@ -416,6 +626,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
epoch_losses["mse"].append(mse_loss.item())
|
||||
epoch_losses["kl"].append(kld_loss.item())
|
||||
epoch_losses["tv"].append(tv_loss.item())
|
||||
epoch_losses["crG"].append(critic_gen_loss.item())
|
||||
epoch_losses["crD"].append(critic_d_loss)
|
||||
|
||||
log_losses["total"].append(loss_value)
|
||||
log_losses["style"].append(style_loss.item())
|
||||
@@ -423,30 +635,33 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
log_losses["mse"].append(mse_loss.item())
|
||||
log_losses["kl"].append(kld_loss.item())
|
||||
log_losses["tv"].append(tv_loss.item())
|
||||
log_losses["crG"].append(critic_gen_loss.item())
|
||||
log_losses["crD"].append(critic_d_loss)
|
||||
|
||||
if step != 0:
|
||||
if self.sample_every and step % self.sample_every == 0:
|
||||
# don't do on first step
|
||||
if self.step_num != start_step:
|
||||
if self.sample_every and self.step_num % self.sample_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Sampling at step {step}")
|
||||
self.sample(step)
|
||||
self.print(f"Sampling at step {self.step_num}")
|
||||
self.sample(self.step_num)
|
||||
|
||||
if self.save_every and step % self.save_every == 0:
|
||||
if self.save_every and self.step_num % self.save_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Saving at step {step}")
|
||||
self.save(step)
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
|
||||
if self.log_every and step % self.log_every == 0:
|
||||
if self.log_every and self.step_num % self.log_every == 0:
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
# get avg loss
|
||||
for key in log_losses:
|
||||
log_losses[key] = sum(log_losses[key]) / len(log_losses[key])
|
||||
if log_losses[key] > 0:
|
||||
self.writer.add_scalar(f"loss/{key}", log_losses[key], step)
|
||||
# if log_losses[key] > 0:
|
||||
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
|
||||
# reset log losses
|
||||
log_losses = copy.deepcopy(blank_losses)
|
||||
|
||||
step += 1
|
||||
self.step_num += 1
|
||||
# end epoch
|
||||
if self.writer is not None:
|
||||
# get avg loss
|
||||
|
||||
38
jobs/process/models/vgg19_critic.py
Normal file
38
jobs/process/models/vgg19_critic.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class MeanReduce(nn.Module):
|
||||
def __init__(self):
|
||||
super(MeanReduce, self).__init__()
|
||||
|
||||
def forward(self, inputs):
|
||||
return torch.mean(inputs, dim=(1, 2, 3), keepdim=True)
|
||||
|
||||
|
||||
class Vgg19Critic(nn.Module):
|
||||
def __init__(self):
|
||||
# vgg19 input (bs, 3, 512, 512)
|
||||
# pool1 (bs, 64, 256, 256)
|
||||
# pool2 (bs, 128, 128, 128)
|
||||
# pool3 (bs, 256, 64, 64)
|
||||
# pool4 (bs, 512, 32, 32) <- take this input
|
||||
|
||||
super(Vgg19Critic, self).__init__()
|
||||
self.main = nn.Sequential(
|
||||
# input (bs, 512, 32, 32)
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2), # (bs, 512, 16, 16)
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2), # (bs, 512, 8, 8)
|
||||
nn.Conv2d(512, 1, kernel_size=3, stride=2, padding=1),
|
||||
# (bs, 1, 4, 4)
|
||||
MeanReduce(), # (bs, 1, 1, 1)
|
||||
nn.Flatten(), # (bs, 1)
|
||||
|
||||
# nn.Flatten(), # (128*8*8) = 8192
|
||||
# nn.Linear(128 * 8 * 8, 1)
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.main(inputs)
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import oyaml as yaml
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.paths import TOOLKIT_ROOT
|
||||
@@ -29,6 +30,20 @@ def preprocess_config(config: OrderedDict):
|
||||
return config
|
||||
|
||||
|
||||
|
||||
# Fixes issue where yaml doesnt load exponents correctly
|
||||
fixed_loader = yaml.SafeLoader
|
||||
fixed_loader.add_implicit_resolver(
|
||||
u'tag:yaml.org,2002:float',
|
||||
re.compile(u'''^(?:
|
||||
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||
|[-+]?\\.(?:inf|Inf|INF)
|
||||
|\\.(?:nan|NaN|NAN))$''', re.X),
|
||||
list(u'-+0123456789.'))
|
||||
|
||||
def get_config(config_file_path):
|
||||
# first check if it is in the config folder
|
||||
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
|
||||
@@ -56,7 +71,7 @@ def get_config(config_file_path):
|
||||
config = json.load(f, object_pairs_hook=OrderedDict)
|
||||
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
|
||||
with open(real_config_path, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
config = yaml.load(f, Loader=fixed_loader)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ def total_variation(image):
|
||||
"""
|
||||
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
|
||||
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
|
||||
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
||||
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
||||
|
||||
|
||||
class ComparativeTotalVariation(torch.nn.Module):
|
||||
@@ -21,3 +21,27 @@ class ComparativeTotalVariation(torch.nn.Module):
|
||||
|
||||
def forward(self, pred, target):
|
||||
return torch.abs(total_variation(pred) - total_variation(target))
|
||||
|
||||
|
||||
# Gradient penalty
|
||||
def get_gradient_penalty(critic, real, fake, device):
|
||||
with torch.autocast(device_type='cuda'):
|
||||
alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
|
||||
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
|
||||
d_interpolates = critic(interpolates)
|
||||
fake = torch.ones(real.size(0), 1, device=device)
|
||||
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=d_interpolates,
|
||||
inputs=interpolates,
|
||||
grad_outputs=fake,
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
only_inputs=True,
|
||||
)[0]
|
||||
|
||||
gradients = gradients.view(gradients.size(0), -1)
|
||||
gradient_norm = gradients.norm(2, dim=1)
|
||||
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
@@ -13,6 +13,16 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
|
||||
# safetensors can only be one level deep
|
||||
for key, value in save_meta.items():
|
||||
# if not float, int, bool, or str, convert to json string
|
||||
if not isinstance(value, (float, int, bool, str)):
|
||||
if not isinstance(value, str):
|
||||
save_meta[key] = json.dumps(value)
|
||||
return save_meta
|
||||
|
||||
|
||||
def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
||||
parsed_meta = OrderedDict()
|
||||
for key, value in meta.items():
|
||||
try:
|
||||
parsed_meta[key] = json.loads(value)
|
||||
except json.decoder.JSONDecodeError:
|
||||
parsed_meta[key] = value
|
||||
return meta
|
||||
|
||||
18
toolkit/optimizer.py
Normal file
18
toolkit/optimizer.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
optimizer_type='adam',
|
||||
learning_rate=1e-6
|
||||
):
|
||||
if optimizer_type == 'dadaptation':
|
||||
# dadaptation optimizer does not use standard learning rate. 1 is the default value
|
||||
import dadaptation
|
||||
print("Using DAdaptAdam optimizer")
|
||||
optimizer = dadaptation.DAdaptAdam(params, lr=1.0)
|
||||
elif optimizer_type == 'adam':
|
||||
optimizer = torch.optim.Adam(params, lr=float(learning_rate))
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||
return optimizer
|
||||
@@ -21,6 +21,7 @@ class ContentLoss(nn.Module):
|
||||
self.loss = None
|
||||
|
||||
def forward(self, stacked_input):
|
||||
|
||||
if self.single_target:
|
||||
split_size = stacked_input.size()[0] // 2
|
||||
pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0)
|
||||
@@ -73,6 +74,8 @@ class StyleLoss(nn.Module):
|
||||
self.device = device
|
||||
|
||||
def forward(self, stacked_input):
|
||||
input_dtype = stacked_input.dtype
|
||||
stacked_input = stacked_input.float()
|
||||
if self.single_target:
|
||||
split_size = stacked_input.size()[0] // 2
|
||||
preds, style_target = torch.split(stacked_input, split_size, dim=0)
|
||||
@@ -94,17 +97,18 @@ class StyleLoss(nn.Module):
|
||||
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
|
||||
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
|
||||
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
|
||||
self.loss = loss
|
||||
return stacked_input
|
||||
self.loss = loss.to(input_dtype)
|
||||
return stacked_input.to(input_dtype)
|
||||
|
||||
|
||||
# create a module to normalize input image so we can easily put it in a
|
||||
# ``nn.Sequential``
|
||||
class Normalization(nn.Module):
|
||||
def __init__(self, device):
|
||||
def __init__(self, device, dtype=torch.float32):
|
||||
super(Normalization, self).__init__()
|
||||
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
|
||||
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
|
||||
self.dtype = dtype
|
||||
# .view the mean and std to make them [C x 1 x 1] so that they can
|
||||
# directly work with image Tensor of shape [B x C x H x W].
|
||||
# B is batch size. C is number of channels. H is height and W is width.
|
||||
@@ -112,9 +116,9 @@ class Normalization(nn.Module):
|
||||
self.std = torch.tensor(std).view(-1, 1, 1)
|
||||
|
||||
def forward(self, stacked_input):
|
||||
# cast to float 32 if not already
|
||||
if stacked_input.dtype != torch.float32:
|
||||
stacked_input = stacked_input.float()
|
||||
# cast to float 32 if not already # only necessary when processing gram matrix
|
||||
# if stacked_input.dtype != torch.float32:
|
||||
# stacked_input = stacked_input.float()
|
||||
# remove alpha channel if it exists
|
||||
if stacked_input.shape[1] == 4:
|
||||
stacked_input = stacked_input[:, :3, :, :]
|
||||
@@ -123,21 +127,37 @@ class Normalization(nn.Module):
|
||||
in_max = torch.max(stacked_input)
|
||||
# norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
|
||||
# return (norm_stacked_input - self.mean) / self.std
|
||||
return (stacked_input - self.mean) / self.std
|
||||
return ((stacked_input - self.mean) / self.std).to(self.dtype)
|
||||
|
||||
|
||||
class OutputLayer(nn.Module):
|
||||
def __init__(self, name='output_layer'):
|
||||
super(OutputLayer, self).__init__()
|
||||
self.name = name
|
||||
self.tensor = None
|
||||
|
||||
def forward(self, stacked_input):
|
||||
self.tensor = stacked_input
|
||||
return stacked_input
|
||||
|
||||
|
||||
def get_style_model_and_losses(
|
||||
single_target=False,
|
||||
single_target=True, # false has 3 targets, dont remember why i added this initially, this is old code
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
output_layer_name=None,
|
||||
dtype=torch.float32
|
||||
):
|
||||
# content_layers = ['conv_4']
|
||||
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
content_layers = ['conv4_2']
|
||||
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||
cnn = models.vgg19(pretrained=True).features.to(device).eval()
|
||||
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
|
||||
# set all weights in the model to our dtype
|
||||
# for layer in cnn.children():
|
||||
# layer.to(dtype=dtype)
|
||||
|
||||
# normalization module
|
||||
normalization = Normalization(device).to(device)
|
||||
normalization = Normalization(device, dtype=dtype).to(device)
|
||||
|
||||
# just in order to have an iterable access to or list of content/style
|
||||
# losses
|
||||
@@ -189,15 +209,15 @@ def get_style_model_and_losses(
|
||||
style_losses.append(style_loss)
|
||||
|
||||
if output_layer_name is not None and name == output_layer_name:
|
||||
output_layer = layer
|
||||
output_layer = OutputLayer(name)
|
||||
model.add_module("output_layer_{}_{}".format(block, i), output_layer)
|
||||
|
||||
# now we trim off the layers after the last content and style losses
|
||||
for i in range(len(model) - 1, -1, -1):
|
||||
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
|
||||
break
|
||||
if output_layer_name is not None and model[i].name == output_layer_name:
|
||||
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss) or isinstance(model[i], OutputLayer):
|
||||
break
|
||||
|
||||
model = model[:(i + 1)]
|
||||
model.to(dtype=dtype)
|
||||
|
||||
return model, style_losses, content_losses, output_layer
|
||||
|
||||
Reference in New Issue
Block a user