Offload ARA with the layer if doing layer offloading. Add support to offload the LoRA. Still needs optimizer support

This commit is contained in:
Jaret Burkett
2025-10-21 06:03:27 -06:00
parent 76ce757e0c
commit 0d8a33dc16
5 changed files with 37 additions and 0 deletions

View File

@@ -718,12 +718,18 @@ class ToolkitNetworkMixin:
if hasattr(first_module, 'lora_down'):
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
if hasattr(first_module.lora_down, '_memory_management_device'):
device = first_module.lora_down._memory_management_device
elif hasattr(first_module, 'lokr_w1'):
device = first_module.lokr_w1.device
dtype = first_module.lokr_w1.dtype
if hasattr(first_module.lokr_w1, '_memory_management_device'):
device = first_module.lokr_w1._memory_management_device
elif hasattr(first_module, 'lokr_w1_a'):
device = first_module.lokr_w1_a.device
dtype = first_module.lokr_w1_a.dtype
if hasattr(first_module.lokr_w1_a, '_memory_management_device'):
device = first_module.lokr_w1_a._memory_management_device
else:
raise ValueError("Unknown module type")
with torch.no_grad():