From aa132518778bce4f319daa4d43b210dfb6bc6c8b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 20 Jul 2023 15:46:23 -0600 Subject: [PATCH] Actually addind the Process class..... --- jobs/process/TrainVAEProcess.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 90efff03..f2e0a518 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -15,7 +15,7 @@ from torchvision.transforms import transforms from jobs.process import BaseTrainProcess from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.data_loader import ImageDataset -from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss from toolkit.metadata import get_meta_for_safetensors from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses @@ -196,6 +196,7 @@ class TrainVAEProcess(BaseTrainProcess): self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) + self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) self.first_step = 0 self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) @@ -241,6 +242,8 @@ class TrainVAEProcess(BaseTrainProcess): if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) + self._pattern_loss = None + def update_training_metadata(self): self.add_meta(OrderedDict({"training_info": self.get_training_info()})) @@ -349,6 +352,12 @@ class TrainVAEProcess(BaseTrainProcess): else: return torch.tensor(0.0, device=self.device) + def get_pattern_loss(self, pred, target): + if self._pattern_loss is None: + self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device, dtype=self.torch_dtype) + loss = torch.mean(self._pattern_loss(pred, target)) + return loss + def save(self, step=None): if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) @@ -534,6 +543,7 @@ class TrainVAEProcess(BaseTrainProcess): "mse": [], "kl": [], "tv": [], + "ptn": [], "crD": [], "crG": [], }) @@ -573,12 +583,13 @@ class TrainVAEProcess(BaseTrainProcess): kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight + pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight if self.use_critic: critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight else: critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) - loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss # Backward pass and optimization optimizer.zero_grad() @@ -600,6 +611,8 @@ class TrainVAEProcess(BaseTrainProcess): loss_string += f" mse: {mse_loss.item():.2e}" if self.tv_weight > 0: loss_string += f" tv: {tv_loss.item():.2e}" + if self.pattern_weight > 0: + loss_string += f" ptn: {pattern_loss.item():.2e}" if self.use_critic and self.critic_weight > 0: loss_string += f" crG: {critic_gen_loss.item():.2e}" if self.use_critic: @@ -628,6 +641,7 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses["mse"].append(mse_loss.item()) epoch_losses["kl"].append(kld_loss.item()) epoch_losses["tv"].append(tv_loss.item()) + epoch_losses["ptn"].append(pattern_loss.item()) epoch_losses["crG"].append(critic_gen_loss.item()) epoch_losses["crD"].append(critic_d_loss) @@ -637,6 +651,7 @@ class TrainVAEProcess(BaseTrainProcess): log_losses["mse"].append(mse_loss.item()) log_losses["kl"].append(kld_loss.item()) log_losses["tv"].append(tv_loss.item()) + log_losses["ptn"].append(pattern_loss.item()) log_losses["crG"].append(critic_gen_loss.item()) log_losses["crD"].append(critic_d_loss)