mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Got wan 14b training to work on 24GB card.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# WIP, coming soon ish
|
||||
from functools import partial
|
||||
import torch
|
||||
import yaml
|
||||
from toolkit.accelerator import unwrap_model
|
||||
@@ -34,6 +35,13 @@ from typing import TYPE_CHECKING, List
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
|
||||
from torchvision.transforms import Resize, ToPILImage
|
||||
from tqdm import tqdm
|
||||
|
||||
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||
from diffusers.pipelines.wan.pipeline_wan import XLA_AVAILABLE
|
||||
# from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
# for generation only?
|
||||
scheduler_configUniPC = {
|
||||
@@ -73,6 +81,199 @@ scheduler_config = {
|
||||
}
|
||||
|
||||
|
||||
class AggressiveWanUnloadPipeline(WanPipeline):
|
||||
def __call__(
|
||||
self: WanPipeline,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator,
|
||||
List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None],
|
||||
PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# unload vae and transformer
|
||||
vae_device = self.vae.device
|
||||
transformer_device = self.transformer.device
|
||||
text_encoder_device = self.text_encoder.device
|
||||
print("Unloading vae")
|
||||
self.vae.to("cpu")
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# unload text encoder
|
||||
print("Unloading text encoder")
|
||||
self.text_encoder.to("cpu")
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(
|
||||
transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - \
|
||||
num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * \
|
||||
(noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(
|
||||
self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop(
|
||||
"prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop(
|
||||
"negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
# unload transformer
|
||||
# load vae
|
||||
print("Loading Vae")
|
||||
self.vae.to(vae_device)
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(
|
||||
video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return WanPipelineOutput(frames=video)
|
||||
|
||||
|
||||
class Wan21(BaseModel):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -118,6 +319,76 @@ class Wan21(BaseModel):
|
||||
if os.path.exists(te_folder_path):
|
||||
base_model_path = model_path
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
transformer = WanTransformer3DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=subfolder,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
if self.model_config.split_model_over_gpus:
|
||||
raise ValueError(
|
||||
"Splitting model over gpus is not supported for Wan2.1 models")
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||
raise ValueError(
|
||||
"Assistant LoRA is not supported for Wan2.1 models currently")
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError(
|
||||
"Loading LoRA is not supported for Wan2.1 models currently")
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
print("Quantizing Transformer")
|
||||
quantization_args = self.model_config.quantize_kwargs
|
||||
if 'exclude' not in quantization_args:
|
||||
quantization_args['exclude'] = []
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
if self.model_config.low_vram:
|
||||
print("Quantizing blocks")
|
||||
orig_exclude = copy.deepcopy(quantization_args['exclude'])
|
||||
# quantize each block
|
||||
idx = 0
|
||||
for block in tqdm(transformer.blocks):
|
||||
block.to(self.device_torch)
|
||||
quantize(block, weights=quantization_type,
|
||||
**quantization_args)
|
||||
freeze(block)
|
||||
idx += 1
|
||||
flush()
|
||||
|
||||
print("Quantizing the rest")
|
||||
low_vram_exclude = copy.deepcopy(quantization_args['exclude'])
|
||||
low_vram_exclude.append('blocks.*')
|
||||
quantization_args['exclude'] = low_vram_exclude
|
||||
# quantize the rest
|
||||
transformer.to(self.device_torch)
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**quantization_args)
|
||||
|
||||
quantization_args['exclude'] = orig_exclude
|
||||
else:
|
||||
# do it in one go
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**quantization_args)
|
||||
freeze(transformer)
|
||||
# move it to the cpu for now
|
||||
transformer.to("cpu")
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading UMT5EncoderModel")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
||||
@@ -133,46 +404,10 @@ class Wan21(BaseModel):
|
||||
freeze(text_encoder)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
transformer = WanTransformer3DModel.from_pretrained(
|
||||
transformer_path,
|
||||
subfolder=subfolder,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
if self.model_config.split_model_over_gpus:
|
||||
raise ValueError(
|
||||
"Splitting model over gpus is not supported for Wan2.1 models")
|
||||
|
||||
transformer.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
|
||||
raise ValueError(
|
||||
"Assistant LoRA is not supported for Wan2.1 models currently")
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError(
|
||||
"Loading LoRA is not supported for Wan2.1 models currently")
|
||||
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
quantization_args = self.model_config.quantize_kwargs
|
||||
if 'exclude' not in quantization_args:
|
||||
quantization_args['exclude'] = []
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
quantization_type = qfloat8
|
||||
self.print_and_status_update("Quantizing transformer")
|
||||
quantize(transformer, weights=quantization_type,
|
||||
**quantization_args)
|
||||
freeze(transformer)
|
||||
if self.model_config.low_vram:
|
||||
print("Moving transformer back to GPU")
|
||||
# we can move it back to the gpu now
|
||||
transformer.to(self.device_torch)
|
||||
else:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
scheduler = Wan21.get_train_scheduler()
|
||||
self.print_and_status_update("Loading VAE")
|
||||
@@ -213,13 +448,23 @@ class Wan21(BaseModel):
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = UniPCMultistepScheduler(**scheduler_configUniPC)
|
||||
pipeline = WanPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
if self.model_config.low_vram:
|
||||
pipeline = AggressiveWanUnloadPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.model,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
pipeline = WanPipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.unet,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
def generate_single_image(
|
||||
@@ -231,6 +476,8 @@ class Wan21(BaseModel):
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
# reactivate progress bar since this is slooooow
|
||||
pipeline.set_progress_bar_config(disable=False)
|
||||
# todo, figure out how to do video
|
||||
output = pipeline(
|
||||
prompt_embeds=conditional_embeds.text_embeds.to(
|
||||
@@ -252,7 +499,7 @@ class Wan21(BaseModel):
|
||||
# shape = [1, frames, channels, height, width]
|
||||
batch_item = output[0] # list of pil images
|
||||
if gen_config.num_frames > 1:
|
||||
return batch_item # return the frames.
|
||||
return batch_item # return the frames.
|
||||
else:
|
||||
# get just the first image
|
||||
img = batch_item[0]
|
||||
@@ -328,7 +575,7 @@ class Wan21(BaseModel):
|
||||
images = torch.stack(image_list)
|
||||
images = images.unsqueeze(2)
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
@@ -338,7 +585,7 @@ class Wan21(BaseModel):
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = (latents - latents_mean) * latents_std
|
||||
|
||||
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@@ -7,6 +7,7 @@ from optimum.quanto.tensor import Optimizer, qtype
|
||||
|
||||
# the quantize function in quanto had a bug where it was using exclude instead of include
|
||||
|
||||
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag']
|
||||
|
||||
def quantize(
|
||||
model: torch.nn.Module,
|
||||
@@ -51,5 +52,13 @@ def quantize(
|
||||
continue
|
||||
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude):
|
||||
continue
|
||||
_quantize_submodule(model, name, m, weights=weights,
|
||||
activations=activations, optimizer=optimizer)
|
||||
try:
|
||||
# check if m is QLinear or QConv2d
|
||||
if m.__class__.__name__ in Q_MODULES:
|
||||
continue
|
||||
else:
|
||||
_quantize_submodule(model, name, m, weights=weights,
|
||||
activations=activations, optimizer=optimizer)
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize {name}: {e}")
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user