From 94d52572d4e294c7b1fb2b341aea5b3313915016 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 18 Jul 2023 07:47:01 -0600 Subject: [PATCH] Style and content loss working --- jobs/process/TrainVAEProcess.py | 101 +++++++++++++---- toolkit/data_loader.py | 13 ++- toolkit/style.py | 194 ++++++++++++++++++++++++++++++++ 3 files changed, 281 insertions(+), 27 deletions(-) create mode 100644 toolkit/style.py diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index e30e37a3..3eb4e8c7 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -15,7 +15,9 @@ from jobs.process import BaseTrainProcess from toolkit.kohya_model_util import load_vae from toolkit.data_loader import ImageDataset from toolkit.metadata import get_meta_for_safetensors +from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype +from diffusers import AutoencoderKL from tqdm import tqdm import time import numpy as np @@ -27,15 +29,9 @@ IMAGE_TRANSFORMS = transforms.Compose( ] ) -INVERSE_IMAGE_TRANSFORMS = transforms.Compose( - [ - transforms.Normalize( - mean=[-0.5/0.5], - std=[1/0.5] - ), - transforms.ToPILImage(), - ] -) + +def unnormalize(tensor): + return (tensor / 2 + 0.5).clamp(0, 1) class TrainVAEProcess(BaseTrainProcess): @@ -56,8 +52,12 @@ class TrainVAEProcess(BaseTrainProcess): self.save_every = self.get_conf('save_every', None) self.dtype = self.get_conf('dtype', 'float32') self.sample_sources = self.get_conf('sample_sources', None) + self.style_weight = self.get_conf('style_weight', 1e4) + self.content_weight = self.get_conf('content_weight', 1) + self.elbo_weight = self.get_conf('elbo_weight', 1e-8) self.torch_dtype = get_torch_dtype(self.dtype) self.save_root = os.path.join(self.training_folder, self.job.name) + self.vgg_19 = None if self.sample_every is not None and self.sample_sources is None: raise ValueError('sample_every is specified but sample_sources is not') @@ -66,7 +66,6 @@ class TrainVAEProcess(BaseTrainProcess): raise ValueError('epochs or max_steps must be specified') self.data_loaders = [] - datasets = [] # check datasets assert isinstance(self.datasets_objects, list) for dataset in self.datasets_objects: @@ -95,10 +94,17 @@ class TrainVAEProcess(BaseTrainProcess): self.data_loader = DataLoader( concatenated_dataset, batch_size=self.batch_size, - shuffle=True + shuffle=True, + num_workers=6 ) - def get_loss(self, pred, target): + def setup_vgg19(self): + if self.vgg_19 is None: + self.vgg_19, self.style_losses, self.content_losses = get_style_model_and_losses( + single_target=True, device=self.device) + self.vgg_19.requires_grad_(False) + + def get_mse_loss(self, pred, target): loss_fn = nn.MSELoss() loss = loss_fn(pred, target) return loss @@ -157,8 +163,18 @@ class TrainVAEProcess(BaseTrainProcess): input_img = img img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) - decoded = self.vae(img).sample.squeeze(0) - decoded = INVERSE_IMAGE_TRANSFORMS(decoded) + img = img + decoded = self.vae(img).sample + decoded = (decoded / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + + #convert to pillow image + decoded = Image.fromarray((decoded * 255).astype(np.uint8)) + + # # decoded = decoded - 0.1 + # decoded = decoded + # decoded = INVERSE_IMAGE_TRANSFORMS(decoded) # stack input image and decoded image input_img = input_img.resize((self.resolution, self.resolution)) @@ -177,7 +193,6 @@ class TrainVAEProcess(BaseTrainProcess): i_str = str(i).zfill(2) filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" output_img.save(os.path.join(sample_folder, filename)) - self.vae.decoder.train() def run(self): super().run() @@ -204,19 +219,41 @@ class TrainVAEProcess(BaseTrainProcess): print(f"Loading VAE") print(f" - Loading VAE: {self.vae_path}") if self.vae is None: + # self.vae = load_vae(self.vae_path, dtype=self.torch_dtype) self.vae = load_vae(self.vae_path, dtype=self.torch_dtype) # set decoder to train self.vae.to(self.device, dtype=self.torch_dtype) self.vae.requires_grad_(False) self.vae.eval() - - self.vae.decoder.requires_grad_(True) self.vae.decoder.train() - parameters = self.vae.decoder.parameters() + blocks_to_train = [ + 'mid_block', + 'up_blocks', + ] - optimizer = torch.optim.Adam(parameters, lr=self.learning_rate) + params = [] + + # only set last 2 layers to trainable + for param in self.vae.decoder.parameters(): + param.requires_grad = False + + if 'mid_block' in blocks_to_train: + params += list(self.vae.decoder.mid_block.parameters()) + self.vae.decoder.mid_block.requires_grad_(True) + if 'up_blocks' in blocks_to_train: + params += list(self.vae.decoder.up_blocks.parameters()) + self.vae.decoder.up_blocks.requires_grad_(True) + + # self.vae.decoder.train() + + self.setup_vgg19() + self.vgg_19.requires_grad_(False) + self.vgg_19.eval() + + + optimizer = torch.optim.Adam(params, lr=self.learning_rate) # setup scheduler # scheduler = lr_scheduler.ConstantLR @@ -249,6 +286,7 @@ class TrainVAEProcess(BaseTrainProcess): # forward pass # with torch.no_grad(): + # batch = batch + 0.1 dgd = self.vae.encode(batch).latent_dist mu, logvar = dgd.mean, dgd.logvar latents = dgd.sample() @@ -256,7 +294,24 @@ class TrainVAEProcess(BaseTrainProcess): pred = self.vae.decode(latents).sample - loss = self.get_elbo_loss(pred, batch, mu, logvar) + # pred = pred + 0.1 + + # loss = self.get_elbo_loss(pred, batch, mu, logvar) + + stacked = torch.cat([pred, batch], dim=0) + stacked = (stacked / 2 + 0.5).clamp(0, 1) + self.vgg_19(stacked) + # reduce the mean of the style_loss list + + style_loss = torch.sum(torch.stack([loss.loss for loss in self.style_losses])) + content_loss = torch.sum(torch.stack([loss.loss for loss in self.content_losses])) + elbo_loss = self.get_elbo_loss(pred, batch, mu, logvar) + # elbo_loss = torch.zeros(1, device=self.device, dtype=self.torch_dtype) + style_loss = style_loss * self.style_weight + content_loss = content_loss * self.content_weight + elbo_loss = elbo_loss * self.elbo_weight + + loss = style_loss + content_loss + elbo_loss # Backward pass and optimization optimizer.zero_grad() @@ -267,9 +322,9 @@ class TrainVAEProcess(BaseTrainProcess): # update progress bar loss_value = loss.item() # get exponent like 3.54e-4 - loss_string = f"{loss_value:.2e}" + loss_string = f"loss: {loss_value:.2e} cnt: {content_loss.item():.2e} sty: {style_loss.item():.2e} elbo: {elbo_loss.item():.2e}" learning_rate = optimizer.param_groups[0]['lr'] - progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} Loss: {loss_string}") + progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}") progress_bar.set_description(f"E: {epoch} - S: {step} ") progress_bar.update(1) @@ -279,7 +334,7 @@ class TrainVAEProcess(BaseTrainProcess): print(f"Sampling at step {step}") self.sample(step) - if self.save_every and step % self.save_every == 0: + if self.save_every and step % self.save_every == 0: # print above the progress bar print(f"Saving at step {step}") self.save(step) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 153888fe..2bd45154 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -51,15 +51,20 @@ class ImageDataset(Dataset): # Downscale the source image first img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) + min_img_size = min(img.size) if self.random_crop: - if self.random_scale: - scale_size = random.randint(int(img.size[0] * self.scale), self.resolution) + if self.random_scale and min_img_size > self.resolution: + if min_img_size < self.resolution: + print( + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}") + scale_size = self.resolution + else: + scale_size = random.randint(self.resolution, int(min_img_size)) img = img.resize((scale_size, scale_size), Image.BICUBIC) img = transforms.RandomCrop(self.resolution)(img) else: - min_dimension = min(img.size) - img = transforms.CenterCrop(min_dimension)(img) + img = transforms.CenterCrop(min_img_size)(img) img = img.resize((self.resolution, self.resolution), Image.BICUBIC) img = self.transform(img) diff --git a/toolkit/style.py b/toolkit/style.py new file mode 100644 index 00000000..3c810499 --- /dev/null +++ b/toolkit/style.py @@ -0,0 +1,194 @@ +from torch import nn +import torch.nn.functional as F +import torch +from torchvision import models + + +# device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def tensor_size(tensor): + channels = tensor.shape[1] + height = tensor.shape[2] + width = tensor.shape[3] + return channels * height * width + +class ContentLoss(nn.Module): + + def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'): + super(ContentLoss, self).__init__() + self.single_target = single_target + self.device = device + self.loss = None + + def forward(self, stacked_input): + if self.single_target: + split_size = stacked_input.size()[0] // 2 + pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0) + else: + split_size = stacked_input.size()[0] // 3 + pred_layer, _, target_layer = torch.split(stacked_input, split_size, dim=0) + + content_size = tensor_size(pred_layer) + + # Define the separate loss function + def separated_loss(y_pred, y_true): + diff = torch.abs(y_pred - y_true) + l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0 + return 2. * l2 / content_size + + # Calculate itemized loss + pred_itemized_loss = separated_loss(pred_layer, target_layer) + + # Calculate the mean of itemized loss + loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True) + self.loss = loss + + return stacked_input + + +def convert_to_gram_matrix(inputs): + shape = inputs.size() + batch, filters, height, width = shape[0], shape[1], shape[2], shape[3] + size = height * width * filters + + feats = inputs.view(batch, filters, height * width) + feats_t = feats.transpose(1, 2) + grams_raw = torch.matmul(feats, feats_t) + gram_matrix = grams_raw / size + + return gram_matrix + + +###################################################################### +# Now the style loss module looks almost exactly like the content loss +# module. The style distance is also computed using the mean square +# error between :math:`G_{XL}` and :math:`G_{SL}`. +# + +class StyleLoss(nn.Module): + + def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'): + super(StyleLoss, self).__init__() + self.single_target = single_target + self.device = device + + def forward(self, stacked_input): + if self.single_target: + split_size = stacked_input.size()[0] // 2 + preds, style_target = torch.split(stacked_input, split_size, dim=0) + else: + split_size = stacked_input.size()[0] // 3 + preds, style_target, _ = torch.split(stacked_input, split_size, dim=0) + + def separated_loss(y_pred, y_true): + gram_size = y_true.size(1) * y_true.size(2) + sum_axis = (1, 2) + diff = torch.abs(y_pred - y_true) + raw_loss = torch.sum(diff ** 2, dim=sum_axis, keepdim=True) + return raw_loss / gram_size + + target_grams = convert_to_gram_matrix(style_target) + pred_grams = convert_to_gram_matrix(preds) + itemized_loss = separated_loss(pred_grams, target_grams) + # reshape itemized loss to be (batch, 1, 1, 1) + itemized_loss = torch.unsqueeze(itemized_loss, dim=1) + # gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2]) + loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True) + self.loss = loss + return stacked_input + + +# create a module to normalize input image so we can easily put it in a +# ``nn.Sequential`` +class Normalization(nn.Module): + def __init__(self, device): + super(Normalization, self).__init__() + mean = torch.tensor([0.485, 0.456, 0.406]).to(device) + std = torch.tensor([0.229, 0.224, 0.225]).to(device) + # .view the mean and std to make them [C x 1 x 1] so that they can + # directly work with image Tensor of shape [B x C x H x W]. + # B is batch size. C is number of channels. H is height and W is width. + self.mean = torch.tensor(mean).view(-1, 1, 1) + self.std = torch.tensor(std).view(-1, 1, 1) + + def forward(self, stacked_input): + # cast to float 32 if not already + if stacked_input.dtype != torch.float32: + stacked_input = stacked_input.float() + # remove alpha channel if it exists + if stacked_input.shape[1] == 4: + stacked_input = stacked_input[:, :3, :, :] + # normalize to min and max of 0 - 1 + in_min = torch.min(stacked_input) + in_max = torch.max(stacked_input) + norm_stacked_input = (stacked_input - in_min) / (in_max - in_min) + return (norm_stacked_input - self.mean) / self.std + + +def get_style_model_and_losses( + single_target=False, + device='cuda' if torch.cuda.is_available() else 'cpu' +): + # content_layers = ['conv_4'] + # style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] + content_layers = ['conv3_2', 'conv4_2'] + style_layers = ['conv2_1', 'conv3_1', 'conv4_1'] + cnn = models.vgg19(pretrained=True).features.to(device).eval() + # normalization module + normalization = Normalization(device).to(device) + + # just in order to have an iterable access to or list of content/style + # losses + content_losses = [] + style_losses = [] + + # assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential`` + # to put in modules that are supposed to be activated sequentially + model = nn.Sequential(normalization) + + i = 0 # increment every time we see a conv + block = 1 + children = list(cnn.children()) + + for layer in children: + if isinstance(layer, nn.Conv2d): + i += 1 + name = f'conv{block}_{i}_raw' + elif isinstance(layer, nn.ReLU): + # name = 'relu_{}'.format(i) + name = f'conv{block}_{i}' # target this + # The in-place version doesn't play very nicely with the ``ContentLoss`` + # and ``StyleLoss`` we insert below. So we replace with out-of-place + # ones here. + layer = nn.ReLU(inplace=False) + elif isinstance(layer, nn.MaxPool2d): + name = 'pool_{}'.format(i) + block += 1 + i = 0 + elif isinstance(layer, nn.BatchNorm2d): + name = 'bn_{}'.format(i) + else: + raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) + + model.add_module(name, layer) + + if name in content_layers: + # add content loss: + content_loss = ContentLoss(single_target=single_target, device=device) + model.add_module("content_loss_{}_{}".format(block, i), content_loss) + content_losses.append(content_loss) + + if name in style_layers: + # add style loss: + style_loss = StyleLoss(single_target=single_target, device=device) + model.add_module("style_loss_{}_{}".format(block, i), style_loss) + style_losses.append(style_loss) + + # now we trim off the layers after the last content and style losses + for i in range(len(model) - 1, -1, -1): + if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): + break + + model = model[:(i + 1)] + + return model, style_losses, content_losses