From 55b8b0e23e0a4328bf7ed3faef45b009a431e741 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 7 Oct 2025 13:39:44 -0600 Subject: [PATCH] Fix issue where ARA was not working when using memory manager --- toolkit/lora_special.py | 10 +++++++++ toolkit/memory_management/manager_modules.py | 22 ++++++++++++++++---- toolkit/util/quantize.py | 1 + 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 2c4c11a8..cd454656 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -59,6 +59,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): module_dropout=None, network: 'LoRASpecialNetwork' = None, use_bias: bool = False, + is_ara: bool = False, **kwargs ): self.can_merge_in = True @@ -68,6 +69,10 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): self.lora_name = lora_name self.orig_module_ref = weakref.ref(org_module) self.scalar = torch.tensor(1.0, device=org_module.weight.device) + + # if is ara lora module, mark it on the layer so memory manager can handle it + if is_ara: + org_module.ara_lora_ref = weakref.ref(self) # check if parent has bias. if not force use_bias to False if org_module.bias is None: use_bias = False @@ -193,6 +198,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): is_assistant_adapter: bool = False, is_transformer: bool = False, base_model: 'StableDiffusion' = None, + is_ara: bool = False, **kwargs ) -> None: """ @@ -247,6 +253,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.network_type = network_type self.is_assistant_adapter = is_assistant_adapter self.full_rank = network_type.lower() == "fullrank" + self.is_ara = is_ara if self.network_type.lower() == "dora": self.module_class = DoRAModule module_class = DoRAModule @@ -419,6 +426,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if self.network_type.lower() == "lokr": module_kwargs["factor"] = self.network_config.lokr_factor + + if self.is_ara: + module_kwargs["is_ara"] = True lora = module_class( lora_name, diff --git a/toolkit/memory_management/manager_modules.py b/toolkit/memory_management/manager_modules.py index 02852918..d116b49a 100644 --- a/toolkit/memory_management/manager_modules.py +++ b/toolkit/memory_management/manager_modules.py @@ -560,7 +560,11 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager): _move_params_to_cpu_and_pin(self.module) # 2) Hijack forward - self._original_forward = getattr(self.module, "forward") + if hasattr(self.module, "ara_lora_ref"): + # ARA, we need to replace the lora forward + self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward") + else: + self._original_forward = getattr(self.module, "forward") def _mm_forward(x, *args, **kwargs): # ensure we only use expected signature (Linear: x) @@ -575,7 +579,10 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager): # NOTE: do NOT move params to device here; autograd fn streams & bounces them return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device) - self.module.forward = _mm_forward + if hasattr(self.module, "ara_lora_ref"): + self.module.ara_lora_ref().org_forward = _mm_forward + else: + self.module.forward = _mm_forward class ConvLayerMemoryManager(BaseLayerMemoryManager): @@ -608,7 +615,11 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager): groups = self.module.groups # 2) Hijack forward - self._original_forward = getattr(self.module, "forward") + if hasattr(self.module, "ara_lora_ref"): + # ARA, we need to replace the lora forward + self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward") + else: + self._original_forward = getattr(self.module, "forward") def _mm_forward(x, *args, **kwargs): # Support the typical Conv2d(x) call; if user passes uncommon extras, fallback. @@ -623,4 +634,7 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager): x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups ) - self.module.forward = _mm_forward + if hasattr(self.module, "ara_lora_ref"): + self.module.ara_lora_ref().org_forward = _mm_forward + else: + self.module.forward = _mm_forward diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py index e3e81940..27d6733a 100644 --- a/toolkit/util/quantize.py +++ b/toolkit/util/quantize.py @@ -247,6 +247,7 @@ def quantize_model( transformer_only=network_config.transformer_only, is_transformer=base_model.is_transformer, base_model=base_model, + is_ara=True, **network_kwargs ) network.apply_to(