mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Added initial support for layer offloading wit Wan 2.2 14B models.
This commit is contained in:
@@ -27,6 +27,7 @@ from .wan22_5b_model import (
|
||||
scheduler_config,
|
||||
time_text_monkeypatch,
|
||||
)
|
||||
from toolkit.memory_management import MemoryManager
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
@@ -288,9 +289,12 @@ class Wan2214bModel(Wan21):
|
||||
|
||||
flush()
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
if self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer_1.to(self.quantize_device, dtype=dtype)
|
||||
transformer_1.to('cpu', dtype=dtype)
|
||||
flush()
|
||||
else:
|
||||
transformer_1.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
|
||||
@@ -315,9 +319,12 @@ class Wan2214bModel(Wan21):
|
||||
|
||||
flush()
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
if self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer_2.to(self.quantize_device, dtype=dtype)
|
||||
transformer_2.to('cpu', dtype=dtype)
|
||||
flush()
|
||||
else:
|
||||
transformer_2.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
|
||||
@@ -331,7 +338,8 @@ class Wan2214bModel(Wan21):
|
||||
transformer_2.to("cpu")
|
||||
else:
|
||||
transformer_2.to(self.device_torch)
|
||||
|
||||
|
||||
layer_offloading_transformer = self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0
|
||||
# make the combined model
|
||||
self.print_and_status_update("Creating DualWanTransformer3DModel")
|
||||
transformer = DualWanTransformer3DModel(
|
||||
@@ -349,6 +357,21 @@ class Wan2214bModel(Wan21):
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
|
||||
if layer_offloading_transformer:
|
||||
MemoryManager.attach(
|
||||
transformer_1,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_transformer_percent,
|
||||
ignore_modules=[transformer_1.scale_shift_table] + [block.scale_shift_table for block in transformer_1.blocks]
|
||||
)
|
||||
MemoryManager.attach(
|
||||
transformer_2,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_transformer_percent,
|
||||
ignore_modules=[transformer_2.scale_shift_table] + [block.scale_shift_table for block in transformer_2.blocks]
|
||||
)
|
||||
|
||||
return transformer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
@@ -380,24 +403,6 @@ class Wan2214bModel(Wan21):
|
||||
def get_base_model_version(self):
|
||||
return "wan_2.2_14b"
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: Wan22Pipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
return super().generate_single_image(
|
||||
pipeline=pipeline,
|
||||
gen_config=gen_config,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
generator=generator,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
|
||||
@@ -197,6 +197,10 @@ class Wan22Pipeline(WanPipeline):
|
||||
boundary_timestep = None
|
||||
|
||||
current_model = self.transformer
|
||||
|
||||
if self._aggressive_offload:
|
||||
# we don't have one loaded yet in aggressive offload mode
|
||||
current_model = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
|
||||
@@ -2149,6 +2149,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.torch_profiler is not None:
|
||||
self.torch_profiler.start()
|
||||
did_oom = False
|
||||
loss_dict = None
|
||||
try:
|
||||
with self.accelerator.accumulate(self.modules_being_trained):
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
@@ -2172,7 +2173,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #")
|
||||
print_acc("################################################")
|
||||
print_acc("")
|
||||
self.num_consecutive_oom = 0
|
||||
else:
|
||||
self.num_consecutive_oom = 0
|
||||
if self.torch_profiler is not None:
|
||||
torch.cuda.synchronize() # Make sure all CUDA ops are done
|
||||
self.torch_profiler.stop()
|
||||
@@ -2191,25 +2193,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
with torch.no_grad():
|
||||
# torch.cuda.empty_cache()
|
||||
# if optimizer has get_lrs method, then use it
|
||||
if hasattr(optimizer, 'get_avg_learning_rate'):
|
||||
learning_rate = optimizer.get_avg_learning_rate()
|
||||
elif hasattr(optimizer, 'get_learning_rates'):
|
||||
learning_rate = optimizer.get_learning_rates()[0]
|
||||
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
if not did_oom and loss_dict is not None:
|
||||
if hasattr(optimizer, 'get_avg_learning_rate'):
|
||||
learning_rate = optimizer.get_avg_learning_rate()
|
||||
elif hasattr(optimizer, 'get_learning_rates'):
|
||||
learning_rate = optimizer.get_learning_rates()[0]
|
||||
elif self.train_config.optimizer.lower().startswith('dadaptation') or \
|
||||
self.train_config.optimizer.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
prog_bar_string = f"lr: {learning_rate:.1e}"
|
||||
for key, value in loss_dict.items():
|
||||
prog_bar_string += f" {key}: {value:.3e}"
|
||||
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
if self.progress_bar is not None:
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
|
||||
@@ -31,7 +31,7 @@ UNMANAGED_MODULES = [
|
||||
"Conv3d"
|
||||
]
|
||||
|
||||
UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm"]
|
||||
UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm", "RotaryPosEmbed"]
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
@@ -47,7 +47,11 @@ class MemoryManager:
|
||||
def memory_managed_to(self, *args, **kwargs):
|
||||
# first move all the unmanaged modules
|
||||
for module in self.unmanaged_modules:
|
||||
module.to(*args, **kwargs)
|
||||
if isinstance(module, torch.nn.Parameter):
|
||||
# Parameter cannot move this way
|
||||
module.data = module.data.to(*args, **kwargs)
|
||||
else:
|
||||
module.to(*args, **kwargs)
|
||||
# check for a dtype argument
|
||||
dtype = None
|
||||
if "dtype" in kwargs:
|
||||
@@ -63,7 +67,11 @@ class MemoryManager:
|
||||
|
||||
@classmethod
|
||||
def attach(
|
||||
cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0
|
||||
cls,
|
||||
module: torch.nn.Module,
|
||||
device: torch.device,
|
||||
offload_percent: float = 1.0,
|
||||
ignore_modules: list[torch.nn.Module] = []
|
||||
):
|
||||
if hasattr(module, "_memory_manager"):
|
||||
# already attached
|
||||
@@ -75,7 +83,12 @@ class MemoryManager:
|
||||
module._mm_to = module.to
|
||||
module.to = module._memory_manager.memory_managed_to
|
||||
|
||||
modules_processed = []
|
||||
# add ignore modules to unmanaged list
|
||||
for im in ignore_modules:
|
||||
module._memory_manager.unmanaged_modules.append(im)
|
||||
|
||||
# count ignore modules as processed
|
||||
modules_processed = [x for x in ignore_modules]
|
||||
# attach to all modules
|
||||
for name, sub_module in module.named_modules():
|
||||
for child_name, child_module in sub_module.named_modules():
|
||||
|
||||
@@ -6,6 +6,7 @@ from toolkit.accelerator import unwrap_model
|
||||
from toolkit.basic import flush
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.dequantize import patch_dequantization_on_save
|
||||
from toolkit.memory_management.manager import MemoryManager
|
||||
from toolkit.models.base_model import BaseModel
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
@@ -353,9 +354,12 @@ class Wan21(BaseModel):
|
||||
raise ValueError(
|
||||
"Splitting model over gpus is not supported for Wan2.1 models")
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
if self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
transformer.to('cpu', dtype=dtype)
|
||||
flush()
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||
@@ -373,6 +377,13 @@ class Wan21(BaseModel):
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0:
|
||||
MemoryManager.attach(
|
||||
transformer,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_transformer_percent
|
||||
)
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer to CPU")
|
||||
transformer.to('cpu')
|
||||
@@ -423,6 +434,13 @@ class Wan21(BaseModel):
|
||||
quantize(text_encoder, weights=get_qtype(self.model_config.qtype))
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0:
|
||||
MemoryManager.attach(
|
||||
text_encoder,
|
||||
self.device_torch,
|
||||
offload_percent=self.model_config.layer_offloading_text_encoder_percent
|
||||
)
|
||||
|
||||
if self.model_config.low_vram:
|
||||
print("Moving transformer back to GPU")
|
||||
|
||||
@@ -226,7 +226,7 @@ export const modelArchs: ModelArch[] = [
|
||||
],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'],
|
||||
additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading'],
|
||||
accuracyRecoveryAdapters: {
|
||||
// '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
|
||||
'4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors',
|
||||
@@ -257,7 +257,7 @@ export const modelArchs: ModelArch[] = [
|
||||
],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'],
|
||||
additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading'],
|
||||
accuracyRecoveryAdapters: {
|
||||
'4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors',
|
||||
},
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.7.0"
|
||||
VERSION = "0.7.1"
|
||||
Reference in New Issue
Block a user