From 557732e7ff32fd2e79ee246fd9f7721a9746a376 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 19 Jul 2023 15:57:32 -0600 Subject: [PATCH] Added Critic support to VAE training. Still tweaking and working on it. Many other fixes --- jobs/TrainJob.py | 12 +- jobs/process/BaseProcess.py | 18 +- jobs/process/TrainVAEProcess.py | 285 ++++++++++++++++++++++++---- jobs/process/models/vgg19_critic.py | 38 ++++ toolkit/config.py | 17 +- toolkit/losses.py | 26 ++- toolkit/metadata.py | 12 +- toolkit/optimizer.py | 18 ++ toolkit/style.py | 48 +++-- 9 files changed, 415 insertions(+), 59 deletions(-) create mode 100644 jobs/process/models/vgg19_critic.py create mode 100644 toolkit/optimizer.py diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index 5f971462..2f5c66ef 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -1,8 +1,11 @@ +import os + from jobs import BaseJob from toolkit.kohya_model_util import load_models_from_stable_diffusion_checkpoint from collections import OrderedDict from typing import List from jobs.process import BaseExtractProcess, TrainFineTuneProcess +from datetime import datetime from toolkit.paths import REPOS_ROOT @@ -45,7 +48,8 @@ class TrainJob(BaseJob): def setup_tensorboard(self): if self.log_dir: from torch.utils.tensorboard import SummaryWriter - self.writer = SummaryWriter( - log_dir=self.log_dir, - filename_suffix=f"_{self.name}" - ) + now = datetime.now() + time_str = now.strftime('%Y%m%d-%H%M%S') + summary_name = f"{self.name}_{time_str}" + summary_dir = os.path.join(self.log_dir, summary_name) + self.writer = SummaryWriter(summary_dir) diff --git a/jobs/process/BaseProcess.py b/jobs/process/BaseProcess.py index e9c32360..821eebaa 100644 --- a/jobs/process/BaseProcess.py +++ b/jobs/process/BaseProcess.py @@ -17,11 +17,23 @@ class BaseProcess: self.job = job self.config = config self.meta = copy.deepcopy(self.job.meta) + print(json.dumps(self.config, indent=4)) def get_conf(self, key, default=None, required=False, as_type=None): - if key in self.config: - value = self.config[key] - if as_type is not None and value is not None: + # split key by '.' and recursively get the value + keys = key.split('.') + + # see if it exists in the config + value = self.config + for subkey in keys: + if subkey in value: + value = value[subkey] + else: + value = None + break + + if value is not None: + if as_type is not None: value = as_type(value) return value elif required: diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 5804ded7..ef526b7d 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -6,7 +6,7 @@ from collections import OrderedDict from PIL import Image from PIL.ImageOps import exif_transpose -from safetensors.torch import save_file +from safetensors.torch import save_file, load_file from torch.utils.data import DataLoader, ConcatDataset import torch from torch import nn @@ -15,8 +15,9 @@ from torchvision.transforms import transforms from jobs.process import BaseTrainProcess from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.data_loader import ImageDataset -from toolkit.losses import ComparativeTotalVariation +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype from diffusers import AutoencoderKL @@ -36,6 +37,139 @@ def unnormalize(tensor): return (tensor / 2 + 0.5).clamp(0, 1) +class Critic: + process: 'TrainVAEProcess' + + def __init__( + self, + learning_rate=1e-5, + device='cpu', + optimizer='adam', + num_critic_per_gen=1, + dtype='float32', + lambda_gp=10, + start_step=0, + warmup_steps=1000, + process=None + ): + self.learning_rate = learning_rate + self.device = device + self.optimizer_type = optimizer + self.num_critic_per_gen = num_critic_per_gen + self.dtype = dtype + self.torch_dtype = get_torch_dtype(self.dtype) + self.process = process + self.model = None + self.optimizer = None + self.scheduler = None + self.warmup_steps = warmup_steps + self.start_step = start_step + self.lambda_gp = lambda_gp + self.print = self.process.print + print(f" Critic config: {self.__dict__}") + + def setup(self): + from .models.vgg19_critic import Vgg19Critic + self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype) + self.load_weights() + self.model.train() + self.model.requires_grad_(True) + params = self.model.parameters() + self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate) + self.scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, + total_iters=self.process.max_steps * self.num_critic_per_gen, + factor=1, + verbose=False + ) + + def load_weights(self): + path_to_load = None + self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}") + files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + else: + self.print(f" - No checkpoint found, starting from scratch") + if path_to_load: + self.model.load_state_dict(load_file(path_to_load)) + + def save(self, step=None): + self.process.update_training_metadata() + save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name) + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors") + save_file(self.model.state_dict(), save_path, save_meta) + self.print(f"Saved critic to {save_path}") + + def get_critic_loss(self, vgg_output): + if self.start_step > self.process.step_num: + return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device) + + warmup_scaler = 1.0 + # we need a warmup when we come on of 1000 steps + # we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps + if self.process.step_num < self.start_step + self.warmup_steps: + warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps + # set model to not train for generator loss + self.model.eval() + self.model.requires_grad_(False) + vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) + + # run model + stacked_output = self.model(vgg_pred) + + return (-torch.mean(stacked_output)) * warmup_scaler + + def step(self, vgg_output): + + # train critic here + self.model.train() + self.model.requires_grad_(True) + + critic_losses = [] + for i in range(self.num_critic_per_gen): + inputs = vgg_output.detach() + inputs = inputs.to(self.device, dtype=self.torch_dtype) + self.optimizer.zero_grad() + + vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + + stacked_output = self.model(inputs) + out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # Compute gradient penalty + gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + + # Compute WGAN-GP critic loss + critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty + critic_loss.backward() + self.optimizer.zero_grad() + self.optimizer.step() + self.scheduler.step() + critic_losses.append(critic_loss.item()) + + # avg loss + loss = np.mean(critic_losses) + return loss + + def get_lr(self): + if self.optimizer_type.startswith('dadaptation'): + learning_rate = ( + self.optimizer.param_groups[0]["d"] * + self.optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = self.optimizer.param_groups[0]['lr'] + + return learning_rate + + class TrainVAEProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) @@ -61,6 +195,7 @@ class TrainVAEProcess(BaseTrainProcess): self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) + self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.writer = self.job.writer @@ -68,6 +203,22 @@ class TrainVAEProcess(BaseTrainProcess): self.save_root = os.path.join(self.training_folder, self.job.name) self.vgg_19 = None self.progress_bar = None + self.style_weight_scalers = [] + self.content_weight_scalers = [] + + self.step_num = 0 + self.epoch_num = 0 + + self.use_critic = self.get_conf('use_critic', False, as_type=bool) + self.critic = None + + if self.use_critic: + self.critic = Critic( + device=self.device, + dtype=self.dtype, + process=self, + **self.get_conf('critic', {}) # pass any other params + ) if self.sample_every is not None and self.sample_sources is None: raise ValueError('sample_every is specified but sample_sources is not') @@ -89,6 +240,16 @@ class TrainVAEProcess(BaseTrainProcess): if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) + def update_training_metadata(self): + self.add_meta(OrderedDict({"training_info": self.get_training_info()})) + + def get_training_info(self): + info = OrderedDict({ + 'step': self.step_num, + 'epoch': self.epoch_num, + }) + return info + def print(self, message, **kwargs): if self.progress_bar is not None: self.progress_bar.write(message, **kwargs) @@ -117,19 +278,46 @@ class TrainVAEProcess(BaseTrainProcess): def setup_vgg19(self): if self.vgg_19 is None: - self.vgg_19, self.style_losses, self.content_losses, output = get_style_model_and_losses( - single_target=True, device=self.device) + self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses( + single_target=True, + device=self.device, + output_layer_name='pool_4', + dtype=self.torch_dtype + ) + self.vgg_19.to(self.device, dtype=self.torch_dtype) self.vgg_19.requires_grad_(False) + # we run random noise through first to get layer scalers to normalize the loss per layer + # bs of 2 because we run pred and target through stacked + noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype) + self.vgg_19(noise) + for style_loss in self.style_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(style_loss.loss).item() + self.style_weight_scalers.append(scaler) + for content_loss in self.content_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(content_loss.loss).item() + self.content_weight_scalers.append(scaler) + + self.print(f"Style weight scalers: {self.style_weight_scalers}") + self.print(f"Content weight scalers: {self.content_weight_scalers}") + def get_style_loss(self): if self.style_weight > 0: - return torch.sum(torch.stack([loss.loss for loss in self.style_losses])) + # scale all losses with loss scalers + loss = torch.sum( + torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)])) + return loss else: return torch.tensor(0.0, device=self.device) def get_content_loss(self): if self.content_weight > 0: - return torch.sum(torch.stack([loss.loss for loss in self.content_losses])) + # scale all losses with loss scalers + loss = torch.sum(torch.stack( + [loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)])) + return loss else: return torch.tensor(0.0, device=self.device) @@ -160,7 +348,6 @@ class TrainVAEProcess(BaseTrainProcess): else: return torch.tensor(0.0, device=self.device) - def save(self, step=None): if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) @@ -170,6 +357,7 @@ class TrainVAEProcess(BaseTrainProcess): # zeropad 9 digits step_num = f"_{str(step).zfill(9)}" + self.update_training_metadata() filename = f'{self.job.name}{step_num}.safetensors' # prepare meta save_meta = get_meta_for_safetensors(self.meta, self.job.name) @@ -184,7 +372,10 @@ class TrainVAEProcess(BaseTrainProcess): # having issues with meta save_file(state_dict, os.path.join(self.save_root, filename), save_meta) - print(f"Saved to {os.path.join(self.save_root, filename)}") + self.print(f"Saved to {os.path.join(self.save_root, filename)}") + + if self.use_critic: + self.critic.save(step) def sample(self, step=None): sample_folder = os.path.join(self.save_root, 'samples') @@ -268,6 +459,9 @@ class TrainVAEProcess(BaseTrainProcess): num_steps = self.max_steps if num_steps is None or num_steps > max_epoch_steps: num_steps = max_epoch_steps + self.max_steps = num_steps + self.epochs = num_epochs + start_step = self.step_num self.print(f"Training VAE") self.print(f" - Training folder: {self.training_folder}") @@ -304,18 +498,14 @@ class TrainVAEProcess(BaseTrainProcess): params += list(self.vae.decoder.conv_out.parameters()) self.vae.decoder.conv_out.requires_grad_(True) - if self.style_weight > 0 or self.content_weight > 0: + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: self.setup_vgg19() self.vgg_19.requires_grad_(False) self.vgg_19.eval() + if self.use_critic: + self.critic.setup() - # todo allow other optimizers - if self.optimizer_type == 'dadaptation': - import dadaptation - print("Using DAdaptAdam optimizer") - optimizer = dadaptation.DAdaptAdam(params, lr=1) - else: - optimizer = torch.optim.Adam(params, lr=float(self.learning_rate)) + optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate) # setup scheduler # todo allow other schedulers @@ -333,7 +523,6 @@ class TrainVAEProcess(BaseTrainProcess): leave=True ) - step = 0 # sample first self.sample() blank_losses = OrderedDict({ @@ -343,15 +532,17 @@ class TrainVAEProcess(BaseTrainProcess): "mse": [], "kl": [], "tv": [], + "crD": [], + "crG": [], }) epoch_losses = copy.deepcopy(blank_losses) log_losses = copy.deepcopy(blank_losses) - - for epoch in range(num_epochs): - if step >= num_steps: + # range start at self.epoch_num go to self.epochs + for epoch in range(self.epoch_num, self.epochs, 1): + if self.step_num >= self.max_steps: break for batch in self.data_loader: - if step >= num_steps: + if self.step_num >= self.max_steps: break batch = batch.to(self.device, dtype=self.torch_dtype) @@ -365,18 +556,27 @@ class TrainVAEProcess(BaseTrainProcess): pred = self.vae.decode(latents).sample # Run through VGG19 - if self.style_weight > 0 or self.content_weight > 0: + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: stacked = torch.cat([pred, batch], dim=0) stacked = (stacked / 2 + 0.5).clamp(0, 1) self.vgg_19(stacked) + if self.use_critic: + critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach()) + else: + critic_d_loss = 0.0 + style_loss = self.get_style_loss() * self.style_weight content_loss = self.get_content_loss() * self.content_weight kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight + if self.use_critic: + critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight + else: + critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) - loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss # Backward pass and optimization optimizer.zero_grad() @@ -398,6 +598,10 @@ class TrainVAEProcess(BaseTrainProcess): loss_string += f" mse: {mse_loss.item():.2e}" if self.tv_weight > 0: loss_string += f" tv: {tv_loss.item():.2e}" + if self.use_critic and self.critic_weight > 0: + loss_string += f" crG: {critic_gen_loss.item():.2e}" + if self.use_critic: + loss_string += f" crD: {critic_d_loss:.2e}" if self.optimizer_type.startswith('dadaptation'): learning_rate = ( @@ -406,7 +610,13 @@ class TrainVAEProcess(BaseTrainProcess): ) else: learning_rate = optimizer.param_groups[0]['lr'] - self.progress_bar.set_postfix_str(f"LR: {learning_rate:.2e} {loss_string}") + + lr_critic_string = '' + if self.use_critic: + lr_critic = self.critic.get_lr() + lr_critic_string = f" lrC: {lr_critic:.1e}" + + self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}") self.progress_bar.set_description(f"E: {epoch}") self.progress_bar.update(1) @@ -416,6 +626,8 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses["mse"].append(mse_loss.item()) epoch_losses["kl"].append(kld_loss.item()) epoch_losses["tv"].append(tv_loss.item()) + epoch_losses["crG"].append(critic_gen_loss.item()) + epoch_losses["crD"].append(critic_d_loss) log_losses["total"].append(loss_value) log_losses["style"].append(style_loss.item()) @@ -423,30 +635,33 @@ class TrainVAEProcess(BaseTrainProcess): log_losses["mse"].append(mse_loss.item()) log_losses["kl"].append(kld_loss.item()) log_losses["tv"].append(tv_loss.item()) + log_losses["crG"].append(critic_gen_loss.item()) + log_losses["crD"].append(critic_d_loss) - if step != 0: - if self.sample_every and step % self.sample_every == 0: + # don't do on first step + if self.step_num != start_step: + if self.sample_every and self.step_num % self.sample_every == 0: # print above the progress bar - self.print(f"Sampling at step {step}") - self.sample(step) + self.print(f"Sampling at step {self.step_num}") + self.sample(self.step_num) - if self.save_every and step % self.save_every == 0: + if self.save_every and self.step_num % self.save_every == 0: # print above the progress bar - self.print(f"Saving at step {step}") - self.save(step) + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) - if self.log_every and step % self.log_every == 0: + if self.log_every and self.step_num % self.log_every == 0: # log to tensorboard if self.writer is not None: # get avg loss for key in log_losses: log_losses[key] = sum(log_losses[key]) / len(log_losses[key]) - if log_losses[key] > 0: - self.writer.add_scalar(f"loss/{key}", log_losses[key], step) + # if log_losses[key] > 0: + self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num) # reset log losses log_losses = copy.deepcopy(blank_losses) - step += 1 + self.step_num += 1 # end epoch if self.writer is not None: # get avg loss diff --git a/jobs/process/models/vgg19_critic.py b/jobs/process/models/vgg19_critic.py new file mode 100644 index 00000000..808d63b8 --- /dev/null +++ b/jobs/process/models/vgg19_critic.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn + + +class MeanReduce(nn.Module): + def __init__(self): + super(MeanReduce, self).__init__() + + def forward(self, inputs): + return torch.mean(inputs, dim=(1, 2, 3), keepdim=True) + + +class Vgg19Critic(nn.Module): + def __init__(self): + # vgg19 input (bs, 3, 512, 512) + # pool1 (bs, 64, 256, 256) + # pool2 (bs, 128, 128, 128) + # pool3 (bs, 256, 64, 64) + # pool4 (bs, 512, 32, 32) <- take this input + + super(Vgg19Critic, self).__init__() + self.main = nn.Sequential( + # input (bs, 512, 32, 32) + nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2), # (bs, 512, 16, 16) + nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2), # (bs, 512, 8, 8) + nn.Conv2d(512, 1, kernel_size=3, stride=2, padding=1), + # (bs, 1, 4, 4) + MeanReduce(), # (bs, 1, 1, 1) + nn.Flatten(), # (bs, 1) + + # nn.Flatten(), # (128*8*8) = 8192 + # nn.Linear(128 * 8 * 8, 1) + ) + + def forward(self, inputs): + return self.main(inputs) diff --git a/toolkit/config.py b/toolkit/config.py index d0d3f15a..9d51c3be 100644 --- a/toolkit/config.py +++ b/toolkit/config.py @@ -1,6 +1,7 @@ import os import json import oyaml as yaml +import re from collections import OrderedDict from toolkit.paths import TOOLKIT_ROOT @@ -29,6 +30,20 @@ def preprocess_config(config: OrderedDict): return config + +# Fixes issue where yaml doesnt load exponents correctly +fixed_loader = yaml.SafeLoader +fixed_loader.add_implicit_resolver( + u'tag:yaml.org,2002:float', + re.compile(u'''^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$''', re.X), + list(u'-+0123456789.')) + def get_config(config_file_path): # first check if it is in the config folder config_path = os.path.join(TOOLKIT_ROOT, 'config', config_file_path) @@ -56,7 +71,7 @@ def get_config(config_file_path): config = json.load(f, object_pairs_hook=OrderedDict) elif real_config_path.endswith('.yaml') or real_config_path.endswith('.yml'): with open(real_config_path, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) + config = yaml.load(f, Loader=fixed_loader) else: raise ValueError(f"Config file {config_file_path} must be a json or yaml file") diff --git a/toolkit/losses.py b/toolkit/losses.py index 9c0ae097..9158c505 100644 --- a/toolkit/losses.py +++ b/toolkit/losses.py @@ -11,7 +11,7 @@ def total_variation(image): """ n_elements = image.shape[1] * image.shape[2] * image.shape[3] return ((torch.sum(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + - torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements) + torch.sum(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))) / n_elements) class ComparativeTotalVariation(torch.nn.Module): @@ -21,3 +21,27 @@ class ComparativeTotalVariation(torch.nn.Module): def forward(self, pred, target): return torch.abs(total_variation(pred) - total_variation(target)) + + +# Gradient penalty +def get_gradient_penalty(critic, real, fake, device): + with torch.autocast(device_type='cuda'): + alpha = torch.rand(real.size(0), 1, 1, 1).to(device) + interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True) + d_interpolates = critic(interpolates) + fake = torch.ones(real.size(0), 1, device=device) + + gradients = torch.autograd.grad( + outputs=d_interpolates, + inputs=interpolates, + grad_outputs=fake, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + + gradients = gradients.view(gradients.size(0), -1) + gradient_norm = gradients.norm(2, dim=1) + gradient_penalty = ((gradient_norm - 1) ** 2).mean() + return gradient_penalty + diff --git a/toolkit/metadata.py b/toolkit/metadata.py index 0a99da70..e5b1ce9e 100644 --- a/toolkit/metadata.py +++ b/toolkit/metadata.py @@ -13,6 +13,16 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None) -> OrderedDict: # safetensors can only be one level deep for key, value in save_meta.items(): # if not float, int, bool, or str, convert to json string - if not isinstance(value, (float, int, bool, str)): + if not isinstance(value, str): save_meta[key] = json.dumps(value) return save_meta + + +def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict: + parsed_meta = OrderedDict() + for key, value in meta.items(): + try: + parsed_meta[key] = json.loads(value) + except json.decoder.JSONDecodeError: + parsed_meta[key] = value + return meta diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py new file mode 100644 index 00000000..f58d90ff --- /dev/null +++ b/toolkit/optimizer.py @@ -0,0 +1,18 @@ +import torch + + +def get_optimizer( + params, + optimizer_type='adam', + learning_rate=1e-6 +): + if optimizer_type == 'dadaptation': + # dadaptation optimizer does not use standard learning rate. 1 is the default value + import dadaptation + print("Using DAdaptAdam optimizer") + optimizer = dadaptation.DAdaptAdam(params, lr=1.0) + elif optimizer_type == 'adam': + optimizer = torch.optim.Adam(params, lr=float(learning_rate)) + else: + raise ValueError(f'Unknown optimizer type {optimizer_type}') + return optimizer diff --git a/toolkit/style.py b/toolkit/style.py index 01fbec73..4282a230 100644 --- a/toolkit/style.py +++ b/toolkit/style.py @@ -21,6 +21,7 @@ class ContentLoss(nn.Module): self.loss = None def forward(self, stacked_input): + if self.single_target: split_size = stacked_input.size()[0] // 2 pred_layer, target_layer = torch.split(stacked_input, split_size, dim=0) @@ -73,6 +74,8 @@ class StyleLoss(nn.Module): self.device = device def forward(self, stacked_input): + input_dtype = stacked_input.dtype + stacked_input = stacked_input.float() if self.single_target: split_size = stacked_input.size()[0] // 2 preds, style_target = torch.split(stacked_input, split_size, dim=0) @@ -94,17 +97,18 @@ class StyleLoss(nn.Module): itemized_loss = torch.unsqueeze(itemized_loss, dim=1) # gram_size = (tf.shape(target_grams)[1] * tf.shape(target_grams)[2]) loss = torch.mean(itemized_loss, dim=(1, 2), keepdim=True) - self.loss = loss - return stacked_input + self.loss = loss.to(input_dtype) + return stacked_input.to(input_dtype) # create a module to normalize input image so we can easily put it in a # ``nn.Sequential`` class Normalization(nn.Module): - def __init__(self, device): + def __init__(self, device, dtype=torch.float32): super(Normalization, self).__init__() mean = torch.tensor([0.485, 0.456, 0.406]).to(device) std = torch.tensor([0.229, 0.224, 0.225]).to(device) + self.dtype = dtype # .view the mean and std to make them [C x 1 x 1] so that they can # directly work with image Tensor of shape [B x C x H x W]. # B is batch size. C is number of channels. H is height and W is width. @@ -112,9 +116,9 @@ class Normalization(nn.Module): self.std = torch.tensor(std).view(-1, 1, 1) def forward(self, stacked_input): - # cast to float 32 if not already - if stacked_input.dtype != torch.float32: - stacked_input = stacked_input.float() + # cast to float 32 if not already # only necessary when processing gram matrix + # if stacked_input.dtype != torch.float32: + # stacked_input = stacked_input.float() # remove alpha channel if it exists if stacked_input.shape[1] == 4: stacked_input = stacked_input[:, :3, :, :] @@ -123,21 +127,37 @@ class Normalization(nn.Module): in_max = torch.max(stacked_input) # norm_stacked_input = (stacked_input - in_min) / (in_max - in_min) # return (norm_stacked_input - self.mean) / self.std - return (stacked_input - self.mean) / self.std + return ((stacked_input - self.mean) / self.std).to(self.dtype) + + +class OutputLayer(nn.Module): + def __init__(self, name='output_layer'): + super(OutputLayer, self).__init__() + self.name = name + self.tensor = None + + def forward(self, stacked_input): + self.tensor = stacked_input + return stacked_input def get_style_model_and_losses( - single_target=False, + single_target=True, # false has 3 targets, dont remember why i added this initially, this is old code device='cuda' if torch.cuda.is_available() else 'cpu', output_layer_name=None, + dtype=torch.float32 ): # content_layers = ['conv_4'] # style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] content_layers = ['conv4_2'] style_layers = ['conv2_1', 'conv3_1', 'conv4_1'] - cnn = models.vgg19(pretrained=True).features.to(device).eval() + cnn = models.vgg19(pretrained=True).features.to(device, dtype=dtype).eval() + # set all weights in the model to our dtype + # for layer in cnn.children(): + # layer.to(dtype=dtype) + # normalization module - normalization = Normalization(device).to(device) + normalization = Normalization(device, dtype=dtype).to(device) # just in order to have an iterable access to or list of content/style # losses @@ -189,15 +209,15 @@ def get_style_model_and_losses( style_losses.append(style_loss) if output_layer_name is not None and name == output_layer_name: - output_layer = layer + output_layer = OutputLayer(name) + model.add_module("output_layer_{}_{}".format(block, i), output_layer) # now we trim off the layers after the last content and style losses for i in range(len(model) - 1, -1, -1): - if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): - break - if output_layer_name is not None and model[i].name == output_layer_name: + if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss) or isinstance(model[i], OutputLayer): break model = model[:(i + 1)] + model.to(dtype=dtype) return model, style_losses, content_losses, output_layer