mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Handle conversions back to ldm for saving
This commit is contained in:
@@ -127,11 +127,12 @@ class Normalization(nn.Module):
|
||||
|
||||
def get_style_model_and_losses(
|
||||
single_target=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
output_layer_name=None,
|
||||
):
|
||||
# content_layers = ['conv_4']
|
||||
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
content_layers = ['conv3_2', 'conv4_2']
|
||||
content_layers = ['conv4_2']
|
||||
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||
cnn = models.vgg19(pretrained=True).features.to(device).eval()
|
||||
# normalization module
|
||||
@@ -150,6 +151,8 @@ def get_style_model_and_losses(
|
||||
block = 1
|
||||
children = list(cnn.children())
|
||||
|
||||
output_layer = None
|
||||
|
||||
for layer in children:
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
i += 1
|
||||
@@ -184,11 +187,16 @@ def get_style_model_and_losses(
|
||||
model.add_module("style_loss_{}_{}".format(block, i), style_loss)
|
||||
style_losses.append(style_loss)
|
||||
|
||||
if output_layer_name is not None and name == output_layer_name:
|
||||
output_layer = layer
|
||||
|
||||
# now we trim off the layers after the last content and style losses
|
||||
for i in range(len(model) - 1, -1, -1):
|
||||
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
|
||||
break
|
||||
if output_layer_name is not None and model[i].name == output_layer_name:
|
||||
break
|
||||
|
||||
model = model[:(i + 1)]
|
||||
|
||||
return model, style_losses, content_losses
|
||||
return model, style_losses, content_losses, output_layer
|
||||
|
||||
Reference in New Issue
Block a user