Partial implementation for training auraflow.

This commit is contained in:
Jaret Burkett
2024-07-12 12:11:38 -06:00
parent c062b7716c
commit e4558dff4b
9 changed files with 386 additions and 19 deletions

View File

@@ -7,7 +7,7 @@ import re
import sys
from typing import List, Optional, Dict, Type, Union
import torch
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
from transformers import CLIPTextModel
from .config_modules import NetworkConfig
@@ -158,6 +158,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
is_v2=False,
is_v3=False,
is_pixart: bool = False,
is_auraflow: bool = False,
use_bias: bool = False,
is_lorm: bool = False,
ignore_if_contains = None,
@@ -212,6 +213,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.is_v2 = is_v2
self.is_v3 = is_v3
self.is_pixart = is_pixart
self.is_auraflow = is_auraflow
self.network_type = network_type
if self.network_type.lower() == "dora":
self.module_class = DoRAModule
@@ -246,7 +248,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
unet_prefix = self.LORA_PREFIX_UNET
if is_pixart or is_v3:
if is_pixart or is_v3 or is_auraflow:
unet_prefix = f"lora_transformer"
prefix = (
@@ -371,6 +373,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if is_pixart:
target_modules = ["PixArtTransformer2DModel"]
if is_auraflow:
target_modules = ["AuraFlowTransformer2DModel"]
if train_unet:
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
else:
@@ -408,6 +413,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
transformer.pos_embed = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
elif self.is_auraflow:
transformer: AuraFlowTransformer2DModel = unet
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
transformer.pos_embed = self.transformer_pos_embed
transformer.proj_out = self.transformer_proj_out
else:
unet: UNet2DConditionModel = unet
unet_conv_in: torch.nn.Conv2d = unet.conv_in
@@ -424,7 +437,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
if self.full_train_in_out:
if self.is_pixart:
if self.is_pixart or self.is_auraflow:
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
else: