Offload ARA with the layer if doing layer offloading. Add support to offload the LoRA. Still needs optimizer support

This commit is contained in:
Jaret Burkett
2025-10-21 06:03:27 -06:00
parent 76ce757e0c
commit 0d8a33dc16
5 changed files with 37 additions and 0 deletions

View File

@@ -208,6 +208,9 @@ class NetworkConfig:
# for multi stage models
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
# ramtorch, doesn't work yet
self.layer_offloading = kwargs.get('layer_offloading', False)
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']

View File

@@ -108,6 +108,14 @@ class MemoryManager:
LinearLayerMemoryManager.attach(
child_module, module._memory_manager
)
# attach to ARA as well
if hasattr(child_module, "ara_lora_ref"):
ara = child_module.ara_lora_ref()
if ara not in modules_processed:
MemoryManager.attach(
ara,
device,
)
modules_processed.append(child_module)
elif (
child_module.__class__.__name__ in CONV_MODULES
@@ -125,6 +133,15 @@ class MemoryManager:
ConvLayerMemoryManager.attach(
child_module, module._memory_manager
)
# attach to ARA as well
if hasattr(child_module, "ara_lora_ref"):
ara = child_module.ara_lora_ref()
if ara not in modules_processed:
MemoryManager.attach(
ara,
device,
)
modules_processed.append(ara)
modules_processed.append(child_module)
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
inc in child_module.__class__.__name__

View File

@@ -583,6 +583,8 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
self.module.ara_lora_ref().org_forward = _mm_forward
else:
self.module.forward = _mm_forward
self.module._memory_management_device = self.manager.process_device
class ConvLayerMemoryManager(BaseLayerMemoryManager):
@@ -638,3 +640,5 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
self.module.ara_lora_ref().org_forward = _mm_forward
else:
self.module.forward = _mm_forward
self.module._memory_management_device = self.manager.process_device

View File

@@ -718,12 +718,18 @@ class ToolkitNetworkMixin:
if hasattr(first_module, 'lora_down'):
device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype
if hasattr(first_module.lora_down, '_memory_management_device'):
device = first_module.lora_down._memory_management_device
elif hasattr(first_module, 'lokr_w1'):
device = first_module.lokr_w1.device
dtype = first_module.lokr_w1.dtype
if hasattr(first_module.lokr_w1, '_memory_management_device'):
device = first_module.lokr_w1._memory_management_device
elif hasattr(first_module, 'lokr_w1_a'):
device = first_module.lokr_w1_a.device
dtype = first_module.lokr_w1_a.dtype
if hasattr(first_module.lokr_w1_a, '_memory_management_device'):
device = first_module.lokr_w1_a._memory_management_device
else:
raise ValueError("Unknown module type")
with torch.no_grad():