Added training for pixart-a

This commit is contained in:
Jaret Burkett
2024-02-13 16:00:04 -07:00
parent 4ec4025cbb
commit 93b52932c1
10 changed files with 288 additions and 24 deletions

View File

@@ -225,6 +225,9 @@ class SDTrainer(BaseSDTrainProcess):
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
if self.train_config.pred_scaler != 1.0:
noise_pred = noise_pred * self.train_config.pred_scaler
target = None
if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask):
if self.train_config.correct_pred_norm and not is_reg:
@@ -343,7 +346,8 @@ class SDTrainer(BaseSDTrainProcess):
print("Prior loss is nan")
prior_loss = None
else:
prior_loss = prior_loss.mean([1, 2, 3])
# prior_loss = prior_loss.mean([1, 2, 3])
loss = loss + prior_loss
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if prior_loss is not None: