mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Actually addind the Process class.....
This commit is contained in:
@@ -15,7 +15,7 @@ from torchvision.transforms import transforms
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
@@ -196,6 +196,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
|
||||
self.first_step = 0
|
||||
|
||||
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||
@@ -241,6 +242,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
|
||||
self._pattern_loss = None
|
||||
|
||||
def update_training_metadata(self):
|
||||
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
|
||||
|
||||
@@ -349,6 +352,12 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_pattern_loss(self, pred, target):
|
||||
if self._pattern_loss is None:
|
||||
self._pattern_loss = PatternLoss(pattern_size=8, dtype=self.torch_dtype).to(self.device, dtype=self.torch_dtype)
|
||||
loss = torch.mean(self._pattern_loss(pred, target))
|
||||
return loss
|
||||
|
||||
def save(self, step=None):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
@@ -534,6 +543,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
"mse": [],
|
||||
"kl": [],
|
||||
"tv": [],
|
||||
"ptn": [],
|
||||
"crD": [],
|
||||
"crG": [],
|
||||
})
|
||||
@@ -573,12 +583,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight
|
||||
mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight
|
||||
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
|
||||
pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
|
||||
if self.use_critic:
|
||||
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
|
||||
else:
|
||||
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -600,6 +611,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
loss_string += f" mse: {mse_loss.item():.2e}"
|
||||
if self.tv_weight > 0:
|
||||
loss_string += f" tv: {tv_loss.item():.2e}"
|
||||
if self.pattern_weight > 0:
|
||||
loss_string += f" ptn: {pattern_loss.item():.2e}"
|
||||
if self.use_critic and self.critic_weight > 0:
|
||||
loss_string += f" crG: {critic_gen_loss.item():.2e}"
|
||||
if self.use_critic:
|
||||
@@ -628,6 +641,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
epoch_losses["mse"].append(mse_loss.item())
|
||||
epoch_losses["kl"].append(kld_loss.item())
|
||||
epoch_losses["tv"].append(tv_loss.item())
|
||||
epoch_losses["ptn"].append(pattern_loss.item())
|
||||
epoch_losses["crG"].append(critic_gen_loss.item())
|
||||
epoch_losses["crD"].append(critic_d_loss)
|
||||
|
||||
@@ -637,6 +651,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
log_losses["mse"].append(mse_loss.item())
|
||||
log_losses["kl"].append(kld_loss.item())
|
||||
log_losses["tv"].append(tv_loss.item())
|
||||
log_losses["ptn"].append(pattern_loss.item())
|
||||
log_losses["crG"].append(critic_gen_loss.item())
|
||||
log_losses["crD"].append(critic_d_loss)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user