From 25341c4613035f98f436ffd3fde7394d57c58f5b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 7 Mar 2025 17:04:10 -0700 Subject: [PATCH] Got wan 14b training to work on 24GB card. --- toolkit/models/wan21.py | 345 +++++++++++++++++++++++++++++++++------ toolkit/util/quantize.py | 13 +- 2 files changed, 307 insertions(+), 51 deletions(-) diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py index 023c38ed..dee07040 100644 --- a/toolkit/models/wan21.py +++ b/toolkit/models/wan21.py @@ -1,4 +1,5 @@ # WIP, coming soon ish +from functools import partial import torch import yaml from toolkit.accelerator import unwrap_model @@ -34,6 +35,13 @@ from typing import TYPE_CHECKING, List from toolkit.accelerator import unwrap_model from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler from torchvision.transforms import Resize, ToPILImage +from tqdm import tqdm + +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE +# from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from typing import Any, Callable, Dict, List, Optional, Union # for generation only? scheduler_configUniPC = { @@ -73,6 +81,199 @@ scheduler_config = { } +class AggressiveWanUnloadPipeline(WanPipeline): + def __call__( + self: WanPipeline, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, + List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], + PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # unload vae and transformer + vae_device = self.vae.device + transformer_device = self.transformer.device + text_encoder_device = self.text_encoder.device + print("Unloading vae") + self.vae.to("cpu") + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # unload text encoder + print("Unloading text encoder") + self.text_encoder.to("cpu") + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - \ + num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * \ + (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + # unload transformer + # load vae + print("Loading Vae") + self.vae.to(vae_device) + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video( + video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) + + class Wan21(BaseModel): def __init__( self, @@ -118,6 +319,76 @@ class Wan21(BaseModel): if os.path.exists(te_folder_path): base_model_path = model_path + self.print_and_status_update("Loading transformer") + transformer = WanTransformer3DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + + if self.model_config.split_model_over_gpus: + raise ValueError( + "Splitting model over gpus is not supported for Wan2.1 models") + + if not self.model_config.low_vram: + # quantize on the device + transformer.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + raise ValueError( + "Assistant LoRA is not supported for Wan2.1 models currently") + + if self.model_config.lora_path is not None: + raise ValueError( + "Loading LoRA is not supported for Wan2.1 models currently") + + flush() + + if self.model_config.quantize: + print("Quantizing Transformer") + quantization_args = self.model_config.quantize_kwargs + if 'exclude' not in quantization_args: + quantization_args['exclude'] = [] + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = qfloat8 + self.print_and_status_update("Quantizing transformer") + if self.model_config.low_vram: + print("Quantizing blocks") + orig_exclude = copy.deepcopy(quantization_args['exclude']) + # quantize each block + idx = 0 + for block in tqdm(transformer.blocks): + block.to(self.device_torch) + quantize(block, weights=quantization_type, + **quantization_args) + freeze(block) + idx += 1 + flush() + + print("Quantizing the rest") + low_vram_exclude = copy.deepcopy(quantization_args['exclude']) + low_vram_exclude.append('blocks.*') + quantization_args['exclude'] = low_vram_exclude + # quantize the rest + transformer.to(self.device_torch) + quantize(transformer, weights=quantization_type, + **quantization_args) + + quantization_args['exclude'] = orig_exclude + else: + # do it in one go + quantize(transformer, weights=quantization_type, + **quantization_args) + freeze(transformer) + # move it to the cpu for now + transformer.to("cpu") + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + self.print_and_status_update("Loading UMT5EncoderModel") tokenizer = AutoTokenizer.from_pretrained( base_model_path, subfolder="tokenizer", torch_dtype=dtype) @@ -133,46 +404,10 @@ class Wan21(BaseModel): freeze(text_encoder) flush() - self.print_and_status_update("Loading transformer") - transformer = WanTransformer3DModel.from_pretrained( - transformer_path, - subfolder=subfolder, - torch_dtype=dtype, - ) - - if self.model_config.split_model_over_gpus: - raise ValueError( - "Splitting model over gpus is not supported for Wan2.1 models") - - transformer.to(self.quantize_device, dtype=dtype) - flush() - - if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: - raise ValueError( - "Assistant LoRA is not supported for Wan2.1 models currently") - - if self.model_config.lora_path is not None: - raise ValueError( - "Loading LoRA is not supported for Wan2.1 models currently") - - flush() - - if self.model_config.quantize: - quantization_args = self.model_config.quantize_kwargs - if 'exclude' not in quantization_args: - quantization_args['exclude'] = [] - # patch the state dict method - patch_dequantization_on_save(transformer) - quantization_type = qfloat8 - self.print_and_status_update("Quantizing transformer") - quantize(transformer, weights=quantization_type, - **quantization_args) - freeze(transformer) + if self.model_config.low_vram: + print("Moving transformer back to GPU") + # we can move it back to the gpu now transformer.to(self.device_torch) - else: - transformer.to(self.device_torch, dtype=dtype) - - flush() scheduler = Wan21.get_train_scheduler() self.print_and_status_update("Loading VAE") @@ -213,13 +448,23 @@ class Wan21(BaseModel): def get_generation_pipeline(self): scheduler = UniPCMultistepScheduler(**scheduler_configUniPC) - pipeline = WanPipeline( - vae=self.vae, - transformer=self.unet, - text_encoder=self.text_encoder, - tokenizer=self.tokenizer, - scheduler=scheduler, - ) + if self.model_config.low_vram: + pipeline = AggressiveWanUnloadPipeline( + vae=self.vae, + transformer=self.model, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + ) + else: + pipeline = WanPipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + ) + return pipeline def generate_single_image( @@ -231,6 +476,8 @@ class Wan21(BaseModel): generator: torch.Generator, extra: dict, ): + # reactivate progress bar since this is slooooow + pipeline.set_progress_bar_config(disable=False) # todo, figure out how to do video output = pipeline( prompt_embeds=conditional_embeds.text_embeds.to( @@ -252,7 +499,7 @@ class Wan21(BaseModel): # shape = [1, frames, channels, height, width] batch_item = output[0] # list of pil images if gen_config.num_frames > 1: - return batch_item # return the frames. + return batch_item # return the frames. else: # get just the first image img = batch_item[0] @@ -328,7 +575,7 @@ class Wan21(BaseModel): images = torch.stack(image_list) images = images.unsqueeze(2) latents = self.vae.encode(images).latent_dist.sample() - + latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) @@ -338,7 +585,7 @@ class Wan21(BaseModel): latents.device, latents.dtype ) latents = (latents - latents_mean) * latents_std - + latents = latents.to(device, dtype=dtype) return latents diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py index fd7b3178..9d81856b 100644 --- a/toolkit/util/quantize.py +++ b/toolkit/util/quantize.py @@ -7,6 +7,7 @@ from optimum.quanto.tensor import Optimizer, qtype # the quantize function in quanto had a bug where it was using exclude instead of include +Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag'] def quantize( model: torch.nn.Module, @@ -51,5 +52,13 @@ def quantize( continue if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): continue - _quantize_submodule(model, name, m, weights=weights, - activations=activations, optimizer=optimizer) + try: + # check if m is QLinear or QConv2d + if m.__class__.__name__ in Q_MODULES: + continue + else: + _quantize_submodule(model, name, m, weights=weights, + activations=activations, optimizer=optimizer) + except Exception as e: + print(f"Failed to quantize {name}: {e}") + raise e