Added training for pixart-a

This commit is contained in:
Jaret Burkett
2024-02-13 16:00:04 -07:00
parent 4ec4025cbb
commit 93b52932c1
10 changed files with 288 additions and 24 deletions

View File

@@ -1054,7 +1054,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
}
},
'sd' if not self.model_config.is_pixart else 'pixart'
)
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: