Actually addind the Process class.....

This commit is contained in:
Jaret Burkett
2023-07-20 15:46:23 -06:00
parent 0761656a90
commit aa13251877

View File

@@ -15,7 +15,7 @@ from torchvision.transforms import transforms
from jobs.process import BaseTrainProcess
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
@@ -196,6 +196,7 @@ class TrainVAEProcess(BaseTrainProcess):
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
self.first_step = 0
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
@@ -241,6 +242,8 @@ class TrainVAEProcess(BaseTrainProcess):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
self._pattern_loss = None
def update_training_metadata(self):
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
@@ -349,6 +352,12 @@ class TrainVAEProcess(BaseTrainProcess):
else:
return torch.tensor(0.0, device=self.device)
def get_pattern_loss(self, pred, target):
if self._pattern_loss is None:
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device, dtype=self.torch_dtype)
loss = torch.mean(self._pattern_loss(pred, target))
return loss
def save(self, step=None):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -534,6 +543,7 @@ class TrainVAEProcess(BaseTrainProcess):
"mse": [],
"kl": [],
"tv": [],
"ptn": [],
"crD": [],
"crG": [],
})
@@ -573,12 +583,13 @@ class TrainVAEProcess(BaseTrainProcess):
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
if self.use_critic:
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
else:
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
# Backward pass and optimization
optimizer.zero_grad()
@@ -600,6 +611,8 @@ class TrainVAEProcess(BaseTrainProcess):
loss_string += f" mse: {mse_loss.item():.2e}"
if self.tv_weight > 0:
loss_string += f" tv: {tv_loss.item():.2e}"
if self.pattern_weight > 0:
loss_string += f" ptn: {pattern_loss.item():.2e}"
if self.use_critic and self.critic_weight > 0:
loss_string += f" crG: {critic_gen_loss.item():.2e}"
if self.use_critic:
@@ -628,6 +641,7 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["mse"].append(mse_loss.item())
epoch_losses["kl"].append(kld_loss.item())
epoch_losses["tv"].append(tv_loss.item())
epoch_losses["ptn"].append(pattern_loss.item())
epoch_losses["crG"].append(critic_gen_loss.item())
epoch_losses["crD"].append(critic_d_loss)
@@ -637,6 +651,7 @@ class TrainVAEProcess(BaseTrainProcess):
log_losses["mse"].append(mse_loss.item())
log_losses["kl"].append(kld_loss.item())
log_losses["tv"].append(tv_loss.item())
log_losses["ptn"].append(pattern_loss.item())
log_losses["crG"].append(critic_gen_loss.item())
log_losses["crD"].append(critic_d_loss)