More guidance work. Improved LoRA module resolver for unet. Added vega mappings and LoRA training for it. Various other bigfixes and changes

This commit is contained in:
Jaret Burkett
2023-12-15 06:02:10 -07:00
parent e5177833b2
commit 39870411d8
14 changed files with 3501 additions and 106 deletions

View File

@@ -913,7 +913,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.load_training_state_from_metadata(latest_save_path)
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
sampler = get_sampler(
self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
}
)
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner')
@@ -1051,6 +1056,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_sdxl=self.model_config.is_xl or self.model_config.is_ssd,
is_v2=self.model_config.is_v2,
is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout,
use_text_encoder_1=self.model_config.use_text_encoder_1,
use_text_encoder_2=self.model_config.use_text_encoder_2,