Added Critic support to VAE training. Still tweaking and working on it. Many other fixes

This commit is contained in:
Jaret Burkett
2023-07-19 15:57:32 -06:00
parent 6ada328d8d
commit 557732e7ff
9 changed files with 415 additions and 59 deletions

View File

@@ -1,8 +1,11 @@
import os
from jobs import BaseJob
from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint
from collections import OrderedDict
from typing import List
from jobs.process import BaseExtractProcess, TrainFineTuneProcess
from datetime import datetime
from toolkit.paths import REPOS_ROOT
@@ -45,7 +48,8 @@ class TrainJob(BaseJob):
def setup_tensorboard(self):
if self.log_dir:
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(
log_dir=self.log_dir,
filename_suffix=f"_{self.name}"
)
now = datetime.now()
time_str = now.strftime('%Y%m%d-%H%M%S')
summary_name = f"{self.name}_{time_str}"
summary_dir = os.path.join(self.log_dir, summary_name)
self.writer = SummaryWriter(summary_dir)

View File

@@ -17,11 +17,23 @@ class BaseProcess:
self.job = job
self.config = config
self.meta = copy.deepcopy(self.job.meta)
print(json.dumps(self.config, indent=4))
def get_conf(self, key, default=None, required=False, as_type=None):
if key in self.config:
value = self.config[key]
if as_type is not None and value is not None:
# split key by '.' and recursively get the value
keys = key.split('.')
# see if it exists in the config
value = self.config
for subkey in keys:
if subkey in value:
value = value[subkey]
else:
value = None
break
if value is not None:
if as_type is not None:
value = as_type(value)
return value
elif required:

View File

@@ -6,7 +6,7 @@ from collections import OrderedDict
from PIL import Image
from PIL.ImageOps import exif_transpose
from safetensors.torch import save_file
from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader, ConcatDataset
import torch
from torch import nn
@@ -15,8 +15,9 @@ from torchvision.transforms import transforms
from jobs.process import BaseTrainProcess
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty
from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype
from diffusers import AutoencoderKL
@@ -36,6 +37,139 @@ def unnormalize(tensor):
return (tensor / 2 + 0.5).clamp(0, 1)
class Critic:
process: 'TrainVAEProcess'
def __init__(
self,
learning_rate=1e-5,
device='cpu',
optimizer='adam',
num_critic_per_gen=1,
dtype='float32',
lambda_gp=10,
start_step=0,
warmup_steps=1000,
process=None
):
self.learning_rate = learning_rate
self.device = device
self.optimizer_type = optimizer
self.num_critic_per_gen = num_critic_per_gen
self.dtype = dtype
self.torch_dtype = get_torch_dtype(self.dtype)
self.process = process
self.model = None
self.optimizer = None
self.scheduler = None
self.warmup_steps = warmup_steps
self.start_step = start_step
self.lambda_gp = lambda_gp
self.print = self.process.print
print(f" Critic config: {self.__dict__}")
def setup(self):
from .models.vgg19_critic import Vgg19Critic
self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype)
self.load_weights()
self.model.train()
self.model.requires_grad_(True)
params = self.model.parameters()
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
self.optimizer,
total_iters=self.process.max_steps * self.num_critic_per_gen,
factor=1,
verbose=False
)
def load_weights(self):
path_to_load = None
self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
if files and len(files) > 0:
latest_file = max(files, key=os.path.getmtime)
print(f" - Latest checkpoint is: {latest_file}")
path_to_load = latest_file
else:
self.print(f" - No checkpoint found, starting from scratch")
if path_to_load:
self.model.load_state_dict(load_file(path_to_load))
def save(self, step=None):
self.process.update_training_metadata()
save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
step_num = ''
if step is not None:
# zeropad 9 digits
step_num = f"_{str(step).zfill(9)}"
save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors")
save_file(self.model.state_dict(), save_path, save_meta)
self.print(f"Saved critic to {save_path}")
def get_critic_loss(self, vgg_output):
if self.start_step > self.process.step_num:
return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
warmup_scaler = 1.0
# we need a warmup when we come on of 1000 steps
# we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps
if self.process.step_num < self.start_step + self.warmup_steps:
warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
# set model to not train for generator loss
self.model.eval()
self.model.requires_grad_(False)
vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
# run model
stacked_output = self.model(vgg_pred)
return (-torch.mean(stacked_output)) * warmup_scaler
def step(self, vgg_output):
# train critic here
self.model.train()
self.model.requires_grad_(True)
critic_losses = []
for i in range(self.num_critic_per_gen):
inputs = vgg_output.detach()
inputs = inputs.to(self.device, dtype=self.torch_dtype)
self.optimizer.zero_grad()
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
stacked_output = self.model(inputs)
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
# Compute gradient penalty
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# Compute WGAN-GP critic loss
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
critic_loss.backward()
self.optimizer.zero_grad()
self.optimizer.step()
self.scheduler.step()
critic_losses.append(critic_loss.item())
# avg loss
loss = np.mean(critic_losses)
return loss
def get_lr(self):
if self.optimizer_type.startswith('dadaptation'):
learning_rate = (
self.optimizer.param_groups[0]["d"] *
self.optimizer.param_groups[0]["lr"]
)
else:
learning_rate = self.optimizer.param_groups[0]['lr']
return learning_rate
class TrainVAEProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config)
@@ -61,6 +195,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
self.writer = self.job.writer
@@ -68,6 +203,22 @@ class TrainVAEProcess(BaseTrainProcess):
self.save_root = os.path.join(self.training_folder, self.job.name)
self.vgg_19 = None
self.progress_bar = None
self.style_weight_scalers = []
self.content_weight_scalers = []
self.step_num = 0
self.epoch_num = 0
self.use_critic = self.get_conf('use_critic', False, as_type=bool)
self.critic = None
if self.use_critic:
self.critic = Critic(
device=self.device,
dtype=self.dtype,
process=self,
**self.get_conf('critic', {}) # pass any other params
)
if self.sample_every is not None and self.sample_sources is None:
raise ValueError('sample_every is specified but sample_sources is not')
@@ -89,6 +240,16 @@ class TrainVAEProcess(BaseTrainProcess):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
def update_training_metadata(self):
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
def get_training_info(self):
info = OrderedDict({
'step': self.step_num,
'epoch': self.epoch_num,
})
return info
def print(self, message, **kwargs):
if self.progress_bar is not None:
self.progress_bar.write(message, **kwargs)
@@ -117,19 +278,46 @@ class TrainVAEProcess(BaseTrainProcess):
def setup_vgg19(self):
if self.vgg_19 is None:
self.vgg_19, self.style_losses, self.content_losses, output = get_style_model_and_losses(
single_target=True, device=self.device)
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
single_target=True,
device=self.device,
output_layer_name='pool_4',
dtype=self.torch_dtype
)
self.vgg_19.to(self.device, dtype=self.torch_dtype)
self.vgg_19.requires_grad_(False)
# we run random noise through first to get layer scalers to normalize the loss per layer
# bs of 2 because we run pred and target through stacked
noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype)
self.vgg_19(noise)
for style_loss in self.style_losses:
# get a scaler to normalize to 1
scaler = 1 / torch.mean(style_loss.loss).item()
self.style_weight_scalers.append(scaler)
for content_loss in self.content_losses:
# get a scaler to normalize to 1
scaler = 1 / torch.mean(content_loss.loss).item()
self.content_weight_scalers.append(scaler)
self.print(f"Style weight scalers: {self.style_weight_scalers}")
self.print(f"Content weight scalers: {self.content_weight_scalers}")
def get_style_loss(self):
if self.style_weight > 0:
return torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
# scale all losses with loss scalers
loss = torch.sum(
torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)]))
return loss
else:
return torch.tensor(0.0, device=self.device)
def get_content_loss(self):
if self.content_weight > 0:
return torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
# scale all losses with loss scalers
loss = torch.sum(torch.stack(
[loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)]))
return loss
else:
return torch.tensor(0.0, device=self.device)
@@ -160,7 +348,6 @@ class TrainVAEProcess(BaseTrainProcess):
else:
return torch.tensor(0.0, device=self.device)
def save(self, step=None):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -170,6 +357,7 @@ class TrainVAEProcess(BaseTrainProcess):
# zeropad 9 digits
step_num = f"_{str(step).zfill(9)}"
self.update_training_metadata()
filename = f'{self.job.name}{step_num}.safetensors'
# prepare meta
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
@@ -184,7 +372,10 @@ class TrainVAEProcess(BaseTrainProcess):
# having issues with meta
save_file(state_dict, os.path.join(self.save_root, filename), save_meta)
print(f"Saved to {os.path.join(self.save_root, filename)}")
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
if self.use_critic:
self.critic.save(step)
def sample(self, step=None):
sample_folder = os.path.join(self.save_root, 'samples')
@@ -268,6 +459,9 @@ class TrainVAEProcess(BaseTrainProcess):
num_steps = self.max_steps
if num_steps is None or num_steps > max_epoch_steps:
num_steps = max_epoch_steps
self.max_steps = num_steps
self.epochs = num_epochs
start_step = self.step_num
self.print(f"Training VAE")
self.print(f" - Training folder: {self.training_folder}")
@@ -304,18 +498,14 @@ class TrainVAEProcess(BaseTrainProcess):
params += list(self.vae.decoder.conv_out.parameters())
self.vae.decoder.conv_out.requires_grad_(True)
if self.style_weight > 0 or self.content_weight > 0:
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
self.setup_vgg19()
self.vgg_19.requires_grad_(False)
self.vgg_19.eval()
if self.use_critic:
self.critic.setup()
# todo allow other optimizers
if self.optimizer_type == 'dadaptation':
import dadaptation
print("Using DAdaptAdam optimizer")
optimizer = dadaptation.DAdaptAdam(params, lr=1)
else:
optimizer = torch.optim.Adam(params, lr=float(self.learning_rate))
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate)
# setup scheduler
# todo allow other schedulers
@@ -333,7 +523,6 @@ class TrainVAEProcess(BaseTrainProcess):
leave=True
)
step = 0
# sample first
self.sample()
blank_losses = OrderedDict({
@@ -343,15 +532,17 @@ class TrainVAEProcess(BaseTrainProcess):
"mse": [],
"kl": [],
"tv": [],
"crD": [],
"crG": [],
})
epoch_losses = copy.deepcopy(blank_losses)
log_losses = copy.deepcopy(blank_losses)
for epoch in range(num_epochs):
if step >= num_steps:
# range start at self.epoch_num go to self.epochs
for epoch in range(self.epoch_num, self.epochs, 1):
if self.step_num >= self.max_steps:
break
for batch in self.data_loader:
if step >= num_steps:
if self.step_num >= self.max_steps:
break
batch = batch.to(self.device, dtype=self.torch_dtype)
@@ -365,18 +556,27 @@ class TrainVAEProcess(BaseTrainProcess):
pred = self.vae.decode(latents).sample
# Run through VGG19
if self.style_weight > 0 or self.content_weight > 0:
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
stacked = torch.cat([pred, batch], dim=0)
stacked = (stacked / 2 + 0.5).clamp(0, 1)
self.vgg_19(stacked)
if self.use_critic:
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
else:
critic_d_loss = 0.0
style_loss = self.get_style_loss() * self.style_weight
content_loss = self.get_content_loss() * self.content_weight
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
if self.use_critic:
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
else:
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss
# Backward pass and optimization
optimizer.zero_grad()
@@ -398,6 +598,10 @@ class TrainVAEProcess(BaseTrainProcess):
loss_string += f" mse: {mse_loss.item():.2e}"
if self.tv_weight > 0:
loss_string += f" tv: {tv_loss.item():.2e}"
if self.use_critic and self.critic_weight > 0:
loss_string += f" crG: {critic_gen_loss.item():.2e}"
if self.use_critic:
loss_string += f" crD: {critic_d_loss:.2e}"
if self.optimizer_type.startswith('dadaptation'):
learning_rate = (
@@ -406,7 +610,13 @@ class TrainVAEProcess(BaseTrainProcess):
)
else:
learning_rate = optimizer.param_groups[0]['lr']
self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
lr_critic_string = ''
if self.use_critic:
lr_critic = self.critic.get_lr()
lr_critic_string = f" lrC: {lr_critic:.1e}"
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}")
self.progress_bar.set_description(f"E: {epoch}")
self.progress_bar.update(1)
@@ -416,6 +626,8 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["mse"].append(mse_loss.item())
epoch_losses["kl"].append(kld_loss.item())
epoch_losses["tv"].append(tv_loss.item())
epoch_losses["crG"].append(critic_gen_loss.item())
epoch_losses["crD"].append(critic_d_loss)
log_losses["total"].append(loss_value)
log_losses["style"].append(style_loss.item())
@@ -423,30 +635,33 @@ class TrainVAEProcess(BaseTrainProcess):
log_losses["mse"].append(mse_loss.item())
log_losses["kl"].append(kld_loss.item())
log_losses["tv"].append(tv_loss.item())
log_losses["crG"].append(critic_gen_loss.item())
log_losses["crD"].append(critic_d_loss)
if step != 0:
if self.sample_every and step % self.sample_every == 0:
# don't do on first step
if self.step_num != start_step:
if self.sample_every and self.step_num % self.sample_every == 0:
# print above the progress bar
self.print(f"Sampling at step {step}")
self.sample(step)
self.print(f"Sampling at step {self.step_num}")
self.sample(self.step_num)
if self.save_every and step % self.save_every == 0:
if self.save_every and self.step_num % self.save_every == 0:
# print above the progress bar
self.print(f"Saving at step {step}")
self.save(step)
self.print(f"Saving at step {self.step_num}")
self.save(self.step_num)
if self.log_every and step % self.log_every == 0:
if self.log_every and self.step_num % self.log_every == 0:
# log to tensorboard
if self.writer is not None:
# get avg loss
for key in log_losses:
log_losses[key] = sum(log_losses[key]) / len(log_losses[key])
if log_losses[key] > 0:
self.writer.add_scalar(f"loss/{key}", log_losses[key], step)
# if log_losses[key] > 0:
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
# reset log losses
log_losses = copy.deepcopy(blank_losses)
step += 1
self.step_num += 1
# end epoch
if self.writer is not None:
# get avg loss

View File

@@ -0,0 +1,38 @@
import torch
import torch.nn as nn
class MeanReduce(nn.Module):
def __init__(self):
super(MeanReduce, self).__init__()
def forward(self, inputs):
return torch.mean(inputs, dim=(1, 2, 3), keepdim=True)
class Vgg19Critic(nn.Module):
def __init__(self):
# vgg19 input (bs, 3, 512, 512)
# pool1 (bs, 64, 256, 256)
# pool2 (bs, 128, 128, 128)
# pool3 (bs, 256, 64, 64)
# pool4 (bs, 512, 32, 32) <- take this input
super(Vgg19Critic, self).__init__()
self.main = nn.Sequential(
# input (bs, 512, 32, 32)
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2), # (bs, 512, 16, 16)
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2), # (bs, 512, 8, 8)
nn.Conv2d(512, 1, kernel_size=3, stride=2, padding=1),
# (bs, 1, 4, 4)
MeanReduce(), # (bs, 1, 1, 1)
nn.Flatten(), # (bs, 1)
# nn.Flatten(), # (128*8*8) = 8192
# nn.Linear(128 * 8 * 8, 1)
)
def forward(self, inputs):
return self.main(inputs)

View File

@@ -1,6 +1,7 @@
import os
import json
import oyaml as yaml
import re
from collections import OrderedDict
from toolkit.paths import TOOLKIT_ROOT
@@ -29,6 +30,20 @@ def preprocess_config(config: OrderedDict):
return config
# Fixes issue where yaml doesnt load exponents correctly
fixed_loader = yaml.SafeLoader
fixed_loader.add_implicit_resolver(
u'tag:yaml.org,2002:float',
re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.'))
def get_config(config_file_path):
# first check if it is in the config folder
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
@@ -56,7 +71,7 @@ def get_config(config_file_path):
config = json.load(f, object_pairs_hook=OrderedDict)
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
with open(real_config_path, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config = yaml.load(f, Loader=fixed_loader)
else:
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")

View File

@@ -11,7 +11,7 @@ def total_variation(image):
"""
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
class ComparativeTotalVariation(torch.nn.Module):
@@ -21,3 +21,27 @@ class ComparativeTotalVariation(torch.nn.Module):
def forward(self, pred, target):
return torch.abs(total_variation(pred) - total_variation(target))
# Gradient penalty
def get_gradient_penalty(critic, real, fake, device):
with torch.autocast(device_type='cuda'):
alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
d_interpolates = critic(interpolates)
fake = torch.ones(real.size(0), 1, device=device)
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
return gradient_penalty

View File

@@ -13,6 +13,16 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
# safetensors can only be one level deep
for key, value in save_meta.items():
# if not float, int, bool, or str, convert to json string
if not isinstance(value, (float, int, bool, str)):
if not isinstance(value, str):
save_meta[key] = json.dumps(value)
return save_meta
def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
parsed_meta = OrderedDict()
for key, value in meta.items():
try:
parsed_meta[key] = json.loads(value)
except json.decoder.JSONDecodeError:
parsed_meta[key] = value
return meta

18
toolkit/optimizer.py Normal file
View File

@@ -0,0 +1,18 @@
import torch
def get_optimizer(
params,
optimizer_type='adam',
learning_rate=1e-6
):
if optimizer_type == 'dadaptation':
# dadaptation optimizer does not use standard learning rate. 1 is the default value
import dadaptation
print("Using DAdaptAdam optimizer")
optimizer = dadaptation.DAdaptAdam(params, lr=1.0)
elif optimizer_type == 'adam':
optimizer = torch.optim.Adam(params, lr=float(learning_rate))
else:
raise ValueError(f'Unknown optimizer type {optimizer_type}')
return optimizer

View File

@@ -21,6 +21,7 @@ class ContentLoss(nn.Module):
self.loss = None
def forward(self, stacked_input):
if self.single_target:
split_size = stacked_input.size()[0] // 2
pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0)
@@ -73,6 +74,8 @@ class StyleLoss(nn.Module):
self.device = device
def forward(self, stacked_input):
input_dtype = stacked_input.dtype
stacked_input = stacked_input.float()
if self.single_target:
split_size = stacked_input.size()[0] // 2
preds, style_target = torch.split(stacked_input, split_size, dim=0)
@@ -94,17 +97,18 @@ class StyleLoss(nn.Module):
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
self.loss = loss
return stacked_input
self.loss = loss.to(input_dtype)
return stacked_input.to(input_dtype)
# create a module to normalize input image so we can easily put it in a
# ``nn.Sequential``
class Normalization(nn.Module):
def __init__(self, device):
def __init__(self, device, dtype=torch.float32):
super(Normalization, self).__init__()
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.dtype = dtype
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
@@ -112,9 +116,9 @@ class Normalization(nn.Module):
self.std = torch.tensor(std).view(-1, 1, 1)
def forward(self, stacked_input):
# cast to float 32 if not already
if stacked_input.dtype != torch.float32:
stacked_input = stacked_input.float()
# cast to float 32 if not already # only necessary when processing gram matrix
# if stacked_input.dtype != torch.float32:
# stacked_input = stacked_input.float()
# remove alpha channel if it exists
if stacked_input.shape[1] == 4:
stacked_input = stacked_input[:, :3, :, :]
@@ -123,21 +127,37 @@ class Normalization(nn.Module):
in_max = torch.max(stacked_input)
# norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
# return (norm_stacked_input - self.mean) / self.std
return (stacked_input - self.mean) / self.std
return ((stacked_input - self.mean) / self.std).to(self.dtype)
class OutputLayer(nn.Module):
def __init__(self, name='output_layer'):
super(OutputLayer, self).__init__()
self.name = name
self.tensor = None
def forward(self, stacked_input):
self.tensor = stacked_input
return stacked_input
def get_style_model_and_losses(
single_target=False,
single_target=True, # false has 3 targets, dont remember why i added this initially, this is old code
device='cuda' if torch.cuda.is_available() else 'cpu',
output_layer_name=None,
dtype=torch.float32
):
# content_layers = ['conv_4']
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers = ['conv4_2']
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
# set all weights in the model to our dtype
# for layer in cnn.children():
# layer.to(dtype=dtype)
# normalization module
normalization = Normalization(device).to(device)
normalization = Normalization(device, dtype=dtype).to(device)
# just in order to have an iterable access to or list of content/style
# losses
@@ -189,15 +209,15 @@ def get_style_model_and_losses(
style_losses.append(style_loss)
if output_layer_name is not None and name == output_layer_name:
output_layer = layer
output_layer = OutputLayer(name)
model.add_module("output_layer_{}_{}".format(block, i), output_layer)
# now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break
if output_layer_name is not None and model[i].name == output_layer_name:
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss) or isinstance(model[i], OutputLayer):
break
model = model[:(i + 1)]
model.to(dtype=dtype)
return model, style_losses, content_losses, output_layer