Files
ai-toolkit/toolkit/memory_management/manager.py
2025-10-05 13:03:26 -06:00

93 lines
2.8 KiB
Python

import torch
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
LINEAR_MODULES = [
"Linear",
"LoRACompatibleLinear",
"QLinear",
]
CONV_MODULES = [
"Conv2d",
"LoRACompatibleConv",
"QConv2d",
]
UNMANAGED_MODULES = [
"LayerNorm",
"BatchNorm1d",
"BatchNorm2d",
"BatchNorm3d",
"GroupNorm",
"InstanceNorm1d",
"InstanceNorm2d",
"InstanceNorm3d",
"Embedding",
"EmbeddingBag",
"RNNBase",
"LSTM",
"GRU",
"RNN",
]
UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm"]
class MemoryManager:
def __init__(
self,
module: torch.nn.Module,
process_device: torch.device = torch.device("cpu"),
):
self.module: torch.nn.Module = module
self.process_device: torch.device = process_device
self.unmanaged_modules: list[torch.nn.Module] = []
def memory_managed_to(self, *args, **kwargs):
# first move all the unmanaged modules
for module in self.unmanaged_modules:
module.to(*args, **kwargs)
# check for a dtype argument
dtype = None
if "dtype" in kwargs:
dtype = kwargs["dtype"]
elif len(args) > 0:
for i, arg in enumerate(args):
if isinstance(arg, torch.dtype):
dtype = arg
break
if dtype is not None:
return self.module._mm_to(dtype=dtype)
return self.module
@classmethod
def attach(cls, module: torch.nn.Module, device: torch.device):
if hasattr(module, "_memory_manager"):
# already attached
return
module._memory_manager = cls(module, device)
# override the to method to handle memory management
module._mm_to = module.to
module.to = module._memory_manager.memory_managed_to
# attach to all modules
for name, sub_module in module.named_modules():
for child_name, child_module in sub_module.named_modules():
if child_module.__class__.__name__ in LINEAR_MODULES:
# linear
LinearLayerMemoryManager.attach(
child_module, module._memory_manager
)
elif child_module.__class__.__name__ in CONV_MODULES:
# conv
ConvLayerMemoryManager.attach(child_module, module._memory_manager)
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
inc in child_module.__class__.__name__
for inc in UNMANAGED_MODULES_INCLUDES
):
# unmanaged
module._memory_manager.unmanaged_modules.append(child_module)
else:
continue