mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added early DoRA support, but will change shortly. Dont use right now.
This commit is contained in:
@@ -15,6 +15,7 @@ from .paths import SD_SCRIPTS_ROOT
|
||||
sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from networks.lora import LoRANetwork, get_block_index
|
||||
from toolkit.models.DoRA import DoRAModule
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
@@ -159,6 +160,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
attn_only: bool = False,
|
||||
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
|
||||
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
|
||||
network_type: str = "lora",
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
@@ -199,6 +201,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.is_sdxl = is_sdxl
|
||||
self.is_v2 = is_v2
|
||||
self.is_pixart = is_pixart
|
||||
self.network_type = network_type
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
module_class = DoRAModule
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
|
||||
Reference in New Issue
Block a user