mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added flux training. Still a WIP. Wont train right without rectified flow working right
This commit is contained in:
@@ -1234,7 +1234,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
if self.sd.is_flux:
|
||||
unet.gradient_checkpointing = True
|
||||
else:
|
||||
unet.enable_gradient_checkpointing()
|
||||
if isinstance(text_encoder, list):
|
||||
for te in text_encoder:
|
||||
if hasattr(te, 'enable_gradient_checkpointing'):
|
||||
@@ -1325,6 +1328,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
is_v3=self.model_config.is_v3,
|
||||
is_pixart=self.model_config.is_pixart,
|
||||
is_auraflow=self.model_config.is_auraflow,
|
||||
is_flux=self.model_config.is_flux,
|
||||
is_ssd=self.model_config.is_ssd,
|
||||
is_vega=self.model_config.is_vega,
|
||||
dropout=self.network_config.dropout,
|
||||
|
||||
Reference in New Issue
Block a user