Added this not that guidance. Added ability to replace prompts.

This commit is contained in:
Jaret Burkett
2024-02-28 20:10:14 -07:00
parent 561914d8e6
commit 337945de9a
7 changed files with 114 additions and 5 deletions

View File

@@ -229,9 +229,13 @@ class SDTrainer(BaseSDTrainProcess):
prior_mask_multiplier = None
target_mask_multiplier = None
dtype = get_torch_dtype(self.train_config.dtype)
has_mask = batch.mask_tensor is not None
with torch.no_grad():
loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32)
if self.train_config.match_noise_norm:
# match the norm of the noise
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
@@ -364,6 +368,8 @@ class SDTrainer(BaseSDTrainProcess):
# loss = loss + prior_loss
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
# apply loss multiplier before prior loss
loss = loss * loss_multiplier
if prior_loss is not None:
loss = loss + prior_loss