mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added initial support for layer offloading wit Wan 2.2 14B models.
This commit is contained in:
@@ -6,6 +6,7 @@ from toolkit.accelerator import unwrap_model
|
||||
from toolkit.basic import flush
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.memory_management.manager import MemoryManager
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
@@ -353,9 +354,12 @@ class Wan21(BaseModel):
|
||||
raise ValueError(
|
||||
"Splitting model over gpus is not supported for Wan2.1 models")
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
if self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
transformer.to('cpu', dtype=dtype)
|
||||
flush()
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||
@@ -373,6 +377,13 @@ class Wan21(BaseModel):
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0:
|
||||
MemoryManager.attach(
|
||||
transformer,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_transformer_percent
|
||||
)
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
transformer.to('cpu')
|
||||
@@ -423,6 +434,13 @@ class Wan21(BaseModel):
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0:
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent
|
||||
)
|
||||
|
||||
if self.model_config.low_vram:
|
||||
print("Moving transformer back to GPU")
|
||||
|
||||
Reference in New Issue
Block a user