mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added flux training. Still a WIP. Wont train right without rectified flow working right
This commit is contained in:
@@ -159,6 +159,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
is_v3=False,
|
||||
is_pixart: bool = False,
|
||||
is_auraflow: bool = False,
|
||||
is_flux: bool = False,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
ignore_if_contains = None,
|
||||
@@ -216,6 +217,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.is_v3 = is_v3
|
||||
self.is_pixart = is_pixart
|
||||
self.is_auraflow = is_auraflow
|
||||
self.is_flux = is_flux
|
||||
self.network_type = network_type
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
@@ -250,7 +252,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[LoRAModule]:
|
||||
unet_prefix = self.LORA_PREFIX_UNET
|
||||
if is_pixart or is_v3 or is_auraflow:
|
||||
if is_pixart or is_v3 or is_auraflow or is_flux:
|
||||
unet_prefix = f"lora_transformer"
|
||||
|
||||
prefix = (
|
||||
@@ -293,6 +295,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if self.transformer_only and self.is_pixart and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
if self.transformer_only and self.is_flux and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
|
||||
@@ -393,6 +398,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if is_auraflow:
|
||||
target_modules = ["AuraFlowTransformer2DModel"]
|
||||
|
||||
if is_flux:
|
||||
target_modules = ["FluxTransformer2DModel"]
|
||||
|
||||
if train_unet:
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
else:
|
||||
@@ -454,7 +462,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
||||
|
||||
if self.full_train_in_out:
|
||||
if self.is_pixart or self.is_auraflow:
|
||||
if self.is_pixart or self.is_auraflow or self.is_flux:
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user