mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 10:11:14 +00:00
Added initial support for layer offloading wit Wan 2.2 14B models.
This commit is contained in:
@@ -27,6 +27,7 @@ from .wan22_5b_model import (
|
||||
scheduler_config,
|
||||
time_text_monkeypatch,
|
||||
)
|
||||
from toolkit.memory_management import MemoryManager
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
@@ -288,9 +289,12 @@ class Wan2214bModel(Wan21):
|
||||
|
||||
flush()
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
if self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer_1.to(self.quantize_device, dtype=dtype)
|
||||
transformer_1.to('cpu', dtype=dtype)
|
||||
flush()
|
||||
else:
|
||||
transformer_1.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
|
||||
@@ -315,9 +319,12 @@ class Wan2214bModel(Wan21):
|
||||
|
||||
flush()
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
if self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer_2.to(self.quantize_device, dtype=dtype)
|
||||
transformer_2.to('cpu', dtype=dtype)
|
||||
flush()
|
||||
else:
|
||||
transformer_2.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
|
||||
@@ -331,7 +338,8 @@ class Wan2214bModel(Wan21):
|
||||
transformer_2.to("cpu")
|
||||
else:
|
||||
transformer_2.to(self.device_torch)
|
||||
|
||||
|
||||
layer_offloading_transformer = self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0
|
||||
# make the combined model
|
||||
self.print_and_status_update("Creating DualWanTransformer3DModel")
|
||||
transformer = DualWanTransformer3DModel(
|
||||
@@ -349,6 +357,21 @@ class Wan2214bModel(Wan21):
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
|
||||
if layer_offloading_transformer:
|
||||
MemoryManager.attach(
|
||||
transformer_1,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_transformer_percent,
|
||||
ignore_modules=[transformer_1.scale_shift_table] + [block.scale_shift_table for block in transformer_1.blocks]
|
||||
)
|
||||
MemoryManager.attach(
|
||||
transformer_2,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_transformer_percent,
|
||||
ignore_modules=[transformer_2.scale_shift_table] + [block.scale_shift_table for block in transformer_2.blocks]
|
||||
)
|
||||
|
||||
return transformer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
@@ -380,24 +403,6 @@ class Wan2214bModel(Wan21):
|
||||
def get_base_model_version(self):
|
||||
return "wan_2.2_14b"
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: Wan22Pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
return super().generate_single_image(
|
||||
pipeline=pipeline,
|
||||
gen_config=gen_config,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
generator=generator,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
|
||||
@@ -197,6 +197,10 @@ class Wan22Pipeline(WanPipeline):
|
||||
boundary_timestep = None
|
||||
|
||||
current_model = self.transformer
|
||||
|
||||
if self._aggressive_offload:
|
||||
# we don't have one loaded yet in aggressive offload mode
|
||||
current_model = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
|
||||
Reference in New Issue
Block a user