mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 02:01:29 +00:00
Added Critic support to VAE training. Still tweaking and working on it. Many other fixes
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import oyaml as yaml
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from toolkit.paths import TOOLKIT_ROOT
|
||||
@@ -29,6 +30,20 @@ def preprocess_config(config: OrderedDict):
|
||||
return config
|
||||
|
||||
|
||||
|
||||
# Fixes issue where yaml doesnt load exponents correctly
|
||||
fixed_loader = yaml.SafeLoader
|
||||
fixed_loader.add_implicit_resolver(
|
||||
u'tag:yaml.org,2002:float',
|
||||
re.compile(u'''^(?:
|
||||
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
||||
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
||||
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
||||
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|
||||
|[-+]?\\.(?:inf|Inf|INF)
|
||||
|\\.(?:nan|NaN|NAN))$''', re.X),
|
||||
list(u'-+0123456789.'))
|
||||
|
||||
def get_config(config_file_path):
|
||||
# first check if it is in the config folder
|
||||
config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path)
|
||||
@@ -56,7 +71,7 @@ def get_config(config_file_path):
|
||||
config = json.load(f, object_pairs_hook=OrderedDict)
|
||||
elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'):
|
||||
with open(real_config_path, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
config = yaml.load(f, Loader=fixed_loader)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file_path} must be a json or yaml file")
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ def total_variation(image):
|
||||
"""
|
||||
n_elements = image.shape[1] * image.shape[2] * image.shape[3]
|
||||
return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) +
|
||||
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
||||
torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements)
|
||||
|
||||
|
||||
class ComparativeTotalVariation(torch.nn.Module):
|
||||
@@ -21,3 +21,27 @@ class ComparativeTotalVariation(torch.nn.Module):
|
||||
|
||||
def forward(self, pred, target):
|
||||
return torch.abs(total_variation(pred) - total_variation(target))
|
||||
|
||||
|
||||
# Gradient penalty
|
||||
def get_gradient_penalty(critic, real, fake, device):
|
||||
with torch.autocast(device_type='cuda'):
|
||||
alpha = torch.rand(real.size(0), 1, 1, 1).to(device)
|
||||
interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
|
||||
d_interpolates = critic(interpolates)
|
||||
fake = torch.ones(real.size(0), 1, device=device)
|
||||
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=d_interpolates,
|
||||
inputs=interpolates,
|
||||
grad_outputs=fake,
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
only_inputs=True,
|
||||
)[0]
|
||||
|
||||
gradients = gradients.view(gradients.size(0), -1)
|
||||
gradient_norm = gradients.norm(2, dim=1)
|
||||
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
|
||||
return gradient_penalty
|
||||
|
||||
|
||||
@@ -13,6 +13,16 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict:
|
||||
# safetensors can only be one level deep
|
||||
for key, value in save_meta.items():
|
||||
# if not float, int, bool, or str, convert to json string
|
||||
if not isinstance(value, (float, int, bool, str)):
|
||||
if not isinstance(value, str):
|
||||
save_meta[key] = json.dumps(value)
|
||||
return save_meta
|
||||
|
||||
|
||||
def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
||||
parsed_meta = OrderedDict()
|
||||
for key, value in meta.items():
|
||||
try:
|
||||
parsed_meta[key] = json.loads(value)
|
||||
except json.decoder.JSONDecodeError:
|
||||
parsed_meta[key] = value
|
||||
return meta
|
||||
|
||||
18
toolkit/optimizer.py
Normal file
18
toolkit/optimizer.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_optimizer(
|
||||
params,
|
||||
optimizer_type='adam',
|
||||
learning_rate=1e-6
|
||||
):
|
||||
if optimizer_type == 'dadaptation':
|
||||
# dadaptation optimizer does not use standard learning rate. 1 is the default value
|
||||
import dadaptation
|
||||
print("Using DAdaptAdam optimizer")
|
||||
optimizer = dadaptation.DAdaptAdam(params, lr=1.0)
|
||||
elif optimizer_type == 'adam':
|
||||
optimizer = torch.optim.Adam(params, lr=float(learning_rate))
|
||||
else:
|
||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||
return optimizer
|
||||
@@ -21,6 +21,7 @@ class ContentLoss(nn.Module):
|
||||
self.loss = None
|
||||
|
||||
def forward(self, stacked_input):
|
||||
|
||||
if self.single_target:
|
||||
split_size = stacked_input.size()[0] // 2
|
||||
pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0)
|
||||
@@ -73,6 +74,8 @@ class StyleLoss(nn.Module):
|
||||
self.device = device
|
||||
|
||||
def forward(self, stacked_input):
|
||||
input_dtype = stacked_input.dtype
|
||||
stacked_input = stacked_input.float()
|
||||
if self.single_target:
|
||||
split_size = stacked_input.size()[0] // 2
|
||||
preds, style_target = torch.split(stacked_input, split_size, dim=0)
|
||||
@@ -94,17 +97,18 @@ class StyleLoss(nn.Module):
|
||||
itemized_loss = torch.unsqueeze(itemized_loss, dim=1)
|
||||
# gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2])
|
||||
loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True)
|
||||
self.loss = loss
|
||||
return stacked_input
|
||||
self.loss = loss.to(input_dtype)
|
||||
return stacked_input.to(input_dtype)
|
||||
|
||||
|
||||
# create a module to normalize input image so we can easily put it in a
|
||||
# ``nn.Sequential``
|
||||
class Normalization(nn.Module):
|
||||
def __init__(self, device):
|
||||
def __init__(self, device, dtype=torch.float32):
|
||||
super(Normalization, self).__init__()
|
||||
mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
|
||||
std = torch.tensor([0.229, 0.224, 0.225]).to(device)
|
||||
self.dtype = dtype
|
||||
# .view the mean and std to make them [C x 1 x 1] so that they can
|
||||
# directly work with image Tensor of shape [B x C x H x W].
|
||||
# B is batch size. C is number of channels. H is height and W is width.
|
||||
@@ -112,9 +116,9 @@ class Normalization(nn.Module):
|
||||
self.std = torch.tensor(std).view(-1, 1, 1)
|
||||
|
||||
def forward(self, stacked_input):
|
||||
# cast to float 32 if not already
|
||||
if stacked_input.dtype != torch.float32:
|
||||
stacked_input = stacked_input.float()
|
||||
# cast to float 32 if not already # only necessary when processing gram matrix
|
||||
# if stacked_input.dtype != torch.float32:
|
||||
# stacked_input = stacked_input.float()
|
||||
# remove alpha channel if it exists
|
||||
if stacked_input.shape[1] == 4:
|
||||
stacked_input = stacked_input[:, :3, :, :]
|
||||
@@ -123,21 +127,37 @@ class Normalization(nn.Module):
|
||||
in_max = torch.max(stacked_input)
|
||||
# norm_stacked_input = (stacked_input - in_min) / (in_max - in_min)
|
||||
# return (norm_stacked_input - self.mean) / self.std
|
||||
return (stacked_input - self.mean) / self.std
|
||||
return ((stacked_input - self.mean) / self.std).to(self.dtype)
|
||||
|
||||
|
||||
class OutputLayer(nn.Module):
|
||||
def __init__(self, name='output_layer'):
|
||||
super(OutputLayer, self).__init__()
|
||||
self.name = name
|
||||
self.tensor = None
|
||||
|
||||
def forward(self, stacked_input):
|
||||
self.tensor = stacked_input
|
||||
return stacked_input
|
||||
|
||||
|
||||
def get_style_model_and_losses(
|
||||
single_target=False,
|
||||
single_target=True, # false has 3 targets, dont remember why i added this initially, this is old code
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
output_layer_name=None,
|
||||
dtype=torch.float32
|
||||
):
|
||||
# content_layers = ['conv_4']
|
||||
# style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
content_layers = ['conv4_2']
|
||||
style_layers = ['conv2_1', 'conv3_1', 'conv4_1']
|
||||
cnn = models.vgg19(pretrained=True).features.to(device).eval()
|
||||
cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval()
|
||||
# set all weights in the model to our dtype
|
||||
# for layer in cnn.children():
|
||||
# layer.to(dtype=dtype)
|
||||
|
||||
# normalization module
|
||||
normalization = Normalization(device).to(device)
|
||||
normalization = Normalization(device, dtype=dtype).to(device)
|
||||
|
||||
# just in order to have an iterable access to or list of content/style
|
||||
# losses
|
||||
@@ -189,15 +209,15 @@ def get_style_model_and_losses(
|
||||
style_losses.append(style_loss)
|
||||
|
||||
if output_layer_name is not None and name == output_layer_name:
|
||||
output_layer = layer
|
||||
output_layer = OutputLayer(name)
|
||||
model.add_module("output_layer_{}_{}".format(block, i), output_layer)
|
||||
|
||||
# now we trim off the layers after the last content and style losses
|
||||
for i in range(len(model) - 1, -1, -1):
|
||||
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
|
||||
break
|
||||
if output_layer_name is not None and model[i].name == output_layer_name:
|
||||
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss) or isinstance(model[i], OutputLayer):
|
||||
break
|
||||
|
||||
model = model[:(i + 1)]
|
||||
model.to(dtype=dtype)
|
||||
|
||||
return model, style_losses, content_losses, output_layer
|
||||
|
||||
Reference in New Issue
Block a user