Style and content loss working

This commit is contained in:
Jaret Burkett
2023-07-18 07:47:01 -06:00
parent 439310e4dc
commit 94d52572d4
3 changed files with 281 additions and 27 deletions

View File

@@ -15,7 +15,9 @@ from jobs.process import BaseTrainProcess
from toolkit.kohya_model_util import load_vae
from toolkit.data_loader import ImageDataset
from toolkit.metadata import get_meta_for_safetensors
from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype
from diffusers import AutoencoderKL
from tqdm import tqdm
import time
import numpy as np
@@ -27,15 +29,9 @@ IMAGE_TRANSFORMS = transforms.Compose(
]
)
INVERSE_IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.Normalize(
mean=[-0.5/0.5],
std=[1/0.5]
),
transforms.ToPILImage(),
]
)
def unnormalize(tensor):
return (tensor / 2 + 0.5).clamp(0, 1)
class TrainVAEProcess(BaseTrainProcess):
@@ -56,8 +52,12 @@ class TrainVAEProcess(BaseTrainProcess):
self.save_every = self.get_conf('save_every', None)
self.dtype = self.get_conf('dtype', 'float32')
self.sample_sources = self.get_conf('sample_sources', None)
self.style_weight = self.get_conf('style_weight', 1e4)
self.content_weight = self.get_conf('content_weight', 1)
self.elbo_weight = self.get_conf('elbo_weight', 1e-8)
self.torch_dtype = get_torch_dtype(self.dtype)
self.save_root = os.path.join(self.training_folder, self.job.name)
self.vgg_19 = None
if self.sample_every is not None and self.sample_sources is None:
raise ValueError('sample_every is specified but sample_sources is not')
@@ -66,7 +66,6 @@ class TrainVAEProcess(BaseTrainProcess):
raise ValueError('epochs or max_steps must be specified')
self.data_loaders = []
datasets = []
# check datasets
assert isinstance(self.datasets_objects, list)
for dataset in self.datasets_objects:
@@ -95,10 +94,17 @@ class TrainVAEProcess(BaseTrainProcess):
self.data_loader = DataLoader(
concatenated_dataset,
batch_size=self.batch_size,
shuffle=True
shuffle=True,
num_workers=6
)
def get_loss(self, pred, target):
def setup_vgg19(self):
if self.vgg_19 is None:
self.vgg_19, self.style_losses, self.content_losses = get_style_model_and_losses(
single_target=True, device=self.device)
self.vgg_19.requires_grad_(False)
def get_mse_loss(self, pred, target):
loss_fn = nn.MSELoss()
loss = loss_fn(pred, target)
return loss
@@ -157,8 +163,18 @@ class TrainVAEProcess(BaseTrainProcess):
input_img = img
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
decoded = self.vae(img).sample.squeeze(0)
decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
img = img
decoded = self.vae(img).sample
decoded = (decoded / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
#convert to pillow image
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
# # decoded = decoded - 0.1
# decoded = decoded
# decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
# stack input image and decoded image
input_img = input_img.resize((self.resolution, self.resolution))
@@ -177,7 +193,6 @@ class TrainVAEProcess(BaseTrainProcess):
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
output_img.save(os.path.join(sample_folder, filename))
self.vae.decoder.train()
def run(self):
super().run()
@@ -204,19 +219,41 @@ class TrainVAEProcess(BaseTrainProcess):
print(f"Loading VAE")
print(f" - Loading VAE: {self.vae_path}")
if self.vae is None:
# self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
# set decoder to train
self.vae.to(self.device, dtype=self.torch_dtype)
self.vae.requires_grad_(False)
self.vae.eval()
self.vae.decoder.requires_grad_(True)
self.vae.decoder.train()
parameters = self.vae.decoder.parameters()
blocks_to_train = [
'mid_block',
'up_blocks',
]
optimizer = torch.optim.Adam(parameters, lr=self.learning_rate)
params = []
# only set last 2 layers to trainable
for param in self.vae.decoder.parameters():
param.requires_grad = False
if 'mid_block' in blocks_to_train:
params += list(self.vae.decoder.mid_block.parameters())
self.vae.decoder.mid_block.requires_grad_(True)
if 'up_blocks' in blocks_to_train:
params += list(self.vae.decoder.up_blocks.parameters())
self.vae.decoder.up_blocks.requires_grad_(True)
# self.vae.decoder.train()
self.setup_vgg19()
self.vgg_19.requires_grad_(False)
self.vgg_19.eval()
optimizer = torch.optim.Adam(params, lr=self.learning_rate)
# setup scheduler
# scheduler = lr_scheduler.ConstantLR
@@ -249,6 +286,7 @@ class TrainVAEProcess(BaseTrainProcess):
# forward pass
# with torch.no_grad():
# batch = batch + 0.1
dgd = self.vae.encode(batch).latent_dist
mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample()
@@ -256,7 +294,24 @@ class TrainVAEProcess(BaseTrainProcess):
pred = self.vae.decode(latents).sample
loss = self.get_elbo_loss(pred, batch, mu, logvar)
# pred = pred + 0.1
# loss = self.get_elbo_loss(pred, batch, mu, logvar)
stacked = torch.cat([pred, batch], dim=0)
stacked = (stacked / 2 + 0.5).clamp(0, 1)
self.vgg_19(stacked)
# reduce the mean of the style_loss list
style_loss = torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
content_loss = torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
elbo_loss = self.get_elbo_loss(pred, batch, mu, logvar)
# elbo_loss = torch.zeros(1, device=self.device, dtype=self.torch_dtype)
style_loss = style_loss * self.style_weight
content_loss = content_loss * self.content_weight
elbo_loss = elbo_loss * self.elbo_weight
loss = style_loss + content_loss + elbo_loss
# Backward pass and optimization
optimizer.zero_grad()
@@ -267,9 +322,9 @@ class TrainVAEProcess(BaseTrainProcess):
# update progress bar
loss_value = loss.item()
# get exponent like 3.54e-4
loss_string = f"{loss_value:.2e}"
loss_string = f"loss: {loss_value:.2e} cnt: {content_loss.item():.2e} sty: {style_loss.item():.2e} elbo: {elbo_loss.item():.2e}"
learning_rate = optimizer.param_groups[0]['lr']
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} Loss: {loss_string}")
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
progress_bar.set_description(f"E: {epoch} - S: {step} ")
progress_bar.update(1)
@@ -279,7 +334,7 @@ class TrainVAEProcess(BaseTrainProcess):
print(f"Sampling at step {step}")
self.sample(step)
if self.save_every and step % self.save_every == 0:
if self.save_every and step % self.save_every == 0:
# print above the progress bar
print(f"Saving at step {step}")
self.save(step)

View File

@@ -51,15 +51,20 @@ class ImageDataset(Dataset):
# Downscale the source image first
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
min_img_size = min(img.size)
if self.random_crop:
if self.random_scale:
scale_size = random.randint(int(img.size[0] * self.scale), self.resolution)
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}")
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
min_dimension = min(img.size)
img = transforms.CenterCrop(min_dimension)(img)
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
img = self.transform(img)

194
toolkit/style.py Normal file
View File

@@ -0,0 +1,194 @@
from torch import nn
import torch.nn.functional as F
import torch
from torchvision import models
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
def tensor_size(tensor):
channels = tensor.shape[1]
height = tensor.shape[2]
width = tensor.shape[3]
return channels * height * width
class ContentLoss(nn.Module):
def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'):
super(ContentLoss, self).__init__()
self.single_target = single_target
self.device = device
self.loss = None
def forward(self, stacked_input):
if self.single_target:
split_size = stacked_input.size()[0] // 2
pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0)
else:
split_size = stacked_input.size()[0] // 3
pred_layer, _, target_layer = torch.split(stacked_input, split_size, dim=0)
content_size = tensor_size(pred_layer)
# Define the separate loss function
def separated_loss(y_pred, y_true):
diff = torch.abs(y_pred - y_true)
l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
return 2. * l2 / content_size
# Calculate itemized loss
pred_itemized_loss = separated_loss(pred_layer, target_layer)
# Calculate the mean of itemized loss
loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
self.loss = loss
return stacked_input
def convert_to_gram_matrix(inputs):
shape = inputs.size()
batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
size = height * width * filters
feats = inputs.view(batch, filters, height * width)
feats_t = feats.transpose(1, 2)
grams_raw = torch.matmul(feats, feats_t)
gram_matrix = grams_raw / size
return gram_matrix
######################################################################
# Now the style loss module looks almost exactly like the content loss
# module. The style distance is also computed using the mean square
# error between :math:`G_{XL}` and :math:`G_{SL}`.
#
class StyleLoss(nn.Module):
def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'):
super(StyleLoss, self).__init__()
self.single_target = single_target
self.device = device
def forward(self, stacked_input):
if self.single_target:
split_size = stacked_input.size()[0] // 2
preds, style_target = torch.split(stacked_input, split_size, dim=0)
else:
split_size = stacked_input.size()[0] // 3
preds, style_target, _ = torch.split(stacked_input, split_size, dim=0)
def separated_loss(y_pred, y_true):
gram_size = y_true.size(1) * y_true.size(2)
sum_axis = (1, 2)
diff = torch.abs(y_pred - y_true)
raw_loss = torch.sum(diff ** 2, dim=sum_axis, keepdim=True)
return raw_loss / gram_size
target_grams = convert_to_gram_matrix(style_target)
pred_grams = convert_to_gram_matrix(preds)
itemized_loss = separated_loss(pred_grams, target_grams)
# reshape itemized loss to be (batch, 1, 1, 1)
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
self.loss = loss
return stacked_input
# create a module to normalize input image so we can easily put it in a
# ``nn.Sequential``
class Normalization(nn.Module):
def __init__(self, device):
super(Normalization, self).__init__()
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)
def forward(self, stacked_input):
# cast to float 32 if not already
if stacked_input.dtype != torch.float32:
stacked_input = stacked_input.float()
# remove alpha channel if it exists
if stacked_input.shape[1] == 4:
stacked_input = stacked_input[:, :3, :, :]
# normalize to min and max of 0 - 1
in_min = torch.min(stacked_input)
in_max = torch.max(stacked_input)
norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
return (norm_stacked_input - self.mean) / self.std
def get_style_model_and_losses(
single_target=False,
device='cuda' if torch.cuda.is_available() else 'cpu'
):
# content_layers = ['conv_4']
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers = ['conv3_2', 'conv4_2']
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
cnn = models.vgg19(pretrained=True).features.to(device).eval()
# normalization module
normalization = Normalization(device).to(device)
# just in order to have an iterable access to or list of content/style
# losses
content_losses = []
style_losses = []
# assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential``
# to put in modules that are supposed to be activated sequentially
model = nn.Sequential(normalization)
i = 0 # increment every time we see a conv
block = 1
children = list(cnn.children())
for layer in children:
if isinstance(layer, nn.Conv2d):
i += 1
name = f'conv{block}_{i}_raw'
elif isinstance(layer, nn.ReLU):
# name = 'relu_{}'.format(i)
name = f'conv{block}_{i}' # target this
# The in-place version doesn't play very nicely with the ``ContentLoss``
# and ``StyleLoss`` we insert below. So we replace with out-of-place
# ones here.
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = 'pool_{}'.format(i)
block += 1
i = 0
elif isinstance(layer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
model.add_module(name, layer)
if name in content_layers:
# add content loss:
content_loss = ContentLoss(single_target=single_target, device=device)
model.add_module("content_loss_{}_{}".format(block, i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# add style loss:
style_loss = StyleLoss(single_target=single_target, device=device)
model.add_module("style_loss_{}_{}".format(block, i), style_loss)
style_losses.append(style_loss)
# now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break
model = model[:(i + 1)]
return model, style_losses, content_losses