diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index a5c842d4..db6c43a3 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -21,6 +21,7 @@ import torch import torch.backends.cuda from huggingface_hub import HfApi, Repository, interpreter_login from huggingface_hub.utils import HfFolder +from toolkit.memory_management import MemoryManager from toolkit.basic import value_map from toolkit.clip_vision_adapter import ClipVisionAdapter @@ -1811,6 +1812,12 @@ class BaseSDTrainProcess(BaseTrainProcess): print_acc(f"Loading from {latest_save_path}") extra_weights = self.load_weights(latest_save_path) self.network.multiplier = 1.0 + + if self.network_config.layer_offloading: + MemoryManager.attach( + self.network, + self.device_torch + ) if self.embed_config is not None: # we are doing embedding training as well diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a6c89576..44f47a71 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -208,6 +208,9 @@ class NetworkConfig: # for multi stage models self.split_multistage_loras = kwargs.get('split_multistage_loras', True) + + # ramtorch, doesn't work yet + self.layer_offloading = kwargs.get('layer_offloading', False) AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v'] diff --git a/toolkit/memory_management/manager.py b/toolkit/memory_management/manager.py index 871fd405..a43d3f31 100644 --- a/toolkit/memory_management/manager.py +++ b/toolkit/memory_management/manager.py @@ -108,6 +108,14 @@ class MemoryManager: LinearLayerMemoryManager.attach( child_module, module._memory_manager ) + # attach to ARA as well + if hasattr(child_module, "ara_lora_ref"): + ara = child_module.ara_lora_ref() + if ara not in modules_processed: + MemoryManager.attach( + ara, + device, + ) modules_processed.append(child_module) elif ( child_module.__class__.__name__ in CONV_MODULES @@ -125,6 +133,15 @@ class MemoryManager: ConvLayerMemoryManager.attach( child_module, module._memory_manager ) + # attach to ARA as well + if hasattr(child_module, "ara_lora_ref"): + ara = child_module.ara_lora_ref() + if ara not in modules_processed: + MemoryManager.attach( + ara, + device, + ) + modules_processed.append(ara) modules_processed.append(child_module) elif child_module.__class__.__name__ in UNMANAGED_MODULES or any( inc in child_module.__class__.__name__ diff --git a/toolkit/memory_management/manager_modules.py b/toolkit/memory_management/manager_modules.py index d116b49a..7dac4b59 100644 --- a/toolkit/memory_management/manager_modules.py +++ b/toolkit/memory_management/manager_modules.py @@ -583,6 +583,8 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager): self.module.ara_lora_ref().org_forward = _mm_forward else: self.module.forward = _mm_forward + + self.module._memory_management_device = self.manager.process_device class ConvLayerMemoryManager(BaseLayerMemoryManager): @@ -638,3 +640,5 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager): self.module.ara_lora_ref().org_forward = _mm_forward else: self.module.forward = _mm_forward + + self.module._memory_management_device = self.manager.process_device diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 617f4baa..5421beb8 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -718,12 +718,18 @@ class ToolkitNetworkMixin: if hasattr(first_module, 'lora_down'): device = first_module.lora_down.weight.device dtype = first_module.lora_down.weight.dtype + if hasattr(first_module.lora_down, '_memory_management_device'): + device = first_module.lora_down._memory_management_device elif hasattr(first_module, 'lokr_w1'): device = first_module.lokr_w1.device dtype = first_module.lokr_w1.dtype + if hasattr(first_module.lokr_w1, '_memory_management_device'): + device = first_module.lokr_w1._memory_management_device elif hasattr(first_module, 'lokr_w1_a'): device = first_module.lokr_w1_a.device dtype = first_module.lokr_w1_a.dtype + if hasattr(first_module.lokr_w1_a, '_memory_management_device'): + device = first_module.lokr_w1_a._memory_management_device else: raise ValueError("Unknown module type") with torch.no_grad():