mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Style and content loss working
This commit is contained in:
@@ -15,7 +15,9 @@ from jobs.process import BaseTrainProcess
|
|||||||
from toolkit.kohya_model_util import load_vae
|
from toolkit.kohya_model_util import load_vae
|
||||||
from toolkit.data_loader import ImageDataset
|
from toolkit.data_loader import ImageDataset
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
|
from toolkit.style import get_style_model_and_losses
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -27,15 +29,9 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
INVERSE_IMAGE_TRANSFORMS = transforms.Compose(
|
|
||||||
[
|
def unnormalize(tensor):
|
||||||
transforms.Normalize(
|
return (tensor / 2 + 0.5).clamp(0, 1)
|
||||||
mean=[-0.5/0.5],
|
|
||||||
std=[1/0.5]
|
|
||||||
),
|
|
||||||
transforms.ToPILImage(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainVAEProcess(BaseTrainProcess):
|
class TrainVAEProcess(BaseTrainProcess):
|
||||||
@@ -56,8 +52,12 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.save_every = self.get_conf('save_every', None)
|
self.save_every = self.get_conf('save_every', None)
|
||||||
self.dtype = self.get_conf('dtype', 'float32')
|
self.dtype = self.get_conf('dtype', 'float32')
|
||||||
self.sample_sources = self.get_conf('sample_sources', None)
|
self.sample_sources = self.get_conf('sample_sources', None)
|
||||||
|
self.style_weight = self.get_conf('style_weight', 1e4)
|
||||||
|
self.content_weight = self.get_conf('content_weight', 1)
|
||||||
|
self.elbo_weight = self.get_conf('elbo_weight', 1e-8)
|
||||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||||
self.save_root = os.path.join(self.training_folder, self.job.name)
|
self.save_root = os.path.join(self.training_folder, self.job.name)
|
||||||
|
self.vgg_19 = None
|
||||||
|
|
||||||
if self.sample_every is not None and self.sample_sources is None:
|
if self.sample_every is not None and self.sample_sources is None:
|
||||||
raise ValueError('sample_every is specified but sample_sources is not')
|
raise ValueError('sample_every is specified but sample_sources is not')
|
||||||
@@ -66,7 +66,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
raise ValueError('epochs or max_steps must be specified')
|
raise ValueError('epochs or max_steps must be specified')
|
||||||
|
|
||||||
self.data_loaders = []
|
self.data_loaders = []
|
||||||
datasets = []
|
|
||||||
# check datasets
|
# check datasets
|
||||||
assert isinstance(self.datasets_objects, list)
|
assert isinstance(self.datasets_objects, list)
|
||||||
for dataset in self.datasets_objects:
|
for dataset in self.datasets_objects:
|
||||||
@@ -95,10 +94,17 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.data_loader = DataLoader(
|
self.data_loader = DataLoader(
|
||||||
concatenated_dataset,
|
concatenated_dataset,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
shuffle=True
|
shuffle=True,
|
||||||
|
num_workers=6
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_loss(self, pred, target):
|
def setup_vgg19(self):
|
||||||
|
if self.vgg_19 is None:
|
||||||
|
self.vgg_19, self.style_losses, self.content_losses = get_style_model_and_losses(
|
||||||
|
single_target=True, device=self.device)
|
||||||
|
self.vgg_19.requires_grad_(False)
|
||||||
|
|
||||||
|
def get_mse_loss(self, pred, target):
|
||||||
loss_fn = nn.MSELoss()
|
loss_fn = nn.MSELoss()
|
||||||
loss = loss_fn(pred, target)
|
loss = loss_fn(pred, target)
|
||||||
return loss
|
return loss
|
||||||
@@ -157,8 +163,18 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
input_img = img
|
input_img = img
|
||||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
||||||
decoded = self.vae(img).sample.squeeze(0)
|
img = img
|
||||||
decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
|
decoded = self.vae(img).sample
|
||||||
|
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
||||||
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||||
|
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||||
|
|
||||||
|
#convert to pillow image
|
||||||
|
decoded = Image.fromarray((decoded * 255).astype(np.uint8))
|
||||||
|
|
||||||
|
# # decoded = decoded - 0.1
|
||||||
|
# decoded = decoded
|
||||||
|
# decoded = INVERSE_IMAGE_TRANSFORMS(decoded)
|
||||||
|
|
||||||
# stack input image and decoded image
|
# stack input image and decoded image
|
||||||
input_img = input_img.resize((self.resolution, self.resolution))
|
input_img = input_img.resize((self.resolution, self.resolution))
|
||||||
@@ -177,7 +193,6 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
i_str = str(i).zfill(2)
|
i_str = str(i).zfill(2)
|
||||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||||
output_img.save(os.path.join(sample_folder, filename))
|
output_img.save(os.path.join(sample_folder, filename))
|
||||||
self.vae.decoder.train()
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run()
|
super().run()
|
||||||
@@ -204,19 +219,41 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
print(f"Loading VAE")
|
print(f"Loading VAE")
|
||||||
print(f" - Loading VAE: {self.vae_path}")
|
print(f" - Loading VAE: {self.vae_path}")
|
||||||
if self.vae is None:
|
if self.vae is None:
|
||||||
|
# self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
||||||
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
self.vae = load_vae(self.vae_path, dtype=self.torch_dtype)
|
||||||
|
|
||||||
# set decoder to train
|
# set decoder to train
|
||||||
self.vae.to(self.device, dtype=self.torch_dtype)
|
self.vae.to(self.device, dtype=self.torch_dtype)
|
||||||
self.vae.requires_grad_(False)
|
self.vae.requires_grad_(False)
|
||||||
self.vae.eval()
|
self.vae.eval()
|
||||||
|
|
||||||
self.vae.decoder.requires_grad_(True)
|
|
||||||
self.vae.decoder.train()
|
self.vae.decoder.train()
|
||||||
|
|
||||||
parameters = self.vae.decoder.parameters()
|
blocks_to_train = [
|
||||||
|
'mid_block',
|
||||||
|
'up_blocks',
|
||||||
|
]
|
||||||
|
|
||||||
optimizer = torch.optim.Adam(parameters, lr=self.learning_rate)
|
params = []
|
||||||
|
|
||||||
|
# only set last 2 layers to trainable
|
||||||
|
for param in self.vae.decoder.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if 'mid_block' in blocks_to_train:
|
||||||
|
params += list(self.vae.decoder.mid_block.parameters())
|
||||||
|
self.vae.decoder.mid_block.requires_grad_(True)
|
||||||
|
if 'up_blocks' in blocks_to_train:
|
||||||
|
params += list(self.vae.decoder.up_blocks.parameters())
|
||||||
|
self.vae.decoder.up_blocks.requires_grad_(True)
|
||||||
|
|
||||||
|
# self.vae.decoder.train()
|
||||||
|
|
||||||
|
self.setup_vgg19()
|
||||||
|
self.vgg_19.requires_grad_(False)
|
||||||
|
self.vgg_19.eval()
|
||||||
|
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(params, lr=self.learning_rate)
|
||||||
|
|
||||||
# setup scheduler
|
# setup scheduler
|
||||||
# scheduler = lr_scheduler.ConstantLR
|
# scheduler = lr_scheduler.ConstantLR
|
||||||
@@ -249,6 +286,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
# with torch.no_grad():
|
# with torch.no_grad():
|
||||||
|
# batch = batch + 0.1
|
||||||
dgd = self.vae.encode(batch).latent_dist
|
dgd = self.vae.encode(batch).latent_dist
|
||||||
mu, logvar = dgd.mean, dgd.logvar
|
mu, logvar = dgd.mean, dgd.logvar
|
||||||
latents = dgd.sample()
|
latents = dgd.sample()
|
||||||
@@ -256,7 +294,24 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
pred = self.vae.decode(latents).sample
|
pred = self.vae.decode(latents).sample
|
||||||
|
|
||||||
loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
# pred = pred + 0.1
|
||||||
|
|
||||||
|
# loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
||||||
|
|
||||||
|
stacked = torch.cat([pred, batch], dim=0)
|
||||||
|
stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||||
|
self.vgg_19(stacked)
|
||||||
|
# reduce the mean of the style_loss list
|
||||||
|
|
||||||
|
style_loss = torch.sum(torch.stack([loss.loss for loss in self.style_losses]))
|
||||||
|
content_loss = torch.sum(torch.stack([loss.loss for loss in self.content_losses]))
|
||||||
|
elbo_loss = self.get_elbo_loss(pred, batch, mu, logvar)
|
||||||
|
# elbo_loss = torch.zeros(1, device=self.device, dtype=self.torch_dtype)
|
||||||
|
style_loss = style_loss * self.style_weight
|
||||||
|
content_loss = content_loss * self.content_weight
|
||||||
|
elbo_loss = elbo_loss * self.elbo_weight
|
||||||
|
|
||||||
|
loss = style_loss + content_loss + elbo_loss
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -267,9 +322,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
# update progress bar
|
# update progress bar
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
# get exponent like 3.54e-4
|
# get exponent like 3.54e-4
|
||||||
loss_string = f"{loss_value:.2e}"
|
loss_string = f"loss: {loss_value:.2e} cnt: {content_loss.item():.2e} sty: {style_loss.item():.2e} elbo: {elbo_loss.item():.2e}"
|
||||||
learning_rate = optimizer.param_groups[0]['lr']
|
learning_rate = optimizer.param_groups[0]['lr']
|
||||||
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} Loss: {loss_string}")
|
progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}")
|
||||||
progress_bar.set_description(f"E: {epoch} - S: {step} ")
|
progress_bar.set_description(f"E: {epoch} - S: {step} ")
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
|
|
||||||
@@ -279,7 +334,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
print(f"Sampling at step {step}")
|
print(f"Sampling at step {step}")
|
||||||
self.sample(step)
|
self.sample(step)
|
||||||
|
|
||||||
if self.save_every and step % self.save_every == 0:
|
if self.save_every and step % self.save_every == 0:
|
||||||
# print above the progress bar
|
# print above the progress bar
|
||||||
print(f"Saving at step {step}")
|
print(f"Saving at step {step}")
|
||||||
self.save(step)
|
self.save(step)
|
||||||
|
|||||||
@@ -51,15 +51,20 @@ class ImageDataset(Dataset):
|
|||||||
|
|
||||||
# Downscale the source image first
|
# Downscale the source image first
|
||||||
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
|
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
|
||||||
|
min_img_size = min(img.size)
|
||||||
|
|
||||||
if self.random_crop:
|
if self.random_crop:
|
||||||
if self.random_scale:
|
if self.random_scale and min_img_size > self.resolution:
|
||||||
scale_size = random.randint(int(img.size[0] * self.scale), self.resolution)
|
if min_img_size < self.resolution:
|
||||||
|
print(
|
||||||
|
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}")
|
||||||
|
scale_size = self.resolution
|
||||||
|
else:
|
||||||
|
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||||
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
img = img.resize((scale_size, scale_size), Image.BICUBIC)
|
||||||
img = transforms.RandomCrop(self.resolution)(img)
|
img = transforms.RandomCrop(self.resolution)(img)
|
||||||
else:
|
else:
|
||||||
min_dimension = min(img.size)
|
img = transforms.CenterCrop(min_img_size)(img)
|
||||||
img = transforms.CenterCrop(min_dimension)(img)
|
|
||||||
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
||||||
|
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
|
|||||||
194
toolkit/style.py
Normal file
194
toolkit/style.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
from torchvision import models
|
||||||
|
|
||||||
|
|
||||||
|
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
def tensor_size(tensor):
|
||||||
|
channels = tensor.shape[1]
|
||||||
|
height = tensor.shape[2]
|
||||||
|
width = tensor.shape[3]
|
||||||
|
return channels * height * width
|
||||||
|
|
||||||
|
class ContentLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||||
|
super(ContentLoss, self).__init__()
|
||||||
|
self.single_target = single_target
|
||||||
|
self.device = device
|
||||||
|
self.loss = None
|
||||||
|
|
||||||
|
def forward(self, stacked_input):
|
||||||
|
if self.single_target:
|
||||||
|
split_size = stacked_input.size()[0] // 2
|
||||||
|
pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0)
|
||||||
|
else:
|
||||||
|
split_size = stacked_input.size()[0] // 3
|
||||||
|
pred_layer, _, target_layer = torch.split(stacked_input, split_size, dim=0)
|
||||||
|
|
||||||
|
content_size = tensor_size(pred_layer)
|
||||||
|
|
||||||
|
# Define the separate loss function
|
||||||
|
def separated_loss(y_pred, y_true):
|
||||||
|
diff = torch.abs(y_pred - y_true)
|
||||||
|
l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
|
||||||
|
return 2. * l2 / content_size
|
||||||
|
|
||||||
|
# Calculate itemized loss
|
||||||
|
pred_itemized_loss = separated_loss(pred_layer, target_layer)
|
||||||
|
|
||||||
|
# Calculate the mean of itemized loss
|
||||||
|
loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
|
||||||
|
self.loss = loss
|
||||||
|
|
||||||
|
return stacked_input
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_gram_matrix(inputs):
|
||||||
|
shape = inputs.size()
|
||||||
|
batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
|
||||||
|
size = height * width * filters
|
||||||
|
|
||||||
|
feats = inputs.view(batch, filters, height * width)
|
||||||
|
feats_t = feats.transpose(1, 2)
|
||||||
|
grams_raw = torch.matmul(feats, feats_t)
|
||||||
|
gram_matrix = grams_raw / size
|
||||||
|
|
||||||
|
return gram_matrix
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
# Now the style loss module looks almost exactly like the content loss
|
||||||
|
# module. The style distance is also computed using the mean square
|
||||||
|
# error between :math:`G_{XL}` and :math:`G_{SL}`.
|
||||||
|
#
|
||||||
|
|
||||||
|
class StyleLoss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, single_target=False, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
||||||
|
super(StyleLoss, self).__init__()
|
||||||
|
self.single_target = single_target
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def forward(self, stacked_input):
|
||||||
|
if self.single_target:
|
||||||
|
split_size = stacked_input.size()[0] // 2
|
||||||
|
preds, style_target = torch.split(stacked_input, split_size, dim=0)
|
||||||
|
else:
|
||||||
|
split_size = stacked_input.size()[0] // 3
|
||||||
|
preds, style_target, _ = torch.split(stacked_input, split_size, dim=0)
|
||||||
|
|
||||||
|
def separated_loss(y_pred, y_true):
|
||||||
|
gram_size = y_true.size(1) * y_true.size(2)
|
||||||
|
sum_axis = (1, 2)
|
||||||
|
diff = torch.abs(y_pred - y_true)
|
||||||
|
raw_loss = torch.sum(diff ** 2, dim=sum_axis, keepdim=True)
|
||||||
|
return raw_loss / gram_size
|
||||||
|
|
||||||
|
target_grams = convert_to_gram_matrix(style_target)
|
||||||
|
pred_grams = convert_to_gram_matrix(preds)
|
||||||
|
itemized_loss = separated_loss(pred_grams, target_grams)
|
||||||
|
# reshape itemized loss to be (batch, 1, 1, 1)
|
||||||
|
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
|
||||||
|
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
|
||||||
|
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
|
||||||
|
self.loss = loss
|
||||||
|
return stacked_input
|
||||||
|
|
||||||
|
|
||||||
|
# create a module to normalize input image so we can easily put it in a
|
||||||
|
# ``nn.Sequential``
|
||||||
|
class Normalization(nn.Module):
|
||||||
|
def __init__(self, device):
|
||||||
|
super(Normalization, self).__init__()
|
||||||
|
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
|
||||||
|
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
|
||||||
|
# .view the mean and std to make them [C x 1 x 1] so that they can
|
||||||
|
# directly work with image Tensor of shape [B x C x H x W].
|
||||||
|
# B is batch size. C is number of channels. H is height and W is width.
|
||||||
|
self.mean = torch.tensor(mean).view(-1, 1, 1)
|
||||||
|
self.std = torch.tensor(std).view(-1, 1, 1)
|
||||||
|
|
||||||
|
def forward(self, stacked_input):
|
||||||
|
# cast to float 32 if not already
|
||||||
|
if stacked_input.dtype != torch.float32:
|
||||||
|
stacked_input = stacked_input.float()
|
||||||
|
# remove alpha channel if it exists
|
||||||
|
if stacked_input.shape[1] == 4:
|
||||||
|
stacked_input = stacked_input[:, :3, :, :]
|
||||||
|
# normalize to min and max of 0 - 1
|
||||||
|
in_min = torch.min(stacked_input)
|
||||||
|
in_max = torch.max(stacked_input)
|
||||||
|
norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
|
||||||
|
return (norm_stacked_input - self.mean) / self.std
|
||||||
|
|
||||||
|
|
||||||
|
def get_style_model_and_losses(
|
||||||
|
single_target=False,
|
||||||
|
device='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
):
|
||||||
|
# content_layers = ['conv_4']
|
||||||
|
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||||
|
content_layers = ['conv3_2', 'conv4_2']
|
||||||
|
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||||
|
cnn = models.vgg19(pretrained=True).features.to(device).eval()
|
||||||
|
# normalization module
|
||||||
|
normalization = Normalization(device).to(device)
|
||||||
|
|
||||||
|
# just in order to have an iterable access to or list of content/style
|
||||||
|
# losses
|
||||||
|
content_losses = []
|
||||||
|
style_losses = []
|
||||||
|
|
||||||
|
# assuming that ``cnn`` is a ``nn.Sequential``, so we make a new ``nn.Sequential``
|
||||||
|
# to put in modules that are supposed to be activated sequentially
|
||||||
|
model = nn.Sequential(normalization)
|
||||||
|
|
||||||
|
i = 0 # increment every time we see a conv
|
||||||
|
block = 1
|
||||||
|
children = list(cnn.children())
|
||||||
|
|
||||||
|
for layer in children:
|
||||||
|
if isinstance(layer, nn.Conv2d):
|
||||||
|
i += 1
|
||||||
|
name = f'conv{block}_{i}_raw'
|
||||||
|
elif isinstance(layer, nn.ReLU):
|
||||||
|
# name = 'relu_{}'.format(i)
|
||||||
|
name = f'conv{block}_{i}' # target this
|
||||||
|
# The in-place version doesn't play very nicely with the ``ContentLoss``
|
||||||
|
# and ``StyleLoss`` we insert below. So we replace with out-of-place
|
||||||
|
# ones here.
|
||||||
|
layer = nn.ReLU(inplace=False)
|
||||||
|
elif isinstance(layer, nn.MaxPool2d):
|
||||||
|
name = 'pool_{}'.format(i)
|
||||||
|
block += 1
|
||||||
|
i = 0
|
||||||
|
elif isinstance(layer, nn.BatchNorm2d):
|
||||||
|
name = 'bn_{}'.format(i)
|
||||||
|
else:
|
||||||
|
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
|
||||||
|
|
||||||
|
model.add_module(name, layer)
|
||||||
|
|
||||||
|
if name in content_layers:
|
||||||
|
# add content loss:
|
||||||
|
content_loss = ContentLoss(single_target=single_target, device=device)
|
||||||
|
model.add_module("content_loss_{}_{}".format(block, i), content_loss)
|
||||||
|
content_losses.append(content_loss)
|
||||||
|
|
||||||
|
if name in style_layers:
|
||||||
|
# add style loss:
|
||||||
|
style_loss = StyleLoss(single_target=single_target, device=device)
|
||||||
|
model.add_module("style_loss_{}_{}".format(block, i), style_loss)
|
||||||
|
style_losses.append(style_loss)
|
||||||
|
|
||||||
|
# now we trim off the layers after the last content and style losses
|
||||||
|
for i in range(len(model) - 1, -1, -1):
|
||||||
|
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
|
||||||
|
break
|
||||||
|
|
||||||
|
model = model[:(i + 1)]
|
||||||
|
|
||||||
|
return model, style_losses, content_losses
|
||||||
Reference in New Issue
Block a user