Added training for Wan2.1. Not finalized, wait.

This commit is contained in:
Jaret Burkett
2025-03-07 13:53:44 -07:00
parent 4e3bda7c70
commit 391cf80fea
7 changed files with 393 additions and 50 deletions

View File

@@ -310,6 +310,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
refiner_start_at=sample_config.refiner_start_at,
extra_values=sample_config.extra_values,
logger=self.logger,
num_frames=sample_config.num_frames,
fps=sample_config.fps,
**extra_args
))
@@ -909,13 +911,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
raise ValueError("Batch must be provided for consistent noise")
noise = self.get_consistent_noise(latents, batch, dtype=dtype)
else:
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if hasattr(self.sd, 'get_latent_noise_from_latents'):
noise = self.sd.get_latent_noise_from_latents(latents).to(self.device_torch, dtype=dtype)
else:
# get noise
noise = self.sd.get_latent_noise(
height=latents.shape[2],
width=latents.shape[3],
batch_size=batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
if self.train_config.random_noise_shift > 0.0:
# get random noise -1 to 1
@@ -929,9 +934,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
noise += noise_shift
# standardize the noise
std = noise.std(dim=(2, 3), keepdim=True)
normalizer = 1 / (std + 1e-6)
noise = noise * normalizer
# shouldnt be needed?
# std = noise.std(dim=(2, 3), keepdim=True)
# normalizer = 1 / (std + 1e-6)
# noise = noise * normalizer
return noise

View File

@@ -1,7 +1,7 @@
torch==2.5.1
torchvision==0.20.1
safetensors
git+https://github.com/huggingface/diffusers@24c062aaa19f5626d03d058daf8afffa2dfd49f7
git+https://github.com/huggingface/diffusers@363d1ab7e24c5ed6c190abb00df66d9edb74383b
transformers==4.49.0
lycoris-lora==1.8.3
flatten_json
@@ -32,4 +32,4 @@ sentencepiece
huggingface_hub
peft
gradio
python-slugify
python-slugify

View File

@@ -57,6 +57,11 @@ class SampleConfig:
self.refiner_start_at = kwargs.get('refiner_start_at',
0.5) # step to start using refiner on sample if it exists
self.extra_values = kwargs.get('extra_values', [])
self.num_frames = kwargs.get('num_frames', 1)
self.fps: int = kwargs.get('fps', 16)
if self.num_frames > 0 and self.ext not in ['webp']:
print("Changing sample extention to animated webp")
self.ext = 'webp'
class LormModuleSettingsConfig:
@@ -775,6 +780,8 @@ class GenerateImageConfig:
refiner_start_at: float = 0.5, # start at this percentage of a step. 0.0 to 1.0 . 1.0 is the end
extra_values: List[float] = None, # extra values to save with prompt file
logger: Optional[EmptyLogger] = None,
num_frames: int = 1,
fps: int = 15,
):
self.width: int = width
self.height: int = height
@@ -803,6 +810,9 @@ class GenerateImageConfig:
self.extra_kwargs = extra_kwargs if extra_kwargs is not None else {}
self.refiner_start_at = refiner_start_at
self.extra_values = extra_values if extra_values is not None else []
self.num_frames = num_frames
self.fps = fps
# prompt string will override any settings above
self._process_prompt_string()
@@ -869,11 +879,30 @@ class GenerateImageConfig:
# make parent dirs
os.makedirs(self.output_folder, exist_ok=True)
self.set_gen_time()
# TODO save image gen header info for A1111 and us, our seeds probably wont match
image.save(self.get_image_path(count, max_count))
# do prompt file
if self.add_prompt_file:
self.save_prompt_file(count, max_count)
if isinstance(image, list):
# video
if self.num_frames == 1:
raise ValueError(f"Expected 1 img but got a list {len(image)}")
if self.output_ext == 'webp':
# save as animated webp
duration = 1000 // self.fps # Convert fps to milliseconds per frame
image[0].save(
self.get_image_path(count, max_count),
format='WEBP',
append_images=image[1:],
save_all=True,
duration=duration, # Duration per frame in milliseconds
loop=0, # 0 means loop forever
quality=80 # Quality setting (0-100)
)
else:
raise ValueError(f"Unsupported video format {self.output_ext}")
else:
# TODO save image gen header info for A1111 and us, our seeds probably wont match
image.save(self.get_image_path(count, max_count))
# do prompt file
if self.add_prompt_file:
self.save_prompt_file(count, max_count)
def save_prompt_file(self, count: int = 0, max_count=0):
# save prompt file
@@ -972,6 +1001,10 @@ class GenerateImageConfig:
elif flag == 'extra_values':
# split by comma
self.extra_values = [float(val) for val in content.split(',')]
elif flag == 'frames':
self.num_frames = int(content)
elif flag == 'fps':
self.fps = int(content)
def post_process_embeddings(
self,

View File

@@ -349,6 +349,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'):
if "transformer_blocks" not in lora_name:
skip = True
if self.transformer_only and is_unet and hasattr(root_module, 'blocks'):
if "blocks" not in lora_name:
skip = True
if (is_linear or is_conv2d) and not skip:

View File

@@ -625,6 +625,15 @@ class BaseModel:
)
noise = apply_noise_offset(noise, noise_offset)
return noise
def get_latent_noise_from_latents(
self,
latents: torch.Tensor,
noise_offset=0.0
):
noise = torch.randn_like(latents)
noise = apply_noise_offset(noise, noise_offset)
return noise
def add_noise(
self,

View File

@@ -1,26 +1,76 @@
# WIP, coming soon ish
import torch
import yaml
from toolkit.accelerator import unwrap_model
from toolkit.basic import flush
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
from toolkit.paths import REPOS_ROOT
import sys
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import AutoencoderKLWan, WanPipeline, WanTransformer3DModel
import os
import gc
import logging
import math
import os
import random
import sys
import types
from contextlib import contextmanager
from functools import partial
import weakref
import torch
import torch.cuda.amp as amp
import torch.distributed as dist
from tqdm import tqdm
import yaml
from toolkit.basic import flush
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.dequantize import patch_dequantization_on_save
from toolkit.models.base_model import BaseModel
from toolkit.prompt_utils import PromptEmbeds
import os
import copy
from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch
import torch
from optimum.quanto import freeze, qfloat8, QTensor, qint4
from toolkit.util.quantize import quantize
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
from typing import TYPE_CHECKING, List
from toolkit.accelerator import unwrap_model
from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler
from torchvision.transforms import Resize, ToPILImage
# for generation only?
scheduler_configUniPC = {
"_class_name": "UniPCMultistepScheduler",
"_diffusers_version": "0.33.0.dev0",
"beta_end": 0.02,
"beta_schedule": "linear",
"beta_start": 0.0001,
"disable_corrector": [],
"dynamic_thresholding_ratio": 0.995,
"final_sigmas_type": "zero",
"flow_shift": 3.0,
"lower_order_final": True,
"num_train_timesteps": 1000,
"predict_x0": True,
"prediction_type": "flow_prediction",
"rescale_betas_zero_snr": False,
"sample_max_value": 1.0,
"solver_order": 2,
"solver_p": None,
"solver_type": "bh2",
"steps_offset": 0,
"thresholding": False,
"timestep_spacing": "linspace",
"trained_betas": None,
"use_beta_sigmas": False,
"use_exponential_sigmas": False,
"use_flow_sigmas": True,
"use_karras_sigmas": False
}
# for training. I think it is right
scheduler_config = {
"num_train_timesteps": 1000,
"shift": 3.0,
"use_dynamic_shifting": False
}
class Wan21(BaseModel):
@@ -36,47 +86,286 @@ class Wan21(BaseModel):
super().__init__(device, model_config, dtype,
custom_pipeline, noise_scheduler, **kwargs)
self.is_flow_matching = True
raise NotImplementedError("Wan21 is not implemented yet")
# these must be implemented in child classes
self.is_transformer = True
self.target_lora_modules = ['WanTransformer3DModel']
# cache for holding noise
self.effective_noise = None
# static method to get the scheduler
@staticmethod
def get_train_scheduler():
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
return scheduler
def load_model(self):
pass
dtype = self.torch_dtype
# todo , will this work with other wan models?
base_model_path = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
model_path = self.model_config.name_or_path
self.print_and_status_update("Loading Wan2.1 model")
# base_model_path = "black-forest-labs/FLUX.1-schnell"
base_model_path = self.model_config.name_or_path_original
subfolder = 'transformer'
transformer_path = model_path
if os.path.exists(transformer_path):
subfolder = None
transformer_path = os.path.join(transformer_path, 'transformer')
# check if the path is a full checkpoint.
te_folder_path = os.path.join(model_path, 'text_encoder')
# if we have the te, this folder is a full checkpoint, use it as the base
if os.path.exists(te_folder_path):
base_model_path = model_path
self.print_and_status_update("Loading UMT5EncoderModel")
tokenizer = AutoTokenizer.from_pretrained(
base_model_path, subfolder="tokenizer", torch_dtype=dtype)
text_encoder = UMT5EncoderModel.from_pretrained(
base_model_path, subfolder="text_encoder", torch_dtype=dtype)
text_encoder.to(self.device_torch, dtype=dtype)
flush()
if self.model_config.quantize_te:
self.print_and_status_update("Quantizing UMT5EncoderModel")
quantize(text_encoder, weights=qfloat8)
freeze(text_encoder)
flush()
self.print_and_status_update("Loading transformer")
transformer = WanTransformer3DModel.from_pretrained(
transformer_path,
subfolder=subfolder,
torch_dtype=dtype,
)
if self.model_config.split_model_over_gpus:
raise ValueError(
"Splitting model over gpus is not supported for Wan2.1 models")
transformer.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None:
raise ValueError(
"Assistant LoRA is not supported for Wan2.1 models currently")
if self.model_config.lora_path is not None:
raise ValueError(
"Loading LoRA is not supported for Wan2.1 models currently")
flush()
if self.model_config.quantize:
quantization_args = self.model_config.quantize_kwargs
if 'exclude' not in quantization_args:
quantization_args['exclude'] = []
# patch the state dict method
patch_dequantization_on_save(transformer)
quantization_type = qfloat8
self.print_and_status_update("Quantizing transformer")
quantize(transformer, weights=quantization_type,
**quantization_args)
freeze(transformer)
transformer.to(self.device_torch)
else:
transformer.to(self.device_torch, dtype=dtype)
flush()
scheduler = Wan21.get_train_scheduler()
self.print_and_status_update("Loading VAE")
# todo, example does float 32? check if quality suffers
vae = AutoencoderKLWan.from_pretrained(
base_model_path, subfolder="vae", torch_dtype=dtype)
flush()
self.print_and_status_update("Making pipe")
pipe: WanPipeline = WanPipeline(
scheduler=scheduler,
text_encoder=None,
tokenizer=tokenizer,
vae=vae,
transformer=None,
)
pipe.text_encoder = text_encoder
pipe.transformer = transformer
self.print_and_status_update("Preparing Model")
text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
text_encoder.to(self.device_torch)
text_encoder.requires_grad_(False)
text_encoder.eval()
pipe.transformer = pipe.transformer.to(self.device_torch)
flush()
self.pipeline = pipe
self.model = transformer
self.vae = vae
self.text_encoder = text_encoder
self.tokenizer = tokenizer
def get_generation_pipeline(self):
# override this in child classes
raise NotImplementedError(
"get_generation_pipeline must be implemented in child classes")
scheduler = UniPCMultistepScheduler(**scheduler_configUniPC)
pipeline = WanPipeline(
vae=self.vae,
transformer=self.unet,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
)
return pipeline
def generate_single_image(
self,
pipeline,
pipeline: WanPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
# override this in child classes
raise NotImplementedError(
"generate_single_image must be implemented in child classes")
# todo, figure out how to do video
output = pipeline(
prompt_embeds=conditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
negative_prompt_embeds=unconditional_embeds.text_embeds.to(
self.device_torch, dtype=self.torch_dtype),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
guidance_scale=gen_config.guidance_scale,
latents=gen_config.latents,
num_frames=gen_config.num_frames,
generator=generator,
return_dict=False,
output_type="pil",
**extra
)[0]
# shape = [1, frames, channels, height, width]
batch_item = output[0] # list of pil images
if gen_config.num_frames > 1:
return batch_item # return the frames.
else:
# get just the first image
img = batch_item[0]
return img
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
**kwargs
):
raise NotImplementedError(
"get_noise_prediction must be implemented in child classes")
# vae_scale_factor_spatial = 8
# vae_scale_factor_temporal = 4
# num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# shape = (
# batch_size,
# num_channels_latents, # 16
# num_latent_frames, # 81
# int(height) // self.vae_scale_factor_spatial,
# int(width) // self.vae_scale_factor_spatial,
# )
noise_pred = self.model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=text_embeddings.text_embeds,
return_dict=False,
**kwargs
)[0]
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
raise NotImplementedError(
"get_prompt_embeds must be implemented in child classes")
def get_model_has_grad(self):
raise NotImplementedError(
"get_model_has_grad must be implemented in child classes")
prompt_embeds, _ = self.pipeline.encode_prompt(
prompt,
do_classifier_free_guidance=False,
max_sequence_length=512,
device=self.device_torch,
dtype=self.torch_dtype,
)
return PromptEmbeds(prompt_embeds)
@torch.no_grad()
def encode_images(
self,
image_list: List[torch.Tensor],
device=None,
dtype=None
):
if device is None:
device = self.vae_device_torch
if dtype is None:
dtype = self.vae_torch_dtype
latent_list = []
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
# move to device and dtype
image_list = [image.to(device, dtype=dtype) for image in image_list]
VAE_SCALE_FACTOR = 8
# resize images if not divisible by 8
for i in range(len(image_list)):
image = image_list[i]
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR,
image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
images = torch.stack(image_list)
images = images.unsqueeze(2)
latents = self.vae.encode(images).latent_dist.sample()
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = (latents - latents_mean) * latents_std
latents = latents.to(device, dtype=dtype)
return latents
def get_model_has_grad(self):
return self.model.proj_out.weight.requires_grad
def get_te_has_grad(self):
raise NotImplementedError(
"get_te_has_grad must be implemented in child classes")
return self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
def save_model(self, output_path, meta, save_dtype):
# only save the unet
transformer: Wan21 = unwrap_model(self.model)
transformer.save_pretrained(
save_directory=os.path.join(output_path, 'transformer'),
safe_serialization=True,
)
meta_path = os.path.join(output_path, 'aitk_meta.yaml')
with open(meta_path, 'w') as f:
yaml.dump(meta, f)
def get_loss_target(self, *args, **kwargs):
noise = kwargs.get('noise')
batch = kwargs.get('batch')
if batch is None:
raise ValueError("Batch is not provided")
if noise is None:
raise ValueError("Noise is not provided")
return (noise - batch.latents).detach()

View File

@@ -137,6 +137,8 @@ def match_noise_to_target_mean_offset(noise, target, mix=0.5, dim=None):
def apply_noise_offset(noise, noise_offset):
if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001):
return noise
if len(noise.shape) > 4:
raise ValueError("Applying noise offset not supported for video models at this time.")
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
return noise