mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Fixes to esrgan trainer. Moved logic for sd prompt embeddings out of diffusers pipeline so I can manipulate it
This commit is contained in:
@@ -33,12 +33,17 @@ class ContentLoss(nn.Module):
|
||||
|
||||
# Define the separate loss function
|
||||
def separated_loss(y_pred, y_true):
|
||||
y_pred = y_pred.float()
|
||||
y_true = y_true.float()
|
||||
diff = torch.abs(y_pred - y_true)
|
||||
l2 = torch.sum(diff ** 2, dim=[1, 2, 3], keepdim=True) / 2.0
|
||||
return 2. * l2 / content_size
|
||||
|
||||
# Calculate itemized loss
|
||||
pred_itemized_loss = separated_loss(pred_layer, target_layer)
|
||||
# check if is nan
|
||||
if torch.isnan(pred_itemized_loss).any():
|
||||
print('pred_itemized_loss is nan')
|
||||
|
||||
# Calculate the mean of itemized loss
|
||||
loss = torch.mean(pred_itemized_loss, dim=(1, 2, 3), keepdim=True)
|
||||
@@ -48,6 +53,7 @@ class ContentLoss(nn.Module):
|
||||
|
||||
|
||||
def convert_to_gram_matrix(inputs):
|
||||
inputs = inputs.float()
|
||||
shape = inputs.size()
|
||||
batch, filters, height, width = shape[0], shape[1], shape[2], shape[3]
|
||||
size = height * width * filters
|
||||
@@ -93,11 +99,14 @@ class StyleLoss(nn.Module):
|
||||
target_grams = convert_to_gram_matrix(style_target)
|
||||
pred_grams = convert_to_gram_matrix(preds)
|
||||
itemized_loss = separated_loss(pred_grams, target_grams)
|
||||
# check if is nan
|
||||
if torch.isnan(itemized_loss).any():
|
||||
print('itemized_loss is nan')
|
||||
# reshape itemized loss to be (batch, 1, 1, 1)
|
||||
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
|
||||
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
|
||||
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
|
||||
self.loss = loss.to(input_dtype)
|
||||
self.loss = loss.to(input_dtype).float()
|
||||
return stacked_input.to(input_dtype)
|
||||
|
||||
|
||||
@@ -149,7 +158,7 @@ def get_style_model_and_losses(
|
||||
):
|
||||
# content_layers = ['conv_4']
|
||||
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
content_layers = ['conv4_2']
|
||||
content_layers = ['conv2_2', 'conv3_2', 'conv4_2', 'conv5_2']
|
||||
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
|
||||
# set all weights in the model to our dtype
|
||||
|
||||
Reference in New Issue
Block a user