import torch from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager import random LINEAR_MODULES = [ "Linear", "LoRACompatibleLinear", "QLinear", ] CONV_MODULES = [ "Conv2d", "LoRACompatibleConv", "QConv2d", ] UNMANAGED_MODULES = [ "LayerNorm", "BatchNorm1d", "BatchNorm2d", "BatchNorm3d", "GroupNorm", "InstanceNorm1d", "InstanceNorm2d", "InstanceNorm3d", "Embedding", "EmbeddingBag", "RNNBase", "LSTM", "GRU", "RNN", "Conv3d" ] UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm", "RotaryPosEmbed"] class MemoryManager: def __init__( self, module: torch.nn.Module, process_device: torch.device = torch.device("cpu"), ): self.module: torch.nn.Module = module self.process_device: torch.device = process_device self.unmanaged_modules: list[torch.nn.Module] = [] def memory_managed_to(self, *args, **kwargs): # first move all the unmanaged modules for module in self.unmanaged_modules: if isinstance(module, torch.nn.Parameter): # Parameter cannot move this way module.data = module.data.to(*args, **kwargs) else: module.to(*args, **kwargs) # check for a dtype argument dtype = None if "dtype" in kwargs: dtype = kwargs["dtype"] elif len(args) > 0: for i, arg in enumerate(args): if isinstance(arg, torch.dtype): dtype = arg break if dtype is not None: return self.module._mm_to(dtype=dtype) return self.module @classmethod def attach( cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0, ignore_modules: list[torch.nn.Module] = [] ): if hasattr(module, "_memory_manager"): # already attached return module._memory_manager = cls(module, device) # override the to method to handle memory management module._mm_to = module.to module.to = module._memory_manager.memory_managed_to # add ignore modules to unmanaged list for im in ignore_modules: module._memory_manager.unmanaged_modules.append(im) # count ignore modules as processed modules_processed = [x for x in ignore_modules] # attach to all modules for name, sub_module in module.named_modules(): for child_name, child_module in sub_module.named_modules(): if ( child_module.__class__.__name__ in LINEAR_MODULES and child_module not in modules_processed ): skip = False if offload_percent < 1.0: # randomly skip some modules if random.random() > offload_percent: skip = True if skip: module._memory_manager.unmanaged_modules.append(child_module) else: # linear LinearLayerMemoryManager.attach( child_module, module._memory_manager ) # attach to ARA as well if hasattr(child_module, "ara_lora_ref"): ara = child_module.ara_lora_ref() if ara not in modules_processed: MemoryManager.attach( ara, device, ) modules_processed.append(child_module) elif ( child_module.__class__.__name__ in CONV_MODULES and child_module not in modules_processed ): skip = False if offload_percent < 1.0: # randomly skip some modules if random.random() > offload_percent: skip = True if skip: module._memory_manager.unmanaged_modules.append(child_module) else: # conv ConvLayerMemoryManager.attach( child_module, module._memory_manager ) # attach to ARA as well if hasattr(child_module, "ara_lora_ref"): ara = child_module.ara_lora_ref() if ara not in modules_processed: MemoryManager.attach( ara, device, ) modules_processed.append(ara) modules_processed.append(child_module) elif child_module.__class__.__name__ in UNMANAGED_MODULES or any( inc in child_module.__class__.__name__ for inc in UNMANAGED_MODULES_INCLUDES ): # unmanaged module._memory_manager.unmanaged_modules.append(child_module) else: continue