Added flux training. Still a WIP. Wont train right without rectified flow working right

This commit is contained in:
Jaret Burkett
2024-08-02 15:00:30 -06:00
parent 03613c523f
commit 87ba867fdc
6 changed files with 292 additions and 15 deletions

View File

@@ -159,6 +159,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
is_v3=False,
is_pixart: bool = False,
is_auraflow: bool = False,
is_flux: bool = False,
use_bias: bool = False,
is_lorm: bool = False,
ignore_if_contains = None,
@@ -216,6 +217,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.is_v3 = is_v3
self.is_pixart = is_pixart
self.is_auraflow = is_auraflow
self.is_flux = is_flux
self.network_type = network_type
if self.network_type.lower() == "dora":
self.module_class = DoRAModule
@@ -250,7 +252,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
unet_prefix = self.LORA_PREFIX_UNET
if is_pixart or is_v3 or is_auraflow:
if is_pixart or is_v3 or is_auraflow or is_flux:
unet_prefix = f"lora_transformer"
prefix = (
@@ -293,6 +295,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.transformer_only and self.is_pixart and is_unet:
if "transformer_blocks" not in lora_name:
skip = True
if self.transformer_only and self.is_flux and is_unet:
if "transformer_blocks" not in lora_name:
skip = True
if (is_linear or is_conv2d) and not skip:
@@ -393,6 +398,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if is_auraflow:
target_modules = ["AuraFlowTransformer2DModel"]
if is_flux:
target_modules = ["FluxTransformer2DModel"]
if train_unet:
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
else:
@@ -454,7 +462,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
if self.full_train_in_out:
if self.is_pixart or self.is_auraflow:
if self.is_pixart or self.is_auraflow or self.is_flux:
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
else: