mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added my good ole pattern loss. God I love that thing, conv transpose pattern instantly wiped from vae
This commit is contained in:
@@ -15,7 +15,7 @@ sys.path.append(REPOS_ROOT)
|
|||||||
|
|
||||||
process_dict = {
|
process_dict = {
|
||||||
'vae': 'TrainVAEProcess',
|
'vae': 'TrainVAEProcess',
|
||||||
'finetune': 'TrainFineTuneProcess'
|
'slider': 'TrainSliderProcess',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -27,8 +27,8 @@ class TrainJob(BaseJob):
|
|||||||
self.training_folder = self.get_conf('training_folder', required=True)
|
self.training_folder = self.get_conf('training_folder', required=True)
|
||||||
self.is_v2 = self.get_conf('is_v2', False)
|
self.is_v2 = self.get_conf('is_v2', False)
|
||||||
self.device = self.get_conf('device', 'cpu')
|
self.device = self.get_conf('device', 'cpu')
|
||||||
self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1)
|
# self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1)
|
||||||
self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
# self.mixed_precision = self.get_conf('mixed_precision', False) # fp16
|
||||||
self.log_dir = self.get_conf('log_dir', None)
|
self.log_dir = self.get_conf('log_dir', None)
|
||||||
|
|
||||||
self.writer = None
|
self.writer = None
|
||||||
|
|||||||
@@ -5,13 +5,14 @@ import itertools
|
|||||||
|
|
||||||
|
|
||||||
class LosslessLatentDecoder(nn.Module):
|
class LosslessLatentDecoder(nn.Module):
|
||||||
def __init__(self, in_channels, latent_depth):
|
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||||
super(LosslessLatentDecoder, self).__init__()
|
super(LosslessLatentDecoder, self).__init__()
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
self.latent_depth = latent_depth
|
self.latent_depth = latent_depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
self.out_channels = int(in_channels // (latent_depth * latent_depth))
|
||||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||||
self.kernel = torch.from_numpy(numpy_kernel).float()
|
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
def build_kernel(self, in_channels, latent_depth):
|
def build_kernel(self, in_channels, latent_depth):
|
||||||
# my old code from tensorflow.
|
# my old code from tensorflow.
|
||||||
@@ -39,13 +40,15 @@ class LosslessLatentDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LosslessLatentEncoder(nn.Module):
|
class LosslessLatentEncoder(nn.Module):
|
||||||
def __init__(self, in_channels, latent_depth):
|
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||||
super(LosslessLatentEncoder, self).__init__()
|
super(LosslessLatentEncoder, self).__init__()
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
self.latent_depth = latent_depth
|
self.latent_depth = latent_depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
self.out_channels = int(in_channels * (latent_depth * latent_depth))
|
||||||
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
numpy_kernel = self.build_kernel(in_channels, latent_depth)
|
||||||
self.kernel = torch.from_numpy(numpy_kernel).float()
|
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def build_kernel(self, in_channels, latent_depth):
|
def build_kernel(self, in_channels, latent_depth):
|
||||||
# my old code from tensorflow.
|
# my old code from tensorflow.
|
||||||
@@ -72,13 +75,13 @@ class LosslessLatentEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LosslessLatentVAE(nn.Module):
|
class LosslessLatentVAE(nn.Module):
|
||||||
def __init__(self, in_channels, latent_depth):
|
def __init__(self, in_channels, latent_depth, dtype=torch.float32):
|
||||||
super(LosslessLatentVAE, self).__init__()
|
super(LosslessLatentVAE, self).__init__()
|
||||||
self.latent_depth = latent_depth
|
self.latent_depth = latent_depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.encoder = LosslessLatentEncoder(in_channels, latent_depth)
|
self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype)
|
||||||
encoder_out_channels = self.encoder.out_channels
|
encoder_out_channels = self.encoder.out_channels
|
||||||
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth)
|
self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
latent = self.latent_encoder(x)
|
latent = self.latent_encoder(x)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from .llvae import LosslessLatentEncoder
|
||||||
|
|
||||||
|
|
||||||
def total_variation(image):
|
def total_variation(image):
|
||||||
@@ -45,3 +46,40 @@ def get_gradient_penalty(critic, real, fake, device):
|
|||||||
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
|
gradient_penalty = ((gradient_norm - 1) ** 2).mean()
|
||||||
return gradient_penalty
|
return gradient_penalty
|
||||||
|
|
||||||
|
|
||||||
|
class PatternLoss(torch.nn.Module):
|
||||||
|
def __init__(self, pattern_size=4, dtype=torch.float32):
|
||||||
|
super().__init__()
|
||||||
|
self.pattern_size = pattern_size
|
||||||
|
self.llvae_encoder = LosslessLatentEncoder(3, pattern_size, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, pred, target):
|
||||||
|
pred_latents = self.llvae_encoder(pred)
|
||||||
|
target_latents = self.llvae_encoder(target)
|
||||||
|
|
||||||
|
matrix_pixels = self.pattern_size * self.pattern_size
|
||||||
|
|
||||||
|
color_chans = pred_latents.shape[1] // 3
|
||||||
|
# pytorch
|
||||||
|
r_chans, g_chans, b_chans = torch.split(pred_latents, [color_chans, color_chans, color_chans], 1)
|
||||||
|
r_chans_target, g_chans_target, b_chans_target = torch.split(target_latents, [color_chans, color_chans, color_chans], 1)
|
||||||
|
|
||||||
|
def separated_chan_loss(latent_chan):
|
||||||
|
nonlocal matrix_pixels
|
||||||
|
chan_mean = torch.mean(latent_chan, dim=[1, 2, 3])
|
||||||
|
chan_splits = torch.split(latent_chan, [1 for i in range(matrix_pixels)], 1)
|
||||||
|
chan_loss = None
|
||||||
|
for chan in chan_splits:
|
||||||
|
this_mean = torch.mean(chan, dim=[1, 2, 3])
|
||||||
|
this_chan_loss = torch.abs(this_mean - chan_mean)
|
||||||
|
if chan_loss is None:
|
||||||
|
chan_loss = this_chan_loss
|
||||||
|
else:
|
||||||
|
chan_loss = chan_loss + this_chan_loss
|
||||||
|
chan_loss = chan_loss * (1 / matrix_pixels)
|
||||||
|
return chan_loss
|
||||||
|
|
||||||
|
r_chan_loss = torch.abs(separated_chan_loss(r_chans) - separated_chan_loss(r_chans_target))
|
||||||
|
g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target))
|
||||||
|
b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target))
|
||||||
|
return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333
|
||||||
|
|||||||
@@ -6,13 +6,16 @@ def get_optimizer(
|
|||||||
optimizer_type='adam',
|
optimizer_type='adam',
|
||||||
learning_rate=1e-6
|
learning_rate=1e-6
|
||||||
):
|
):
|
||||||
if optimizer_type == 'dadaptation':
|
lower_type = optimizer_type.lower()
|
||||||
|
if lower_type == 'dadaptation':
|
||||||
# dadaptation optimizer does not use standard learning rate. 1 is the default value
|
# dadaptation optimizer does not use standard learning rate. 1 is the default value
|
||||||
import dadaptation
|
import dadaptation
|
||||||
print("Using DAdaptAdam optimizer")
|
print("Using DAdaptAdam optimizer")
|
||||||
optimizer = dadaptation.DAdaptAdam(params, lr=1.0)
|
optimizer = dadaptation.DAdaptAdam(params, lr=1.0)
|
||||||
elif optimizer_type == 'adam':
|
elif lower_type == 'adam':
|
||||||
optimizer = torch.optim.Adam(params, lr=float(learning_rate))
|
optimizer = torch.optim.Adam(params, lr=float(learning_rate))
|
||||||
|
elif lower_type == 'adamw':
|
||||||
|
optimizer = torch.optim.AdamW(params, lr=float(learning_rate))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
raise ValueError(f'Unknown optimizer type {optimizer_type}')
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|||||||
Reference in New Issue
Block a user