diff --git a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py index b3f4af3f..a32183ce 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_14b_model.py @@ -27,6 +27,7 @@ from .wan22_5b_model import ( scheduler_config, time_text_monkeypatch, ) +from toolkit.memory_management import MemoryManager from safetensors.torch import load_file, save_file @@ -288,9 +289,12 @@ class Wan2214bModel(Wan21): flush() - if not self.model_config.low_vram: + if self.model_config.low_vram: # quantize on the device - transformer_1.to(self.quantize_device, dtype=dtype) + transformer_1.to('cpu', dtype=dtype) + flush() + else: + transformer_1.to(self.device_torch, dtype=dtype) flush() if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None: @@ -315,9 +319,12 @@ class Wan2214bModel(Wan21): flush() - if not self.model_config.low_vram: + if self.model_config.low_vram: # quantize on the device - transformer_2.to(self.quantize_device, dtype=dtype) + transformer_2.to('cpu', dtype=dtype) + flush() + else: + transformer_2.to(self.device_torch, dtype=dtype) flush() if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None: @@ -331,7 +338,8 @@ class Wan2214bModel(Wan21): transformer_2.to("cpu") else: transformer_2.to(self.device_torch) - + + layer_offloading_transformer = self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0 # make the combined model self.print_and_status_update("Creating DualWanTransformer3DModel") transformer = DualWanTransformer3DModel( @@ -349,6 +357,21 @@ class Wan2214bModel(Wan21): quantize_model(self, transformer) flush() + + if layer_offloading_transformer: + MemoryManager.attach( + transformer_1, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[transformer_1.scale_shift_table] + [block.scale_shift_table for block in transformer_1.blocks] + ) + MemoryManager.attach( + transformer_2, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent, + ignore_modules=[transformer_2.scale_shift_table] + [block.scale_shift_table for block in transformer_2.blocks] + ) + return transformer def get_generation_pipeline(self): @@ -380,24 +403,6 @@ class Wan2214bModel(Wan21): def get_base_model_version(self): return "wan_2.2_14b" - def generate_single_image( - self, - pipeline: Wan22Pipeline, - gen_config: GenerateImageConfig, - conditional_embeds: PromptEmbeds, - unconditional_embeds: PromptEmbeds, - generator: torch.Generator, - extra: dict, - ): - return super().generate_single_image( - pipeline=pipeline, - gen_config=gen_config, - conditional_embeds=conditional_embeds, - unconditional_embeds=unconditional_embeds, - generator=generator, - extra=extra, - ) - def get_noise_prediction( self, latent_model_input: torch.Tensor, diff --git a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py index 8d7db242..d119343c 100644 --- a/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py +++ b/extensions_built_in/diffusion_models/wan22/wan22_pipeline.py @@ -197,6 +197,10 @@ class Wan22Pipeline(WanPipeline): boundary_timestep = None current_model = self.transformer + + if self._aggressive_offload: + # we don't have one loaded yet in aggressive offload mode + current_model = None with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index d376b299..a5c842d4 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -2149,6 +2149,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.torch_profiler is not None: self.torch_profiler.start() did_oom = False + loss_dict = None try: with self.accelerator.accumulate(self.modules_being_trained): loss_dict = self.hook_train_loop(batch_list) @@ -2172,7 +2173,8 @@ class BaseSDTrainProcess(BaseTrainProcess): print_acc(f"# OOM during training step, skipping batch {self.num_consecutive_oom}/3 #") print_acc("################################################") print_acc("") - self.num_consecutive_oom = 0 + else: + self.num_consecutive_oom = 0 if self.torch_profiler is not None: torch.cuda.synchronize() # Make sure all CUDA ops are done self.torch_profiler.stop() @@ -2191,25 +2193,26 @@ class BaseSDTrainProcess(BaseTrainProcess): with torch.no_grad(): # torch.cuda.empty_cache() # if optimizer has get_lrs method, then use it - if hasattr(optimizer, 'get_avg_learning_rate'): - learning_rate = optimizer.get_avg_learning_rate() - elif hasattr(optimizer, 'get_learning_rates'): - learning_rate = optimizer.get_learning_rates()[0] - elif self.train_config.optimizer.lower().startswith('dadaptation') or \ - self.train_config.optimizer.lower().startswith('prodigy'): - learning_rate = ( - optimizer.param_groups[0]["d"] * - optimizer.param_groups[0]["lr"] - ) - else: - learning_rate = optimizer.param_groups[0]['lr'] + if not did_oom and loss_dict is not None: + if hasattr(optimizer, 'get_avg_learning_rate'): + learning_rate = optimizer.get_avg_learning_rate() + elif hasattr(optimizer, 'get_learning_rates'): + learning_rate = optimizer.get_learning_rates()[0] + elif self.train_config.optimizer.lower().startswith('dadaptation') or \ + self.train_config.optimizer.lower().startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] - prog_bar_string = f"lr: {learning_rate:.1e}" - for key, value in loss_dict.items(): - prog_bar_string += f" {key}: {value:.3e}" + prog_bar_string = f"lr: {learning_rate:.1e}" + for key, value in loss_dict.items(): + prog_bar_string += f" {key}: {value:.3e}" - if self.progress_bar is not None: - self.progress_bar.set_postfix_str(prog_bar_string) + if self.progress_bar is not None: + self.progress_bar.set_postfix_str(prog_bar_string) # if the batch is a DataLoaderBatchDTO, then we need to clean it up if isinstance(batch, DataLoaderBatchDTO): diff --git a/toolkit/memory_management/manager.py b/toolkit/memory_management/manager.py index 59b63806..871fd405 100644 --- a/toolkit/memory_management/manager.py +++ b/toolkit/memory_management/manager.py @@ -31,7 +31,7 @@ UNMANAGED_MODULES = [ "Conv3d" ] -UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm"] +UNMANAGED_MODULES_INCLUDES = ["RotaryEmbedding", "Norm", "RotaryPosEmbed"] class MemoryManager: @@ -47,7 +47,11 @@ class MemoryManager: def memory_managed_to(self, *args, **kwargs): # first move all the unmanaged modules for module in self.unmanaged_modules: - module.to(*args, **kwargs) + if isinstance(module, torch.nn.Parameter): + # Parameter cannot move this way + module.data = module.data.to(*args, **kwargs) + else: + module.to(*args, **kwargs) # check for a dtype argument dtype = None if "dtype" in kwargs: @@ -63,7 +67,11 @@ class MemoryManager: @classmethod def attach( - cls, module: torch.nn.Module, device: torch.device, offload_percent: float = 1.0 + cls, + module: torch.nn.Module, + device: torch.device, + offload_percent: float = 1.0, + ignore_modules: list[torch.nn.Module] = [] ): if hasattr(module, "_memory_manager"): # already attached @@ -75,7 +83,12 @@ class MemoryManager: module._mm_to = module.to module.to = module._memory_manager.memory_managed_to - modules_processed = [] + # add ignore modules to unmanaged list + for im in ignore_modules: + module._memory_manager.unmanaged_modules.append(im) + + # count ignore modules as processed + modules_processed = [x for x in ignore_modules] # attach to all modules for name, sub_module in module.named_modules(): for child_name, child_module in sub_module.named_modules(): diff --git a/toolkit/models/wan21/wan21.py b/toolkit/models/wan21/wan21.py index ecdc8f3f..71d3deca 100644 --- a/toolkit/models/wan21/wan21.py +++ b/toolkit/models/wan21/wan21.py @@ -6,6 +6,7 @@ from toolkit.accelerator import unwrap_model from toolkit.basic import flush from toolkit.config_modules import GenerateImageConfig, ModelConfig from toolkit.dequantize import patch_dequantization_on_save +from toolkit.memory_management.manager import MemoryManager from toolkit.models.base_model import BaseModel from toolkit.prompt_utils import PromptEmbeds from transformers import AutoTokenizer, UMT5EncoderModel @@ -353,9 +354,12 @@ class Wan21(BaseModel): raise ValueError( "Splitting model over gpus is not supported for Wan2.1 models") - if not self.model_config.low_vram: + if self.model_config.low_vram: # quantize on the device - transformer.to(self.quantize_device, dtype=dtype) + transformer.to('cpu', dtype=dtype) + flush() + else: + transformer.to(self.device_torch, dtype=dtype) flush() if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: @@ -373,6 +377,13 @@ class Wan21(BaseModel): quantize_model(self, transformer) flush() + if self.model_config.layer_offloading and self.model_config.layer_offloading_transformer_percent > 0: + MemoryManager.attach( + transformer, + self.device_torch, + offload_percent=self.model_config.layer_offloading_transformer_percent + ) + if self.model_config.low_vram: self.print_and_status_update("Moving transformer to CPU") transformer.to('cpu') @@ -423,6 +434,13 @@ class Wan21(BaseModel): quantize(text_encoder, weights=get_qtype(self.model_config.qtype)) freeze(text_encoder) flush() + + if self.model_config.layer_offloading and self.model_config.layer_offloading_text_encoder_percent > 0: + MemoryManager.attach( + text_encoder, + self.device_torch, + offload_percent=self.model_config.layer_offloading_text_encoder_percent + ) if self.model_config.low_vram: print("Moving transformer back to GPU") diff --git a/ui/src/app/jobs/new/options.ts b/ui/src/app/jobs/new/options.ts index a5636f23..07596a0a 100644 --- a/ui/src/app/jobs/new/options.ts +++ b/ui/src/app/jobs/new/options.ts @@ -226,7 +226,7 @@ export const modelArchs: ModelArch[] = [ ], }, disableSections: ['network.conv'], - additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'], + additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading'], accuracyRecoveryAdapters: { // '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors', '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint4.safetensors', @@ -257,7 +257,7 @@ export const modelArchs: ModelArch[] = [ ], }, disableSections: ['network.conv'], - additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage'], + additionalSections: ['sample.ctrl_img', 'datasets.num_frames', 'model.low_vram', 'model.multistage', 'model.layer_offloading'], accuracyRecoveryAdapters: { '4 bit with ARA': 'uint4|ostris/accuracy_recovery_adapters/wan22_14b_i2v_torchao_uint4.safetensors', }, diff --git a/version.py b/version.py index 30854d5d..bfbfd0f9 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -VERSION = "0.7.0" \ No newline at end of file +VERSION = "0.7.1" \ No newline at end of file