mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Added training for pixart-a
This commit is contained in:
@@ -225,6 +225,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||
|
||||
if self.train_config.pred_scaler != 1.0:
|
||||
noise_pred = noise_pred * self.train_config.pred_scaler
|
||||
|
||||
target = None
|
||||
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
|
||||
if self.train_config.correct_pred_norm and not is_reg:
|
||||
@@ -343,7 +346,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
print("Prior loss is nan")
|
||||
prior_loss = None
|
||||
else:
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
# prior_loss = prior_loss.mean([1, 2, 3])
|
||||
loss = loss + prior_loss
|
||||
# loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
if prior_loss is not None:
|
||||
|
||||
Reference in New Issue
Block a user