Added initial support for layer offloading wit Wan 2.2 14B models.

This commit is contained in:
Jaret Burkett
2025-10-20 14:54:30 -06:00
parent 8bbaa4e224
commit 76ce757e0c
7 changed files with 93 additions and 50 deletions

View File

@@ -31,7 +31,7 @@ UNMANAGED_MODULES = [
"Conv3d"
]
UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm"]
UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm", "RotaryPosEmbed"]
class MemoryManager:
@@ -47,7 +47,11 @@ class MemoryManager:
def memory_managed_to(self, *args, **kwargs):
# first move all the unmanaged modules
for module in self.unmanaged_modules:
module.to(*args, **kwargs)
if isinstance(module, torch.nn.Parameter):
# Parameter cannot move this way
module.data = module.data.to(*args, **kwargs)
else:
module.to(*args, **kwargs)
# check for a dtype argument
dtype = None
if "dtype" in kwargs:
@@ -63,7 +67,11 @@ class MemoryManager:
@classmethod
def attach(
cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0
cls,
module: torch.nn.Module,
device: torch.device,
offload_percent: float = 1.0,
ignore_modules: list[torch.nn.Module] = []
):
if hasattr(module, "_memory_manager"):
# already attached
@@ -75,7 +83,12 @@ class MemoryManager:
module._mm_to = module.to
module.to = module._memory_manager.memory_managed_to
modules_processed = []
# add ignore modules to unmanaged list
for im in ignore_modules:
module._memory_manager.unmanaged_modules.append(im)
# count ignore modules as processed
modules_processed = [x for x in ignore_modules]
# attach to all modules
for name, sub_module in module.named_modules():
for child_name, child_module in sub_module.named_modules():

View File

@@ -6,6 +6,7 @@ from toolkit.accelerator import unwrap_model
from toolkit.basic import flush
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.memory_management.manager import MemoryManager
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
from transformers import AutoTokenizer, UMT5EncoderModel
@@ -353,9 +354,12 @@ class Wan21(BaseModel):
raise ValueError(
"Splitting model over gpus is not supported for Wan2.1 models")
if not self.model_config.low_vram:
if self.model_config.low_vram:
# quantize on the device
transformer.to(self.quantize_device, dtype=dtype)
transformer.to('cpu', dtype=dtype)
flush()
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
@@ -373,6 +377,13 @@ class Wan21(BaseModel):
quantize_model(self, transformer)
flush()
if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0:
MemoryManager.attach(
transformer,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent
)
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer to CPU")
transformer.to('cpu')
@@ -423,6 +434,13 @@ class Wan21(BaseModel):
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0:
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent
)
if self.model_config.low_vram:
print("Moving transformer back to GPU")