mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Fix issue where ARA was not working when using memory manager
This commit is contained in:
@@ -59,6 +59,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
module_dropout=None,
|
||||
network: 'LoRASpecialNetwork' = None,
|
||||
use_bias: bool = False,
|
||||
is_ara: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
self.can_merge_in = True
|
||||
@@ -68,6 +69,10 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
self.lora_name = lora_name
|
||||
self.orig_module_ref = weakref.ref(org_module)
|
||||
self.scalar = torch.tensor(1.0, device=org_module.weight.device)
|
||||
|
||||
# if is ara lora module, mark it on the layer so memory manager can handle it
|
||||
if is_ara:
|
||||
org_module.ara_lora_ref = weakref.ref(self)
|
||||
# check if parent has bias. if not force use_bias to False
|
||||
if org_module.bias is None:
|
||||
use_bias = False
|
||||
@@ -193,6 +198,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
is_assistant_adapter: bool = False,
|
||||
is_transformer: bool = False,
|
||||
base_model: 'StableDiffusion' = None,
|
||||
is_ara: bool = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""
|
||||
@@ -247,6 +253,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
self.network_type = network_type
|
||||
self.is_assistant_adapter = is_assistant_adapter
|
||||
self.full_rank = network_type.lower() == "fullrank"
|
||||
self.is_ara = is_ara
|
||||
if self.network_type.lower() == "dora":
|
||||
self.module_class = DoRAModule
|
||||
module_class = DoRAModule
|
||||
@@ -419,6 +426,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
|
||||
if self.network_type.lower() == "lokr":
|
||||
module_kwargs["factor"] = self.network_config.lokr_factor
|
||||
|
||||
if self.is_ara:
|
||||
module_kwargs["is_ara"] = True
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
|
||||
Reference in New Issue
Block a user