Fix issue where ARA was not working when using memory manager

This commit is contained in:
Jaret Burkett
2025-10-07 13:39:44 -06:00
parent dfc85f0b51
commit 55b8b0e23e
3 changed files with 29 additions and 4 deletions

View File

@@ -560,7 +560,11 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
_move_params_to_cpu_and_pin(self.module)
# 2) Hijack forward
self._original_forward = getattr(self.module, "forward")
if hasattr(self.module, "ara_lora_ref"):
# ARA, we need to replace the lora forward
self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward")
else:
self._original_forward = getattr(self.module, "forward")
def _mm_forward(x, *args, **kwargs):
# ensure we only use expected signature (Linear: x)
@@ -575,7 +579,10 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
# NOTE: do NOT move params to device here; autograd fn streams & bounces them
return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device)
self.module.forward = _mm_forward
if hasattr(self.module, "ara_lora_ref"):
self.module.ara_lora_ref().org_forward = _mm_forward
else:
self.module.forward = _mm_forward
class ConvLayerMemoryManager(BaseLayerMemoryManager):
@@ -608,7 +615,11 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
groups = self.module.groups
# 2) Hijack forward
self._original_forward = getattr(self.module, "forward")
if hasattr(self.module, "ara_lora_ref"):
# ARA, we need to replace the lora forward
self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward")
else:
self._original_forward = getattr(self.module, "forward")
def _mm_forward(x, *args, **kwargs):
# Support the typical Conv2d(x) call; if user passes uncommon extras, fallback.
@@ -623,4 +634,7 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups
)
self.module.forward = _mm_forward
if hasattr(self.module, "ara_lora_ref"):
self.module.ara_lora_ref().org_forward = _mm_forward
else:
self.module.forward = _mm_forward