mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Fix issue where ARA was not working when using memory manager
This commit is contained in:
@@ -560,7 +560,11 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
|
||||
_move_params_to_cpu_and_pin(self.module)
|
||||
|
||||
# 2) Hijack forward
|
||||
self._original_forward = getattr(self.module, "forward")
|
||||
if hasattr(self.module, "ara_lora_ref"):
|
||||
# ARA, we need to replace the lora forward
|
||||
self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward")
|
||||
else:
|
||||
self._original_forward = getattr(self.module, "forward")
|
||||
|
||||
def _mm_forward(x, *args, **kwargs):
|
||||
# ensure we only use expected signature (Linear: x)
|
||||
@@ -575,7 +579,10 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
|
||||
# NOTE: do NOT move params to device here; autograd fn streams & bounces them
|
||||
return _BouncingLinearFn.apply(x, weight_cpu, bias_cpu, device)
|
||||
|
||||
self.module.forward = _mm_forward
|
||||
if hasattr(self.module, "ara_lora_ref"):
|
||||
self.module.ara_lora_ref().org_forward = _mm_forward
|
||||
else:
|
||||
self.module.forward = _mm_forward
|
||||
|
||||
|
||||
class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
||||
@@ -608,7 +615,11 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
||||
groups = self.module.groups
|
||||
|
||||
# 2) Hijack forward
|
||||
self._original_forward = getattr(self.module, "forward")
|
||||
if hasattr(self.module, "ara_lora_ref"):
|
||||
# ARA, we need to replace the lora forward
|
||||
self._original_forward = getattr(self.module.ara_lora_ref(), "org_forward")
|
||||
else:
|
||||
self._original_forward = getattr(self.module, "forward")
|
||||
|
||||
def _mm_forward(x, *args, **kwargs):
|
||||
# Support the typical Conv2d(x) call; if user passes uncommon extras, fallback.
|
||||
@@ -623,4 +634,7 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
||||
x, weight_cpu, bias_cpu, device, stride, padding, dilation, groups
|
||||
)
|
||||
|
||||
self.module.forward = _mm_forward
|
||||
if hasattr(self.module, "ara_lora_ref"):
|
||||
self.module.ara_lora_ref().org_forward = _mm_forward
|
||||
else:
|
||||
self.module.forward = _mm_forward
|
||||
|
||||
Reference in New Issue
Block a user