mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
A lot of pixart sigma training tweaks
This commit is contained in:
@@ -169,6 +169,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
|
||||
network_type: str = "lora",
|
||||
full_train_in_out: bool = False,
|
||||
transformer_only: bool = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
@@ -193,6 +194,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if ignore_if_contains is None:
|
||||
ignore_if_contains = []
|
||||
self.ignore_if_contains = ignore_if_contains
|
||||
self.transformer_only = transformer_only
|
||||
|
||||
self.only_if_contains: Union[List, None] = only_if_contains
|
||||
|
||||
@@ -271,6 +273,15 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
is_conv2d = child_module.__class__.__name__ in CONV_MODULES
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
|
||||
lora_name = [prefix, name, child_name]
|
||||
# filter out blank
|
||||
lora_name = [x for x in lora_name if x and x != ""]
|
||||
lora_name = ".".join(lora_name)
|
||||
# if it doesnt have a name, it wil have two dots
|
||||
lora_name.replace("..", ".")
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
skip = False
|
||||
if any([word in child_name for word in self.ignore_if_contains]):
|
||||
skip = True
|
||||
@@ -279,9 +290,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
if count_parameters(child_module) < parameter_threshold:
|
||||
skip = True
|
||||
|
||||
if self.transformer_only and self.is_pixart and is_unet:
|
||||
if "transformer_blocks" not in lora_name:
|
||||
skip = True
|
||||
|
||||
if (is_linear or is_conv2d) and not skip:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if self.only_if_contains is not None and not any([word in lora_name for word in self.only_if_contains]):
|
||||
continue
|
||||
@@ -356,8 +369,12 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
index = None
|
||||
print(f"create LoRA for Text Encoder:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder,
|
||||
LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
replace_modules = LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
|
||||
if self.is_pixart:
|
||||
replace_modules = ["T5EncoderModel"]
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, replace_modules)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
Reference in New Issue
Block a user