Got wan 14b training to work on 24GB card.

This commit is contained in:
Jaret Burkett
2025-03-07 17:04:10 -07:00
parent 391cf80fea
commit 25341c4613
2 changed files with 307 additions and 51 deletions

View File

@@ -1,4 +1,5 @@
# WIP, coming soon ish
from functools import partial
import torch
import yaml
from toolkit.accelerator import unwrap_model
@@ -34,6 +35,13 @@ from typing import TYPE_CHECKING, List
from toolkit.accelerator import unwrap_model
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from torchvision.transforms import Resize, ToPILImage
from tqdm import tqdm
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
# from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from typing import Any, Callable, Dict, List, Optional, Union
# for generation only?
scheduler_configUniPC = {
@@ -73,6 +81,199 @@ scheduler_config = {
}
class AggressiveWanUnloadPipeline(WanPipeline):
def __call__(
self: WanPipeline,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None],
PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# unload vae and transformer
vae_device = self.vae.device
transformer_device = self.transformer.device
text_encoder_device = self.text_encoder.device
print("Unloading vae")
self.vae.to("cpu")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
negative_prompt,
height,
width,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
# unload text encoder
print("Unloading text encoder")
self.text_encoder.to("cpu")
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(
transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - \
num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * \
(noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(
self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop(
"prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
# unload transformer
# load vae
print("Loading Vae")
self.vae.to(vae_device)
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(
video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return WanPipelineOutput(frames=video)
class Wan21(BaseModel):
def __init__(
self,
@@ -118,6 +319,76 @@ class Wan21(BaseModel):
if os.path.exists(te_folder_path):
base_model_path = model_path
self.print_and_status_update("Loading transformer")
transformer = WanTransformer3DModel.from_pretrained(
transformer_path,
subfolder=subfolder,
torch_dtype=dtype,
)
if self.model_config.split_model_over_gpus:
raise ValueError(
"Splitting model over gpus is not supported for Wan2.1 models")
if not self.model_config.low_vram:
# quantize on the device
transformer.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
raise ValueError(
"Assistant LoRA is not supported for Wan2.1 models currently")
if self.model_config.lora_path is not None:
raise ValueError(
"Loading LoRA is not supported for Wan2.1 models currently")
flush()
if self.model_config.quantize:
print("Quantizing Transformer")
quantization_args = self.model_config.quantize_kwargs
if 'exclude' not in quantization_args:
quantization_args['exclude'] = []
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
self.print_and_status_update("Quantizing transformer")
if self.model_config.low_vram:
print("Quantizing blocks")
orig_exclude = copy.deepcopy(quantization_args['exclude'])
# quantize each block
idx = 0
for block in tqdm(transformer.blocks):
block.to(self.device_torch)
quantize(block, weights=quantization_type,
**quantization_args)
freeze(block)
idx += 1
flush()
print("Quantizing the rest")
low_vram_exclude = copy.deepcopy(quantization_args['exclude'])
low_vram_exclude.append('blocks.*')
quantization_args['exclude'] = low_vram_exclude
# quantize the rest
transformer.to(self.device_torch)
quantize(transformer, weights=quantization_type,
**quantization_args)
quantization_args['exclude'] = orig_exclude
else:
# do it in one go
quantize(transformer, weights=quantization_type,
**quantization_args)
freeze(transformer)
# move it to the cpu for now
transformer.to("cpu")
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
self.print_and_status_update("Loading UMT5EncoderModel")
tokenizer = AutoTokenizer.from_pretrained(
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
@@ -133,46 +404,10 @@ class Wan21(BaseModel):
freeze(text_encoder)
flush()
self.print_and_status_update("Loading transformer")
transformer = WanTransformer3DModel.from_pretrained(
transformer_path,
subfolder=subfolder,
torch_dtype=dtype,
)
if self.model_config.split_model_over_gpus:
raise ValueError(
"Splitting model over gpus is not supported for Wan2.1 models")
transformer.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
raise ValueError(
"Assistant LoRA is not supported for Wan2.1 models currently")
if self.model_config.lora_path is not None:
raise ValueError(
"Loading LoRA is not supported for Wan2.1 models currently")
flush()
if self.model_config.quantize:
quantization_args = self.model_config.quantize_kwargs
if 'exclude' not in quantization_args:
quantization_args['exclude'] = []
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type,
**quantization_args)
freeze(transformer)
if self.model_config.low_vram:
print("Moving transformer back to GPU")
# we can move it back to the gpu now
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
scheduler = Wan21.get_train_scheduler()
self.print_and_status_update("Loading VAE")
@@ -213,13 +448,23 @@ class Wan21(BaseModel):
def get_generation_pipeline(self):
scheduler = UniPCMultistepScheduler(**scheduler_configUniPC)
pipeline = WanPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
)
if self.model_config.low_vram:
pipeline = AggressiveWanUnloadPipeline(
vae=self.vae,
transformer=self.model,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
)
else:
pipeline = WanPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
)
return pipeline
def generate_single_image(
@@ -231,6 +476,8 @@ class Wan21(BaseModel):
generator: torch.Generator,
extra: dict,
):
# reactivate progress bar since this is slooooow
pipeline.set_progress_bar_config(disable=False)
# todo, figure out how to do video
output = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
@@ -252,7 +499,7 @@ class Wan21(BaseModel):
# shape = [1, frames, channels, height, width]
batch_item = output[0] # list of pil images
if gen_config.num_frames > 1:
return batch_item # return the frames.
return batch_item # return the frames.
else:
# get just the first image
img = batch_item[0]
@@ -328,7 +575,7 @@ class Wan21(BaseModel):
images = torch.stack(image_list)
images = images.unsqueeze(2)
latents = self.vae.encode(images).latent_dist.sample()
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
@@ -338,7 +585,7 @@ class Wan21(BaseModel):
latents.device, latents.dtype
)
latents = (latents - latents_mean) * latents_std
latents = latents.to(device, dtype=dtype)
return latents

View File

@@ -7,6 +7,7 @@ from optimum.quanto.tensor import Optimizer, qtype
# the quantize function in quanto had a bug where it was using exclude instead of include
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag']
def quantize(
model: torch.nn.Module,
@@ -51,5 +52,13 @@ def quantize(
continue
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
continue
_quantize_submodule(model, name, m, weights=weights,
activations=activations, optimizer=optimizer)
try:
# check if m is QLinear or QConv2d
if m.__class__.__name__ in Q_MODULES:
continue
else:
_quantize_submodule(model, name, m, weights=weights,
activations=activations, optimizer=optimizer)
except Exception as e:
print(f"Failed to quantize {name}: {e}")
raise e