Added early DoRA support, but will change shortly. Dont use right now.

This commit is contained in:
Jaret Burkett
2024-02-23 05:55:41 -07:00
parent 9ffa8c3711
commit 1bd94f0f01
5 changed files with 140 additions and 5 deletions

View File

@@ -15,6 +15,7 @@ from .paths import SD_SCRIPTS_ROOT
sys.path.append(SD_SCRIPTS_ROOT)
from networks.lora import LoRANetwork, get_block_index
from toolkit.models.DoRA import DoRAModule
from torch.utils.checkpoint import checkpoint
@@ -159,6 +160,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
attn_only: bool = False,
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
network_type: str = "lora",
**kwargs
) -> None:
"""
@@ -199,6 +201,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.is_sdxl = is_sdxl
self.is_v2 = is_v2
self.is_pixart = is_pixart
self.network_type = network_type
if self.network_type.lower() == "dora":
self.module_class = DoRAModule
module_class = DoRAModule
if modules_dim is not None:
print(f"create LoRA network from weights")