mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Offload ARA with the layer if doing layer offloading. Add support to offload the LoRA. Still needs optimizer support
This commit is contained in:
@@ -21,6 +21,7 @@ import torch
|
|||||||
import torch.backends.cuda
|
import torch.backends.cuda
|
||||||
from huggingface_hub import HfApi, Repository, interpreter_login
|
from huggingface_hub import HfApi, Repository, interpreter_login
|
||||||
from huggingface_hub.utils import HfFolder
|
from huggingface_hub.utils import HfFolder
|
||||||
|
from toolkit.memory_management import MemoryManager
|
||||||
|
|
||||||
from toolkit.basic import value_map
|
from toolkit.basic import value_map
|
||||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||||
@@ -1811,6 +1812,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
print_acc(f"Loading from {latest_save_path}")
|
print_acc(f"Loading from {latest_save_path}")
|
||||||
extra_weights = self.load_weights(latest_save_path)
|
extra_weights = self.load_weights(latest_save_path)
|
||||||
self.network.multiplier = 1.0
|
self.network.multiplier = 1.0
|
||||||
|
|
||||||
|
if self.network_config.layer_offloading:
|
||||||
|
MemoryManager.attach(
|
||||||
|
self.network,
|
||||||
|
self.device_torch
|
||||||
|
)
|
||||||
|
|
||||||
if self.embed_config is not None:
|
if self.embed_config is not None:
|
||||||
# we are doing embedding training as well
|
# we are doing embedding training as well
|
||||||
|
|||||||
@@ -208,6 +208,9 @@ class NetworkConfig:
|
|||||||
|
|
||||||
# for multi stage models
|
# for multi stage models
|
||||||
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
|
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
|
||||||
|
|
||||||
|
# ramtorch, doesn't work yet
|
||||||
|
self.layer_offloading = kwargs.get('layer_offloading', False)
|
||||||
|
|
||||||
|
|
||||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']
|
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']
|
||||||
|
|||||||
@@ -108,6 +108,14 @@ class MemoryManager:
|
|||||||
LinearLayerMemoryManager.attach(
|
LinearLayerMemoryManager.attach(
|
||||||
child_module, module._memory_manager
|
child_module, module._memory_manager
|
||||||
)
|
)
|
||||||
|
# attach to ARA as well
|
||||||
|
if hasattr(child_module, "ara_lora_ref"):
|
||||||
|
ara = child_module.ara_lora_ref()
|
||||||
|
if ara not in modules_processed:
|
||||||
|
MemoryManager.attach(
|
||||||
|
ara,
|
||||||
|
device,
|
||||||
|
)
|
||||||
modules_processed.append(child_module)
|
modules_processed.append(child_module)
|
||||||
elif (
|
elif (
|
||||||
child_module.__class__.__name__ in CONV_MODULES
|
child_module.__class__.__name__ in CONV_MODULES
|
||||||
@@ -125,6 +133,15 @@ class MemoryManager:
|
|||||||
ConvLayerMemoryManager.attach(
|
ConvLayerMemoryManager.attach(
|
||||||
child_module, module._memory_manager
|
child_module, module._memory_manager
|
||||||
)
|
)
|
||||||
|
# attach to ARA as well
|
||||||
|
if hasattr(child_module, "ara_lora_ref"):
|
||||||
|
ara = child_module.ara_lora_ref()
|
||||||
|
if ara not in modules_processed:
|
||||||
|
MemoryManager.attach(
|
||||||
|
ara,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
modules_processed.append(ara)
|
||||||
modules_processed.append(child_module)
|
modules_processed.append(child_module)
|
||||||
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
|
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
|
||||||
inc in child_module.__class__.__name__
|
inc in child_module.__class__.__name__
|
||||||
|
|||||||
@@ -583,6 +583,8 @@ class LinearLayerMemoryManager(BaseLayerMemoryManager):
|
|||||||
self.module.ara_lora_ref().org_forward = _mm_forward
|
self.module.ara_lora_ref().org_forward = _mm_forward
|
||||||
else:
|
else:
|
||||||
self.module.forward = _mm_forward
|
self.module.forward = _mm_forward
|
||||||
|
|
||||||
|
self.module._memory_management_device = self.manager.process_device
|
||||||
|
|
||||||
|
|
||||||
class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
||||||
@@ -638,3 +640,5 @@ class ConvLayerMemoryManager(BaseLayerMemoryManager):
|
|||||||
self.module.ara_lora_ref().org_forward = _mm_forward
|
self.module.ara_lora_ref().org_forward = _mm_forward
|
||||||
else:
|
else:
|
||||||
self.module.forward = _mm_forward
|
self.module.forward = _mm_forward
|
||||||
|
|
||||||
|
self.module._memory_management_device = self.manager.process_device
|
||||||
|
|||||||
@@ -718,12 +718,18 @@ class ToolkitNetworkMixin:
|
|||||||
if hasattr(first_module, 'lora_down'):
|
if hasattr(first_module, 'lora_down'):
|
||||||
device = first_module.lora_down.weight.device
|
device = first_module.lora_down.weight.device
|
||||||
dtype = first_module.lora_down.weight.dtype
|
dtype = first_module.lora_down.weight.dtype
|
||||||
|
if hasattr(first_module.lora_down, '_memory_management_device'):
|
||||||
|
device = first_module.lora_down._memory_management_device
|
||||||
elif hasattr(first_module, 'lokr_w1'):
|
elif hasattr(first_module, 'lokr_w1'):
|
||||||
device = first_module.lokr_w1.device
|
device = first_module.lokr_w1.device
|
||||||
dtype = first_module.lokr_w1.dtype
|
dtype = first_module.lokr_w1.dtype
|
||||||
|
if hasattr(first_module.lokr_w1, '_memory_management_device'):
|
||||||
|
device = first_module.lokr_w1._memory_management_device
|
||||||
elif hasattr(first_module, 'lokr_w1_a'):
|
elif hasattr(first_module, 'lokr_w1_a'):
|
||||||
device = first_module.lokr_w1_a.device
|
device = first_module.lokr_w1_a.device
|
||||||
dtype = first_module.lokr_w1_a.dtype
|
dtype = first_module.lokr_w1_a.dtype
|
||||||
|
if hasattr(first_module.lokr_w1_a, '_memory_management_device'):
|
||||||
|
device = first_module.lokr_w1_a._memory_management_device
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown module type")
|
raise ValueError("Unknown module type")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Reference in New Issue
Block a user