Added initial support for layer offloading wit Wan 2.2 14B models.

This commit is contained in:
Jaret Burkett
2025-10-20 14:54:30 -06:00
parent 8bbaa4e224
commit 76ce757e0c
7 changed files with 93 additions and 50 deletions

View File

@@ -27,6 +27,7 @@ from .wan22_5b_model import (
scheduler_config,
time_text_monkeypatch,
)
from toolkit.memory_management import MemoryManager
from safetensors.torch import load_file, save_file
@@ -288,9 +289,12 @@ class Wan2214bModel(Wan21):
flush()
if not self.model_config.low_vram:
if self.model_config.low_vram:
# quantize on the device
transformer_1.to(self.quantize_device, dtype=dtype)
transformer_1.to('cpu', dtype=dtype)
flush()
else:
transformer_1.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
@@ -315,9 +319,12 @@ class Wan2214bModel(Wan21):
flush()
if not self.model_config.low_vram:
if self.model_config.low_vram:
# quantize on the device
transformer_2.to(self.quantize_device, dtype=dtype)
transformer_2.to('cpu', dtype=dtype)
flush()
else:
transformer_2.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
@@ -331,7 +338,8 @@ class Wan2214bModel(Wan21):
transformer_2.to("cpu")
else:
transformer_2.to(self.device_torch)
layer_offloading_transformer = self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0
# make the combined model
self.print_and_status_update("Creating DualWanTransformer3DModel")
transformer = DualWanTransformer3DModel(
@@ -349,6 +357,21 @@ class Wan2214bModel(Wan21):
quantize_model(self, transformer)
flush()
if layer_offloading_transformer:
MemoryManager.attach(
transformer_1,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent,
ignore_modules=[transformer_1.scale_shift_table] + [block.scale_shift_table for block in transformer_1.blocks]
)
MemoryManager.attach(
transformer_2,
self.device_torch,
offload_percent=self.model_config.layer_offloading_transformer_percent,
ignore_modules=[transformer_2.scale_shift_table] + [block.scale_shift_table for block in transformer_2.blocks]
)
return transformer
def get_generation_pipeline(self):
@@ -380,24 +403,6 @@ class Wan2214bModel(Wan21):
def get_base_model_version(self):
return "wan_2.2_14b"
def generate_single_image(
self,
pipeline: Wan22Pipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
return super().generate_single_image(
pipeline=pipeline,
gen_config=gen_config,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
generator=generator,
extra=extra,
)
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,

View File

@@ -197,6 +197,10 @@ class Wan22Pipeline(WanPipeline):
boundary_timestep = None
current_model = self.transformer
if self._aggressive_offload:
# we don't have one loaded yet in aggressive offload mode
current_model = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):