merged in lumina2 branch

This commit is contained in:
Jaret Burkett
2025-02-12 09:33:03 -07:00
11 changed files with 986 additions and 17 deletions

View File

@@ -335,6 +335,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
o_dict['ss_base_model_version'] = 'sdxl_1.0'
elif self.model_config.is_flux:
o_dict['ss_base_model_version'] = 'flux.1'
elif self.model_config.is_lumina2:
o_dict['ss_base_model_version'] = 'lumina2'
else:
o_dict['ss_base_model_version'] = 'sd_1.5'
@@ -1392,12 +1394,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
arch = 'pixart'
if self.model_config.is_flux:
arch = 'flux'
if self.model_config.is_lumina2:
arch = 'lumina2'
sampler = get_sampler(
self.train_config.noise_scheduler,
{
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
},
arch
arch=arch,
)
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
@@ -1457,10 +1461,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
# print_acc("sage attention is not installed. Using SDP instead")
if self.train_config.gradient_checkpointing:
<<<<<<< HEAD
=======
# if has method enable_gradient_checkpointing
>>>>>>> lumina2
if hasattr(unet, 'enable_gradient_checkpointing'):
unet.enable_gradient_checkpointing()
elif hasattr(unet, 'gradient_checkpointing'):
unet.gradient_checkpointing = True
<<<<<<< HEAD
=======
else:
print("Gradient checkpointing not supported on this model")
>>>>>>> lumina2
if isinstance(text_encoder, list):
for te in text_encoder:
if hasattr(te, 'enable_gradient_checkpointing'):
@@ -1552,6 +1565,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_pixart=self.model_config.is_pixart,
is_auraflow=self.model_config.is_auraflow,
is_flux=self.model_config.is_flux,
is_lumina2=self.model_config.is_lumina2,
is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout,
@@ -2170,6 +2184,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
tags.append("stable-diffusion-xl")
if self.model_config.is_flux:
tags.append("flux")
if self.model_config.is_lumina2:
tags.append("lumina2")
if self.model_config.is_v3:
tags.append("sd3")
if self.network_config: