mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added this not that guidance. Added ability to replace prompts.
This commit is contained in:
@@ -229,9 +229,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
prior_mask_multiplier = None
|
||||
target_mask_multiplier = None
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
has_mask = batch.mask_tensor is not None
|
||||
|
||||
with torch.no_grad():
|
||||
loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32)
|
||||
|
||||
if self.train_config.match_noise_norm:
|
||||
# match the norm of the noise
|
||||
noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
@@ -364,6 +368,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# loss = loss + prior_loss
|
||||
# loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
# apply loss multiplier before prior loss
|
||||
loss = loss * loss_multiplier
|
||||
if prior_loss is not None:
|
||||
loss = loss + prior_loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user