Added flux training. Still a WIP. Wont train right without rectified flow working right

This commit is contained in:
Jaret Burkett
2024-08-02 15:00:30 -06:00
parent 03613c523f
commit 87ba867fdc
6 changed files with 292 additions and 15 deletions

View File

@@ -1234,7 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
torch.backends.cuda.enable_mem_efficient_sdp(True)
if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if self.sd.is_flux:
unet.gradient_checkpointing = True
else:
unet.enable_gradient_checkpointing()
if isinstance(text_encoder, list):
for te in text_encoder:
if hasattr(te, 'enable_gradient_checkpointing'):
@@ -1325,6 +1328,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_v3=self.model_config.is_v3,
is_pixart=self.model_config.is_pixart,
is_auraflow=self.model_config.is_auraflow,
is_flux=self.model_config.is_flux,
is_ssd=self.model_config.is_ssd,
is_vega=self.model_config.is_vega,
dropout=self.network_config.dropout,