Fix issue where ARA was not working when using memory manager

This commit is contained in:
Jaret Burkett
2025-10-07 13:39:44 -06:00
parent dfc85f0b51
commit 55b8b0e23e
3 changed files with 29 additions and 4 deletions

View File

@@ -59,6 +59,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
module_dropout=None, module_dropout=None,
network: 'LoRASpecialNetwork' = None, network: 'LoRASpecialNetwork' = None,
use_bias: bool = False, use_bias: bool = False,
is_ara: bool = False,
**kwargs **kwargs
): ):
self.can_merge_in = True self.can_merge_in = True
@@ -68,6 +69,10 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
self.lora_name = lora_name self.lora_name = lora_name
self.orig_module_ref = weakref.ref(org_module) self.orig_module_ref = weakref.ref(org_module)
self.scalar = torch.tensor(1.0, device=org_module.weight.device) self.scalar = torch.tensor(1.0, device=org_module.weight.device)
# if is ara lora module, mark it on the layer so memory manager can handle it
if is_ara:
org_module.ara_lora_ref = weakref.ref(self)
# check if parent has bias. if not force use_bias to False # check if parent has bias. if not force use_bias to False
if org_module.bias is None: if org_module.bias is None:
use_bias = False use_bias = False
@@ -193,6 +198,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
is_assistant_adapter: bool = False, is_assistant_adapter: bool = False,
is_transformer: bool = False, is_transformer: bool = False,
base_model: 'StableDiffusion' = None, base_model: 'StableDiffusion' = None,
is_ara: bool = False,
**kwargs **kwargs
) -> None: ) -> None:
""" """
@@ -247,6 +253,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
self.network_type = network_type self.network_type = network_type
self.is_assistant_adapter = is_assistant_adapter self.is_assistant_adapter = is_assistant_adapter
self.full_rank = network_type.lower() == "fullrank" self.full_rank = network_type.lower() == "fullrank"
self.is_ara = is_ara
if self.network_type.lower() == "dora": if self.network_type.lower() == "dora":
self.module_class = DoRAModule self.module_class = DoRAModule
module_class = DoRAModule module_class = DoRAModule
@@ -419,6 +426,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.network_type.lower() == "lokr": if self.network_type.lower() == "lokr":
module_kwargs["factor"] = self.network_config.lokr_factor module_kwargs["factor"] = self.network_config.lokr_factor
if self.is_ara:
module_kwargs["is_ara"] = True
lora = module_class( lora = module_class(
lora_name, lora_name,

View File

@@ -560,7 +560,11 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
_move_params_to_cpu_and_pin(self.module) _move_params_to_cpu_and_pin(self.module)
# 2) Hijack forward # 2) Hijack forward
self._original_forward = getattr(self.module, "forward") if hasattr(self.module, "ara_lora_ref"):
# ARA, we need to replace the lora forward
self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward")
else:
self._original_forward = getattr(self.module, "forward")
def _mm_forward(x, *args, **kwargs): def _mm_forward(x, *args, **kwargs):
# ensure we only use expected signature (Linear: x) # ensure we only use expected signature (Linear: x)
@@ -575,7 +579,10 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
# NOTE: do NOT move params to device here; autograd fn streams & bounces them # NOTE: do NOT move params to device here; autograd fn streams & bounces them
return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device) return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device)
self.module.forward = _mm_forward if hasattr(self.module, "ara_lora_ref"):
self.module.ara_lora_ref().org_forward = _mm_forward
else:
self.module.forward = _mm_forward
class ConvLayerMemoryManager(BaseLayerMemoryManager): class ConvLayerMemoryManager(BaseLayerMemoryManager):
@@ -608,7 +615,11 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
groups = self.module.groups groups = self.module.groups
# 2) Hijack forward # 2) Hijack forward
self._original_forward = getattr(self.module, "forward") if hasattr(self.module, "ara_lora_ref"):
# ARA, we need to replace the lora forward
self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward")
else:
self._original_forward = getattr(self.module, "forward")
def _mm_forward(x, *args, **kwargs): def _mm_forward(x, *args, **kwargs):
# Support the typical Conv2d(x) call; if user passes uncommon extras, fallback. # Support the typical Conv2d(x) call; if user passes uncommon extras, fallback.
@@ -623,4 +634,7 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups
) )
self.module.forward = _mm_forward if hasattr(self.module, "ara_lora_ref"):
self.module.ara_lora_ref().org_forward = _mm_forward
else:
self.module.forward = _mm_forward

View File

@@ -247,6 +247,7 @@ def quantize_model(
transformer_only=network_config.transformer_only, transformer_only=network_config.transformer_only,
is_transformer=base_model.is_transformer, is_transformer=base_model.is_transformer,
base_model=base_model, base_model=base_model,
is_ara=True,
**network_kwargs **network_kwargs
) )
network.apply_to( network.apply_to(