Files

154 lines
5.4 KiB
Python

import torch
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
import random
LINEAR_MODULES = [
"Linear",
"LoRACompatibleLinear",
"QLinear",
]
CONV_MODULES = [
"Conv2d",
"LoRACompatibleConv",
"QConv2d",
]
UNMANAGED_MODULES = [
"LayerNorm",
"BatchNorm1d",
"BatchNorm2d",
"BatchNorm3d",
"GroupNorm",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"Embedding",
"EmbeddingBag",
"RNNBase",
"LSTM",
"GRU",
"RNN",
"Conv3d"
]
UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm", "RotaryPosEmbed"]
class MemoryManager:
def __init__(
self,
module: torch.nn.Module,
process_device: torch.device = torch.device("cpu"),
):
self.module: torch.nn.Module = module
self.process_device: torch.device = process_device
self.unmanaged_modules: list[torch.nn.Module] = []
def memory_managed_to(self, *args, **kwargs):
# first move all the unmanaged modules
for module in self.unmanaged_modules:
if isinstance(module, torch.nn.Parameter):
# Parameter cannot move this way
module.data = module.data.to(*args, **kwargs)
else:
module.to(*args, **kwargs)
# check for a dtype argument
dtype = None
if "dtype" in kwargs:
dtype = kwargs["dtype"]
elif len(args) > 0:
for i, arg in enumerate(args):
if isinstance(arg, torch.dtype):
dtype = arg
break
if dtype is not None:
return self.module._mm_to(dtype=dtype)
return self.module
@classmethod
def attach(
cls,
module: torch.nn.Module,
device: torch.device,
offload_percent: float = 1.0,
ignore_modules: list[torch.nn.Module] = []
):
if hasattr(module, "_memory_manager"):
# already attached
return
module._memory_manager = cls(module, device)
# override the to method to handle memory management
module._mm_to = module.to
module.to = module._memory_manager.memory_managed_to
# add ignore modules to unmanaged list
for im in ignore_modules:
module._memory_manager.unmanaged_modules.append(im)
# count ignore modules as processed
modules_processed = [x for x in ignore_modules]
# attach to all modules
for name, sub_module in module.named_modules():
for child_name, child_module in sub_module.named_modules():
if (
child_module.__class__.__name__ in LINEAR_MODULES
and child_module not in modules_processed
):
skip = False
if offload_percent < 1.0:
# randomly skip some modules
if random.random() > offload_percent:
skip = True
if skip:
module._memory_manager.unmanaged_modules.append(child_module)
else:
# linear
LinearLayerMemoryManager.attach(
child_module, module._memory_manager
)
# attach to ARA as well
if hasattr(child_module, "ara_lora_ref"):
ara = child_module.ara_lora_ref()
if ara not in modules_processed:
MemoryManager.attach(
ara,
device,
)
modules_processed.append(child_module)
elif (
child_module.__class__.__name__ in CONV_MODULES
and child_module not in modules_processed
):
skip = False
if offload_percent < 1.0:
# randomly skip some modules
if random.random() > offload_percent:
skip = True
if skip:
module._memory_manager.unmanaged_modules.append(child_module)
else:
# conv
ConvLayerMemoryManager.attach(
child_module, module._memory_manager
)
# attach to ARA as well
if hasattr(child_module, "ara_lora_ref"):
ara = child_module.ara_lora_ref()
if ara not in modules_processed:
MemoryManager.attach(
ara,
device,
)
modules_processed.append(ara)
modules_processed.append(child_module)
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
inc in child_module.__class__.__name__
for inc in UNMANAGED_MODULES_INCLUDES
):
# unmanaged
module._memory_manager.unmanaged_modules.append(child_module)
else:
continue