mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Style and content loss working
This commit is contained in:
@@ -15,7 +15,9 @@ from jobs.process import BaseTrainProcess
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from diffusers import AutoencoderKL
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import numpy as np
|
||||
@@ -27,15 +29,9 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
]
|
||||
)
|
||||
|
||||
INVERSE_IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.Normalize(
|
||||
mean=[-0.5/0.5],
|
||||
std=[1/0.5]
|
||||
),
|
||||
transforms.ToPILImage(),
|
||||
]
|
||||
)
|
||||
|
||||
def unnormalize(tensor):
|
||||
return (tensor / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
|
||||
class TrainVAEProcess(BaseTrainProcess):
|
||||
@@ -56,8 +52,12 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.save_every = self.get_conf('save_every', None)
|
||||
self.dtype = self.get_conf('dtype', 'float32')
|
||||
self.sample_sources = self.get_conf('sample_sources', None)
|
||||
self.style_weight = self.get_conf('style_weight', 1e4)
|
||||
self.content_weight = self.get_conf('content_weight', 1)
|
||||
self.elbo_weight = self.get_conf('elbo_weight', 1e-8)
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||
self.vgg_19 = None
|
||||
|
||||
if self.sample_every is not None and self.sample_sources is None:
|
||||
raise ValueError('sample_every is specified but sample_sources is not')
|
||||
@@ -66,7 +66,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
raise ValueError('epochs or max_steps must be specified')
|
||||
|
||||
self.data_loaders = []
|
||||
datasets = []
|
||||
# check datasets
|
||||
assert isinstance(self.datasets_objects, list)
|
||||
for dataset in self.datasets_objects:
|
||||
@@ -95,10 +94,17 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True
|
||||
shuffle=True,
|
||||
num_workers=6
|
||||
)
|
||||
|
||||
def get_loss(self, pred, target):
|
||||
def setup_vgg19(self):
|
||||
if self.vgg_19 is None:
|
||||
self.vgg_19, self.style_losses, self.content_losses = get_style_model_and_losses(
|
||||
single_target=True, device=self.device)
|
||||
self.vgg_19.requires_grad_(False)
|
||||
|
||||
def get_mse_loss(self, pred, target):
|
||||
loss_fn = nn.MSELoss()
|
||||
loss = loss_fn(pred, target)
|
||||
return loss
|
||||
@@ -157,8 +163,18 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
input_img = img
|
||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
||||
decoded = self.vae(img).sample.squeeze(0)
|
||||
decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
|
||||
img = img
|
||||
decoded = self.vae(img).sample
|
||||
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
|
||||
#convert to pillow image
|
||||
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
|
||||
|
||||
# # decoded = decoded - 0.1
|
||||
# decoded = decoded
|
||||
# decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
|
||||
|
||||
# stack input image and decoded image
|
||||
input_img = input_img.resize((self.resolution, self.resolution))
|
||||
@@ -177,7 +193,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
i_str = str(i).zfill(2)
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||
output_img.save(os.path.join(sample_folder, filename))
|
||||
self.vae.decoder.train()
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
@@ -204,19 +219,41 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
print(f"Loading VAE")
|
||||
print(f" - Loading VAE: {self.vae_path}")
|
||||
if self.vae is None:
|
||||
# self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
||||
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
||||
|
||||
# set decoder to train
|
||||
self.vae.to(self.device, dtype=self.torch_dtype)
|
||||
self.vae.requires_grad_(False)
|
||||
self.vae.eval()
|
||||
|
||||
self.vae.decoder.requires_grad_(True)
|
||||
self.vae.decoder.train()
|
||||
|
||||
parameters = self.vae.decoder.parameters()
|
||||
blocks_to_train = [
|
||||
'mid_block',
|
||||
'up_blocks',
|
||||
]
|
||||
|
||||
optimizer = torch.optim.Adam(parameters, lr=self.learning_rate)
|
||||
params = []
|
||||
|
||||
# only set last 2 layers to trainable
|
||||
for param in self.vae.decoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if 'mid_block' in blocks_to_train:
|
||||
params += list(self.vae.decoder.mid_block.parameters())
|
||||
self.vae.decoder.mid_block.requires_grad_(True)
|
||||
if 'up_blocks' in blocks_to_train:
|
||||
params += list(self.vae.decoder.up_blocks.parameters())
|
||||
self.vae.decoder.up_blocks.requires_grad_(True)
|
||||
|
||||
# self.vae.decoder.train()
|
||||
|
||||
self.setup_vgg19()
|
||||
self.vgg_19.requires_grad_(False)
|
||||
self.vgg_19.eval()
|
||||
|
||||
|
||||
optimizer = torch.optim.Adam(params, lr=self.learning_rate)
|
||||
|
||||
# setup scheduler
|
||||
# scheduler = lr_scheduler.ConstantLR
|
||||
@@ -249,6 +286,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
# forward pass
|
||||
# with torch.no_grad():
|
||||
# batch = batch + 0.1
|
||||
dgd = self.vae.encode(batch).latent_dist
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
@@ -256,7 +294,24 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
|
||||
pred = self.vae.decode(latents).sample
|
||||
|
||||
loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
||||
# pred = pred + 0.1
|
||||
|
||||
# loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
||||
|
||||
stacked = torch.cat([pred, batch], dim=0)
|
||||
stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
# reduce the mean of the style_loss list
|
||||
|
||||
style_loss = torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
|
||||
content_loss = torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
|
||||
elbo_loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
||||
# elbo_loss = torch.zeros(1, device=self.device, dtype=self.torch_dtype)
|
||||
style_loss = style_loss * self.style_weight
|
||||
content_loss = content_loss * self.content_weight
|
||||
elbo_loss = elbo_loss * self.elbo_weight
|
||||
|
||||
loss = style_loss + content_loss + elbo_loss
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -267,9 +322,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
# update progress bar
|
||||
loss_value = loss.item()
|
||||
# get exponent like 3.54e-4
|
||||
loss_string = f"{loss_value:.2e}"
|
||||
loss_string = f"loss: {loss_value:.2e} cnt: {content_loss.item():.2e} sty: {style_loss.item():.2e} elbo: {elbo_loss.item():.2e}"
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} Loss: {loss_string}")
|
||||
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
|
||||
progress_bar.set_description(f"E: {epoch} - S: {step} ")
|
||||
progress_bar.update(1)
|
||||
|
||||
@@ -279,7 +334,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
print(f"Sampling at step {step}")
|
||||
self.sample(step)
|
||||
|
||||
if self.save_every and step % self.save_every == 0:
|
||||
if self.save_every and step % self.save_every == 0:
|
||||
# print above the progress bar
|
||||
print(f"Saving at step {step}")
|
||||
self.save(step)
|
||||
|
||||
@@ -51,15 +51,20 @@ class ImageDataset(Dataset):
|
||||
|
||||
# Downscale the source image first
|
||||
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
|
||||
min_img_size = min(img.size)
|
||||
|
||||
if self.random_crop:
|
||||
if self.random_scale:
|
||||
scale_size = random.randint(int(img.size[0] * self.scale), self.resolution)
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||
img = transforms.RandomCrop(self.resolution)(img)
|
||||
else:
|
||||
min_dimension = min(img.size)
|
||||
img = transforms.CenterCrop(min_dimension)(img)
|
||||
img = transforms.CenterCrop(min_img_size)(img)
|
||||
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
||||
|
||||
img = self.transform(img)
|
||||
|
||||
194
toolkit/style.py
Normal file
194
toolkit/style.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from torchvision import models
|
||||
|
||||
|
||||
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
def tensor_size(tensor):
|
||||
channels = tensor.shape[1]
|
||||
height = tensor.shape[2]
|
||||
width = tensor.shape[3]
|
||||
return channels * height * width
|
||||
|
||||
class ContentLoss(nn.Module):
|
||||
|
||||
def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
super(ContentLoss, self).__init__()
|
||||
self.single_target = single_target
|
||||
self.device = device
|
||||
self.loss = None
|
||||
|
||||
def forward(self, stacked_input):
|
||||
if self.single_target:
|
||||
split_size = stacked_input.size()[0] // 2
|
||||
pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0)
|
||||
else:
|
||||
split_size = stacked_input.size()[0] // 3
|
||||
pred_layer, _, target_layer = torch.split(stacked_input, split_size, dim=0)
|
||||
|
||||
content_size = tensor_size(pred_layer)
|
||||
|
||||
# Define the separate loss function
|
||||
def separated_loss(y_pred, y_true):
|
||||
diff = torch.abs(y_pred - y_true)
|
||||
l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
|
||||
return 2. * l2 / content_size
|
||||
|
||||
# Calculate itemized loss
|
||||
pred_itemized_loss = separated_loss(pred_layer, target_layer)
|
||||
|
||||
# Calculate the mean of itemized loss
|
||||
loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
|
||||
self.loss = loss
|
||||
|
||||
return stacked_input
|
||||
|
||||
|
||||
def convert_to_gram_matrix(inputs):
|
||||
shape = inputs.size()
|
||||
batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
|
||||
size = height * width * filters
|
||||
|
||||
feats = inputs.view(batch, filters, height * width)
|
||||
feats_t = feats.transpose(1, 2)
|
||||
grams_raw = torch.matmul(feats, feats_t)
|
||||
gram_matrix = grams_raw / size
|
||||
|
||||
return gram_matrix
|
||||
|
||||
|
||||
######################################################################
|
||||
# Now the style loss module looks almost exactly like the content loss
|
||||
# module. The style distance is also computed using the mean square
|
||||
# error between :math:`G_{XL}` and :math:`G_{SL}`.
|
||||
#
|
||||
|
||||
class StyleLoss(nn.Module):
|
||||
|
||||
def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||
super(StyleLoss, self).__init__()
|
||||
self.single_target = single_target
|
||||
self.device = device
|
||||
|
||||
def forward(self, stacked_input):
|
||||
if self.single_target:
|
||||
split_size = stacked_input.size()[0] // 2
|
||||
preds, style_target = torch.split(stacked_input, split_size, dim=0)
|
||||
else:
|
||||
split_size = stacked_input.size()[0] // 3
|
||||
preds, style_target, _ = torch.split(stacked_input, split_size, dim=0)
|
||||
|
||||
def separated_loss(y_pred, y_true):
|
||||
gram_size = y_true.size(1) * y_true.size(2)
|
||||
sum_axis = (1, 2)
|
||||
diff = torch.abs(y_pred - y_true)
|
||||
raw_loss = torch.sum(diff ** 2, dim=sum_axis, keepdim=True)
|
||||
return raw_loss / gram_size
|
||||
|
||||
target_grams = convert_to_gram_matrix(style_target)
|
||||
pred_grams = convert_to_gram_matrix(preds)
|
||||
itemized_loss = separated_loss(pred_grams, target_grams)
|
||||
# reshape itemized loss to be (batch, 1, 1, 1)
|
||||
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
|
||||
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
|
||||
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
|
||||
self.loss = loss
|
||||
return stacked_input
|
||||
|
||||
|
||||
# create a module to normalize input image so we can easily put it in a
|
||||
# ``nn.Sequential``
|
||||
class Normalization(nn.Module):
|
||||
def __init__(self, device):
|
||||
super(Normalization, self).__init__()
|
||||
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
|
||||
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
|
||||
# .view the mean and std to make them [C x 1 x 1] so that they can
|
||||
# directly work with image Tensor of shape [B x C x H x W].
|
||||
# B is batch size. C is number of channels. H is height and W is width.
|
||||
self.mean = torch.tensor(mean).view(-1, 1, 1)
|
||||
self.std = torch.tensor(std).view(-1, 1, 1)
|
||||
|
||||
def forward(self, stacked_input):
|
||||
# cast to float 32 if not already
|
||||
if stacked_input.dtype != torch.float32:
|
||||
stacked_input = stacked_input.float()
|
||||
# remove alpha channel if it exists
|
||||
if stacked_input.shape[1] == 4:
|
||||
stacked_input = stacked_input[:, :3, :, :]
|
||||
# normalize to min and max of 0 - 1
|
||||
in_min = torch.min(stacked_input)
|
||||
in_max = torch.max(stacked_input)
|
||||
norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
|
||||
return (norm_stacked_input - self.mean) / self.std
|
||||
|
||||
|
||||
def get_style_model_and_losses(
|
||||
single_target=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
):
|
||||
# content_layers = ['conv_4']
|
||||
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
content_layers = ['conv3_2', 'conv4_2']
|
||||
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||
cnn = models.vgg19(pretrained=True).features.to(device).eval()
|
||||
# normalization module
|
||||
normalization = Normalization(device).to(device)
|
||||
|
||||
# just in order to have an iterable access to or list of content/style
|
||||
# losses
|
||||
content_losses = []
|
||||
style_losses = []
|
||||
|
||||
# assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential``
|
||||
# to put in modules that are supposed to be activated sequentially
|
||||
model = nn.Sequential(normalization)
|
||||
|
||||
i = 0 # increment every time we see a conv
|
||||
block = 1
|
||||
children = list(cnn.children())
|
||||
|
||||
for layer in children:
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
i += 1
|
||||
name = f'conv{block}_{i}_raw'
|
||||
elif isinstance(layer, nn.ReLU):
|
||||
# name = 'relu_{}'.format(i)
|
||||
name = f'conv{block}_{i}' # target this
|
||||
# The in-place version doesn't play very nicely with the ``ContentLoss``
|
||||
# and ``StyleLoss`` we insert below. So we replace with out-of-place
|
||||
# ones here.
|
||||
layer = nn.ReLU(inplace=False)
|
||||
elif isinstance(layer, nn.MaxPool2d):
|
||||
name = 'pool_{}'.format(i)
|
||||
block += 1
|
||||
i = 0
|
||||
elif isinstance(layer, nn.BatchNorm2d):
|
||||
name = 'bn_{}'.format(i)
|
||||
else:
|
||||
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
|
||||
|
||||
model.add_module(name, layer)
|
||||
|
||||
if name in content_layers:
|
||||
# add content loss:
|
||||
content_loss = ContentLoss(single_target=single_target, device=device)
|
||||
model.add_module("content_loss_{}_{}".format(block, i), content_loss)
|
||||
content_losses.append(content_loss)
|
||||
|
||||
if name in style_layers:
|
||||
# add style loss:
|
||||
style_loss = StyleLoss(single_target=single_target, device=device)
|
||||
model.add_module("style_loss_{}_{}".format(block, i), style_loss)
|
||||
style_losses.append(style_loss)
|
||||
|
||||
# now we trim off the layers after the last content and style losses
|
||||
for i in range(len(model) - 1, -1, -1):
|
||||
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
|
||||
break
|
||||
|
||||
model = model[:(i + 1)]
|
||||
|
||||
return model, style_losses, content_losses
|
||||
Reference in New Issue
Block a user