Offload ARA with the layer if doing layer offloading. Add support to offload the LoRA. Still needs optimizer support

This commit is contained in:
Jaret Burkett
2025-10-21 06:03:27 -06:00
parent 76ce757e0c
commit 0d8a33dc16
5 changed files with 37 additions and 0 deletions

View File

@@ -21,6 +21,7 @@ import torch
import torch.backends.cuda import torch.backends.cuda
from huggingface_hub import HfApi, Repository, interpreter_login from huggingface_hub import HfApi, Repository, interpreter_login
from huggingface_hub.utils import HfFolder from huggingface_hub.utils import HfFolder
from toolkit.memory_management import MemoryManager
from toolkit.basic import value_map from toolkit.basic import value_map
from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.clip_vision_adapter import ClipVisionAdapter
@@ -1812,6 +1813,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
extra_weights = self.load_weights(latest_save_path) extra_weights = self.load_weights(latest_save_path)
self.network.multiplier = 1.0 self.network.multiplier = 1.0
if self.network_config.layer_offloading:
MemoryManager.attach(
self.network,
self.device_torch
)
if self.embed_config is not None: if self.embed_config is not None:
# we are doing embedding training as well # we are doing embedding training as well
self.embedding = Embedding( self.embedding = Embedding(

View File

@@ -209,6 +209,9 @@ class NetworkConfig:
# for multi stage models # for multi stage models
self.split_multistage_loras = kwargs.get('split_multistage_loras', True) self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
# ramtorch, doesn't work yet
self.layer_offloading = kwargs.get('layer_offloading', False)
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v'] AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']

View File

@@ -108,6 +108,14 @@ class MemoryManager:
LinearLayerMemoryManager.attach( LinearLayerMemoryManager.attach(
child_module, module._memory_manager child_module, module._memory_manager
) )
# attach to ARA as well
if hasattr(child_module, "ara_lora_ref"):
ara = child_module.ara_lora_ref()
if ara not in modules_processed:
MemoryManager.attach(
ara,
device,
)
modules_processed.append(child_module) modules_processed.append(child_module)
elif ( elif (
child_module.__class__.__name__ in CONV_MODULES child_module.__class__.__name__ in CONV_MODULES
@@ -125,6 +133,15 @@ class MemoryManager:
ConvLayerMemoryManager.attach( ConvLayerMemoryManager.attach(
child_module, module._memory_manager child_module, module._memory_manager
) )
# attach to ARA as well
if hasattr(child_module, "ara_lora_ref"):
ara = child_module.ara_lora_ref()
if ara not in modules_processed:
MemoryManager.attach(
ara,
device,
)
modules_processed.append(ara)
modules_processed.append(child_module) modules_processed.append(child_module)
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any( elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
inc in child_module.__class__.__name__ inc in child_module.__class__.__name__

View File

@@ -584,6 +584,8 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
else: else:
self.module.forward = _mm_forward self.module.forward = _mm_forward
self.module._memory_management_device = self.manager.process_device
class ConvLayerMemoryManager(BaseLayerMemoryManager): class ConvLayerMemoryManager(BaseLayerMemoryManager):
def __init__( def __init__(
@@ -638,3 +640,5 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
self.module.ara_lora_ref().org_forward = _mm_forward self.module.ara_lora_ref().org_forward = _mm_forward
else: else:
self.module.forward = _mm_forward self.module.forward = _mm_forward
self.module._memory_management_device = self.manager.process_device

View File

@@ -718,12 +718,18 @@ class ToolkitNetworkMixin:
if hasattr(first_module, 'lora_down'): if hasattr(first_module, 'lora_down'):
device = first_module.lora_down.weight.device device = first_module.lora_down.weight.device
dtype = first_module.lora_down.weight.dtype dtype = first_module.lora_down.weight.dtype
if hasattr(first_module.lora_down, '_memory_management_device'):
device = first_module.lora_down._memory_management_device
elif hasattr(first_module, 'lokr_w1'): elif hasattr(first_module, 'lokr_w1'):
device = first_module.lokr_w1.device device = first_module.lokr_w1.device
dtype = first_module.lokr_w1.dtype dtype = first_module.lokr_w1.dtype
if hasattr(first_module.lokr_w1, '_memory_management_device'):
device = first_module.lokr_w1._memory_management_device
elif hasattr(first_module, 'lokr_w1_a'): elif hasattr(first_module, 'lokr_w1_a'):
device = first_module.lokr_w1_a.device device = first_module.lokr_w1_a.device
dtype = first_module.lokr_w1_a.dtype dtype = first_module.lokr_w1_a.dtype
if hasattr(first_module.lokr_w1_a, '_memory_management_device'):
device = first_module.lokr_w1_a._memory_management_device
else: else:
raise ValueError("Unknown module type") raise ValueError("Unknown module type")
with torch.no_grad(): with torch.no_grad():