mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 10:41:28 +00:00
Partial implementation for training auraflow.
This commit is contained in:
@@ -7,7 +7,7 @@ import re
|
||||
import sys
|
||||
from typing import List, Optional, Dict, Type, Union
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel
|
||||
from diffusers import UNet2DConditionModel, PixArtTransformer2DModel, AuraFlowTransformer2DModel
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from .config_modules import NetworkConfig
|
||||
@@ -158,6 +158,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
is_v2=False,
|
||||
is_v3=False,
|
||||
is_pixart: bool = False,
|
||||
is_auraflow: bool = False,
|
||||
use_bias: bool = False,
|
||||
is_lorm: bool = False,
|
||||
ignore_if_contains = None,
|
||||
@@ -212,6 +213,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.is_v2 = is_v2
|
||||
self.is_v3 = is_v3
|
||||
self.is_pixart = is_pixart
|
||||
self.is_auraflow = is_auraflow
|
||||
self.network_type = network_type
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
@@ -246,7 +248,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[LoRAModule]:
|
||||
unet_prefix = self.LORA_PREFIX_UNET
|
||||
if is_pixart or is_v3:
|
||||
if is_pixart or is_v3 or is_auraflow:
|
||||
unet_prefix = f"lora_transformer"
|
||||
|
||||
prefix = (
|
||||
@@ -371,6 +373,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if is_pixart:
|
||||
target_modules = ["PixArtTransformer2DModel"]
|
||||
|
||||
if is_auraflow:
|
||||
target_modules = ["AuraFlowTransformer2DModel"]
|
||||
|
||||
if train_unet:
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
else:
|
||||
@@ -408,6 +413,14 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
transformer.pos_embed = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
elif self.is_auraflow:
|
||||
transformer: AuraFlowTransformer2DModel = unet
|
||||
self.transformer_pos_embed = copy.deepcopy(transformer.pos_embed)
|
||||
self.transformer_proj_out = copy.deepcopy(transformer.proj_out)
|
||||
|
||||
transformer.pos_embed = self.transformer_pos_embed
|
||||
transformer.proj_out = self.transformer_proj_out
|
||||
|
||||
else:
|
||||
unet: UNet2DConditionModel = unet
|
||||
unet_conv_in: torch.nn.Conv2d = unet.conv_in
|
||||
@@ -424,7 +437,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
all_params = super().prepare_optimizer_params(text_encoder_lr, unet_lr, default_lr)
|
||||
|
||||
if self.full_train_in_out:
|
||||
if self.is_pixart:
|
||||
if self.is_pixart or self.is_auraflow:
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_pos_embed.parameters())})
|
||||
all_params.append({"lr": unet_lr, "params": list(self.transformer_proj_out.parameters())})
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user