mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Imitial lumina3 support
This commit is contained in:
@@ -335,6 +335,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
o_dict['ss_base_model_version'] = 'sdxl_1.0'
|
||||
elif self.model_config.is_flux:
|
||||
o_dict['ss_base_model_version'] = 'flux.1'
|
||||
elif self.model_config.is_lumina2:
|
||||
o_dict['ss_base_model_version'] = 'lumina2'
|
||||
else:
|
||||
o_dict['ss_base_model_version'] = 'sd_1.5'
|
||||
|
||||
@@ -1387,12 +1389,19 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
|
||||
# get the noise scheduler
|
||||
arch = 'sd'
|
||||
if self.model_config.is_pixart:
|
||||
arch = 'pixart'
|
||||
if self.model_config.is_flux:
|
||||
arch = 'flux'
|
||||
if self.model_config.is_lumina2:
|
||||
arch = 'lumina2'
|
||||
sampler = get_sampler(
|
||||
self.train_config.noise_scheduler,
|
||||
{
|
||||
"prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon",
|
||||
},
|
||||
'sd' if not self.model_config.is_pixart else 'pixart'
|
||||
arch=arch,
|
||||
)
|
||||
|
||||
if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None:
|
||||
@@ -1452,10 +1461,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# print_acc("sage attention is not installed. Using SDP instead")
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
if self.sd.is_flux:
|
||||
# if has method enable_gradient_checkpointing
|
||||
if hasattr(unet, 'enable_gradient_checkpointing'):
|
||||
unet.enable_gradient_checkpointing()
|
||||
elif hasattr(unet, 'gradient_checkpointing'):
|
||||
unet.gradient_checkpointing = True
|
||||
else:
|
||||
unet.enable_gradient_checkpointing()
|
||||
print("Gradient checkpointing not supported on this model")
|
||||
if isinstance(text_encoder, list):
|
||||
for te in text_encoder:
|
||||
if hasattr(te, 'enable_gradient_checkpointing'):
|
||||
@@ -1547,6 +1559,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
is_pixart=self.model_config.is_pixart,
|
||||
is_auraflow=self.model_config.is_auraflow,
|
||||
is_flux=self.model_config.is_flux,
|
||||
is_lumina2=self.model_config.is_lumina2,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
is_vega=self.model_config.is_vega,
|
||||
dropout=self.network_config.dropout,
|
||||
@@ -2165,6 +2178,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
tags.append("stable-diffusion-xl")
|
||||
if self.model_config.is_flux:
|
||||
tags.append("flux")
|
||||
if self.model_config.is_lumina2:
|
||||
tags.append("lumina2")
|
||||
if self.model_config.is_v3:
|
||||
tags.append("sd3")
|
||||
if self.network_config:
|
||||
|
||||
Reference in New Issue
Block a user