Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it

This commit is contained in:
Jaret Burkett
2023-09-16 17:41:07 -06:00
parent 27f343fc08
commit c698837241
11 changed files with 214 additions and 78 deletions

View File

@@ -33,12 +33,17 @@ class ContentLoss(nn.Module):
# Define the separate loss function
def separated_loss(y_pred, y_true):
y_pred = y_pred.float()
y_true = y_true.float()
diff = torch.abs(y_pred - y_true)
l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
return 2. * l2 / content_size
# Calculate itemized loss
pred_itemized_loss = separated_loss(pred_layer, target_layer)
# check if is nan
if torch.isnan(pred_itemized_loss).any():
print('pred_itemized_loss is nan')
# Calculate the mean of itemized loss
loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
@@ -48,6 +53,7 @@ class ContentLoss(nn.Module):
def convert_to_gram_matrix(inputs):
inputs = inputs.float()
shape = inputs.size()
batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
size = height * width * filters
@@ -93,11 +99,14 @@ class StyleLoss(nn.Module):
target_grams = convert_to_gram_matrix(style_target)
pred_grams = convert_to_gram_matrix(preds)
itemized_loss = separated_loss(pred_grams, target_grams)
# check if is nan
if torch.isnan(itemized_loss).any():
print('itemized_loss is nan')
# reshape itemized loss to be (batch, 1, 1, 1)
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
self.loss = loss.to(input_dtype)
self.loss = loss.to(input_dtype).float()
return stacked_input.to(input_dtype)
@@ -149,7 +158,7 @@ def get_style_model_and_losses(
):
# content_layers = ['conv_4']
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
content_layers = ['conv4_2']
content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
# set all weights in the model to our dtype