Change auto_memory to be layer_offloading and allow you to set the amount to unload

This commit is contained in:
Jaret Burkett
2025-10-10 13:12:32 -06:00
parent 2c2fbf16ea
commit 1bc6dee127
11 changed files with 279 additions and 45 deletions

View File

@@ -626,12 +626,18 @@ class ModelConfig:
# auto memory management, only for some models
self.auto_memory = kwargs.get("auto_memory", False)
if self.auto_memory and self.qtype == "qfloat8":
print(f"Auto memory is not compatible with qfloat8, switching to float8 for model")
# auto memory is deprecated, use layer offloading instead
if self.auto_memory:
print("auto_memory is deprecated, use layer_offloading instead")
self.layer_offloading = kwargs.get("layer_offloading", self.auto_memory )
if self.layer_offloading and self.qtype == "qfloat8":
self.qtype = "float8"
if self.auto_memory and not self.qtype_te == "qfloat8":
print(f"Auto memory is not compatible with qfloat8, switching to float8 for te")
if self.layer_offloading and not self.qtype_te == "qfloat8":
self.qtype_te = "float8"
# 0 is off and 1.0 is 100% of the layers
self.layer_offloading_transformer_percent = kwargs.get("layer_offloading_transformer_percent", 1.0)
self.layer_offloading_text_encoder_percent = kwargs.get("layer_offloading_text_encoder_percent", 1.0)
# can be used to load the extras like text encoder or vae from here
# only setup for some models but will prevent having to download the te for

View File

@@ -1,5 +1,6 @@
import torch
from .manager_modules import LinearLayerMemoryManager, ConvLayerMemoryManager
import random
LINEAR_MODULES = [
"Linear",
@@ -60,7 +61,9 @@ class MemoryManager:
return self.module
@classmethod
def attach(cls, module: torch.nn.Module, device: torch.device):
def attach(
cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0
):
if hasattr(module, "_memory_manager"):
# already attached
return
@@ -71,17 +74,44 @@ class MemoryManager:
module._mm_to = module.to
module.to = module._memory_manager.memory_managed_to
modules_processed = []
# attach to all modules
for name, sub_module in module.named_modules():
for child_name, child_module in sub_module.named_modules():
if child_module.__class__.__name__ in LINEAR_MODULES:
# linear
LinearLayerMemoryManager.attach(
child_module, module._memory_manager
)
elif child_module.__class__.__name__ in CONV_MODULES:
# conv
ConvLayerMemoryManager.attach(child_module, module._memory_manager)
if (
child_module.__class__.__name__ in LINEAR_MODULES
and child_module not in modules_processed
):
skip = False
if offload_percent < 1.0:
# randomly skip some modules
if random.random() > offload_percent:
skip = True
if skip:
module._memory_manager.unmanaged_modules.append(child_module)
else:
# linear
LinearLayerMemoryManager.attach(
child_module, module._memory_manager
)
modules_processed.append(child_module)
elif (
child_module.__class__.__name__ in CONV_MODULES
and child_module not in modules_processed
):
skip = False
if offload_percent < 1.0:
# randomly skip some modules
if random.random() > offload_percent:
skip = True
if skip:
module._memory_manager.unmanaged_modules.append(child_module)
else:
# conv
ConvLayerMemoryManager.attach(
child_module, module._memory_manager
)
modules_processed.append(child_module)
elif child_module.__class__.__name__ in UNMANAGED_MODULES or any(
inc in child_module.__class__.__name__
for inc in UNMANAGED_MODULES_INCLUDES