mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Change auto_memory to be layer_offloading and allow you to set the amount to unload
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
|
||||
import random
|
||||
|
||||
LINEAR_MODULES = [
|
||||
"Linear",
|
||||
@@ -60,7 +61,9 @@ class MemoryManager:
|
||||
return self.module
|
||||
|
||||
@classmethod
|
||||
def attach(cls, module: torch.nn.Module, device: torch.device):
|
||||
def attach(
|
||||
cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0
|
||||
):
|
||||
if hasattr(module, "_memory_manager"):
|
||||
# already attached
|
||||
return
|
||||
@@ -71,17 +74,44 @@ class MemoryManager:
|
||||
module._mm_to = module.to
|
||||
module.to = module._memory_manager.memory_managed_to
|
||||
|
||||
modules_processed = []
|
||||
# attach to all modules
|
||||
for name, sub_module in module.named_modules():
|
||||
for child_name, child_module in sub_module.named_modules():
|
||||
if child_module.__class__.__name__ in LINEAR_MODULES:
|
||||
# linear
|
||||
LinearLayerMemoryManager.attach(
|
||||
child_module, module._memory_manager
|
||||
)
|
||||
elif child_module.__class__.__name__ in CONV_MODULES:
|
||||
# conv
|
||||
ConvLayerMemoryManager.attach(child_module, module._memory_manager)
|
||||
if (
|
||||
child_module.__class__.__name__ in LINEAR_MODULES
|
||||
and child_module not in modules_processed
|
||||
):
|
||||
skip = False
|
||||
if offload_percent < 1.0:
|
||||
# randomly skip some modules
|
||||
if random.random() > offload_percent:
|
||||
skip = True
|
||||
if skip:
|
||||
module._memory_manager.unmanaged_modules.append(child_module)
|
||||
else:
|
||||
# linear
|
||||
LinearLayerMemoryManager.attach(
|
||||
child_module, module._memory_manager
|
||||
)
|
||||
modules_processed.append(child_module)
|
||||
elif (
|
||||
child_module.__class__.__name__ in CONV_MODULES
|
||||
and child_module not in modules_processed
|
||||
):
|
||||
skip = False
|
||||
if offload_percent < 1.0:
|
||||
# randomly skip some modules
|
||||
if random.random() > offload_percent:
|
||||
skip = True
|
||||
if skip:
|
||||
module._memory_manager.unmanaged_modules.append(child_module)
|
||||
else:
|
||||
# conv
|
||||
ConvLayerMemoryManager.attach(
|
||||
child_module, module._memory_manager
|
||||
)
|
||||
modules_processed.append(child_module)
|
||||
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
|
||||
inc in child_module.__class__.__name__
|
||||
for inc in UNMANAGED_MODULES_INCLUDES
|
||||
|
||||
Reference in New Issue
Block a user