Added initial support for layer offloading wit Wan 2.2 14B models.

This commit is contained in:
Jaret Burkett
2025-10-20 14:54:30 -06:00
parent 8bbaa4e224
commit 76ce757e0c
7 changed files with 93 additions and 50 deletions

View File

@@ -27,6 +27,7 @@ from .wan22_5b_model import (
scheduler_config,
time_text_monkeypatch,
)
from toolkit.memory_management import MemoryManager
from safetensors.torch import load_file, save_file
@@ -288,9 +289,12 @@ class Wan2214bModel(Wan21):
flush()
if not self.model_config.low_vram:
if self.model_config.low_vram:
# quantize on the device
transformer_1.to(self.quantize_device, dtype=dtype)
transformer_1.to('cpu', dtype=dtype)
flush()
else:
transformer_1.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
@@ -315,9 +319,12 @@ class Wan2214bModel(Wan21):
flush()
if not self.model_config.low_vram:
if self.model_config.low_vram:
# quantize on the device
transformer_2.to(self.quantize_device, dtype=dtype)
transformer_2.to('cpu', dtype=dtype)
flush()
else:
transformer_2.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
@@ -331,7 +338,8 @@ class Wan2214bModel(Wan21):
transformer_2.to("cpu")
else:
transformer_2.to(self.device_torch)
layer_offloading_transformer = self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0
# make the combined model
self.print_and_status_update("Creating DualWanTransformer3DModel")
transformer = DualWanTransformer3DModel(
@@ -349,6 +357,21 @@ class Wan2214bModel(Wan21):
quantize_model(self, transformer)
flush()
if layer_offloading_transformer:
MemoryManager.attach(
transformer_1,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent,
ignore_modules=[transformer_1.scale_shift_table] + [block.scale_shift_table for block in transformer_1.blocks]
)
MemoryManager.attach(
transformer_2,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent,
ignore_modules=[transformer_2.scale_shift_table] + [block.scale_shift_table for block in transformer_2.blocks]
)
return transformer
def get_generation_pipeline(self):
@@ -380,24 +403,6 @@ class Wan2214bModel(Wan21):
def get_base_model_version(self):
return "wan_2.2_14b"
def generate_single_image(
self,
pipeline: Wan22Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
return super().generate_single_image(
pipeline=pipeline,
gen_config=gen_config,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
generator=generator,
extra=extra,
)
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,

View File

@@ -197,6 +197,10 @@ class Wan22Pipeline(WanPipeline):
boundary_timestep = None
current_model = self.transformer
if self._aggressive_offload:
# we don't have one loaded yet in aggressive offload mode
current_model = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):

View File

@@ -2149,6 +2149,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.torch_profiler is not None:
self.torch_profiler.start()
did_oom = False
loss_dict = None
try:
with self.accelerator.accumulate(self.modules_being_trained):
loss_dict = self.hook_train_loop(batch_list)
@@ -2172,7 +2173,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #")
print_acc("################################################")
print_acc("")
self.num_consecutive_oom = 0
else:
self.num_consecutive_oom = 0
if self.torch_profiler is not None:
torch.cuda.synchronize() # Make sure all CUDA ops are done
self.torch_profiler.stop()
@@ -2191,25 +2193,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
with torch.no_grad():
# torch.cuda.empty_cache()
# if optimizer has get_lrs method, then use it
if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
if not did_oom and loss_dict is not None:
if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
elif hasattr(optimizer, 'get_learning_rates'):
learning_rate = optimizer.get_learning_rates()[0]
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
self.train_config.optimizer.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]
)
else:
learning_rate = optimizer.param_groups[0]['lr']
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
prog_bar_string = f"lr: {learning_rate:.1e}"
for key, value in loss_dict.items():
prog_bar_string += f" {key}: {value:.3e}"
if self.progress_bar is not None:
self.progress_bar.set_postfix_str(prog_bar_string)
if self.progress_bar is not None:
self.progress_bar.set_postfix_str(prog_bar_string)
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):

View File

@@ -31,7 +31,7 @@ UNMANAGED_MODULES = [
"Conv3d"
]
UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm"]
UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm", "RotaryPosEmbed"]
class MemoryManager:
@@ -47,7 +47,11 @@ class MemoryManager:
def memory_managed_to(self, *args, **kwargs):
# first move all the unmanaged modules
for module in self.unmanaged_modules:
module.to(*args, **kwargs)
if isinstance(module, torch.nn.Parameter):
# Parameter cannot move this way
module.data = module.data.to(*args, **kwargs)
else:
module.to(*args, **kwargs)
# check for a dtype argument
dtype = None
if "dtype" in kwargs:
@@ -63,7 +67,11 @@ class MemoryManager:
@classmethod
def attach(
cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0
cls,
module: torch.nn.Module,
device: torch.device,
offload_percent: float = 1.0,
ignore_modules: list[torch.nn.Module] = []
):
if hasattr(module, "_memory_manager"):
# already attached
@@ -75,7 +83,12 @@ class MemoryManager:
module._mm_to = module.to
module.to = module._memory_manager.memory_managed_to
modules_processed = []
# add ignore modules to unmanaged list
for im in ignore_modules:
module._memory_manager.unmanaged_modules.append(im)
# count ignore modules as processed
modules_processed = [x for x in ignore_modules]
# attach to all modules
for name, sub_module in module.named_modules():
for child_name, child_module in sub_module.named_modules():

View File

@@ -6,6 +6,7 @@ from toolkit.accelerator import unwrap_model
from toolkit.basic import flush
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.memory_management.manager import MemoryManager
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
from transformers import AutoTokenizer, UMT5EncoderModel
@@ -353,9 +354,12 @@ class Wan21(BaseModel):
raise ValueError(
"Splitting model over gpus is not supported for Wan2.1 models")
if not self.model_config.low_vram:
if self.model_config.low_vram:
# quantize on the device
transformer.to(self.quantize_device, dtype=dtype)
transformer.to('cpu', dtype=dtype)
flush()
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
@@ -373,6 +377,13 @@ class Wan21(BaseModel):
quantize_model(self, transformer)
flush()
if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0:
MemoryManager.attach(
transformer,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent
)
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer to CPU")
transformer.to('cpu')
@@ -423,6 +434,13 @@ class Wan21(BaseModel):
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
freeze(text_encoder)
flush()
if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0:
MemoryManager.attach(
text_encoder,
self.device_torch,
offload_percent=self.model_config.layer_offloading_text_encoder_percent
)
if self.model_config.low_vram:
print("Moving transformer back to GPU")

View File

@@ -226,7 +226,7 @@ export const modelArchs: ModelArch[] = [
],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'],
additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading'],
accuracyRecoveryAdapters: {
// '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
'4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors',
@@ -257,7 +257,7 @@ export const modelArchs: ModelArch[] = [
],
},
disableSections: ['network.conv'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'],
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading'],
accuracyRecoveryAdapters: {
'4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors',
},

View File

@@ -1 +1 @@
VERSION = "0.7.0"
VERSION = "0.7.1"