Added Critic support to VAE training. Still tweaking and working on it. Many other fixes

This commit is contained in:
Jaret Burkett
2023-07-19 15:57:32 -06:00
parent 6ada328d8d
commit 557732e7ff
9 changed files with 415 additions and 59 deletions

View File

@@ -21,6 +21,7 @@ class ContentLoss(nn.Module):
self.loss = None
def forward(self, stacked_input):
if self.single_target:
split_size = stacked_input.size()[0] // 2
pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0)
@@ -73,6 +74,8 @@ class StyleLoss(nn.Module):
self.device = device
def forward(self, stacked_input):
input_dtype = stacked_input.dtype
stacked_input = stacked_input.float()
if self.single_target:
split_size = stacked_input.size()[0] // 2
preds, style_target = torch.split(stacked_input, split_size, dim=0)
@@ -94,17 +97,18 @@ class StyleLoss(nn.Module):
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
self.loss = loss
return stacked_input
self.loss = loss.to(input_dtype)
return stacked_input.to(input_dtype)
# create a module to normalize input image so we can easily put it in a
# ``nn.Sequential``
class Normalization(nn.Module):
def __init__(self, device):
def __init__(self, device, dtype=torch.float32):
super(Normalization, self).__init__()
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
self.dtype = dtype
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
@@ -112,9 +116,9 @@ class Normalization(nn.Module):
self.std = torch.tensor(std).view(-1, 1, 1)
def forward(self, stacked_input):
# cast to float 32 if not already
if stacked_input.dtype != torch.float32:
stacked_input = stacked_input.float()
# cast to float 32 if not already # only necessary when processing gram matrix
# if stacked_input.dtype != torch.float32:
# stacked_input = stacked_input.float()
# remove alpha channel if it exists
if stacked_input.shape[1] == 4:
stacked_input = stacked_input[:, :3, :, :]
@@ -123,21 +127,37 @@ class Normalization(nn.Module):
in_max = torch.max(stacked_input)
# norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
# return (norm_stacked_input - self.mean) / self.std
return (stacked_input - self.mean) / self.std
return ((stacked_input - self.mean) / self.std).to(self.dtype)
class OutputLayer(nn.Module):
def __init__(self, name='output_layer'):
super(OutputLayer, self).__init__()
self.name = name
self.tensor = None
def forward(self, stacked_input):
self.tensor = stacked_input
return stacked_input
def get_style_model_and_losses(
single_target=False,
single_target=True, # false has 3 targets, dont remember why i added this initially, this is old code
device='cuda' if torch.cuda.is_available() else 'cpu',
output_layer_name=None,
dtype=torch.float32
):
# content_layers = ['conv_4']
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers = ['conv4_2']
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
# set all weights in the model to our dtype
# for layer in cnn.children():
# layer.to(dtype=dtype)
# normalization module
normalization = Normalization(device).to(device)
normalization = Normalization(device, dtype=dtype).to(device)
# just in order to have an iterable access to or list of content/style
# losses
@@ -189,15 +209,15 @@ def get_style_model_and_losses(
style_losses.append(style_loss)
if output_layer_name is not None and name == output_layer_name:
output_layer = layer
output_layer = OutputLayer(name)
model.add_module("output_layer_{}_{}".format(block, i), output_layer)
# now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break
if output_layer_name is not None and model[i].name == output_layer_name:
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss) or isinstance(model[i], OutputLayer):
break
model = model[:(i + 1)]
model.to(dtype=dtype)
return model, style_losses, content_losses, output_layer