Merge pull request #377 from ostris/wan22_14b

Wan2.2 14B T2I support
This commit is contained in:
Jaret Burkett
2025-08-16 14:25:23 -06:00
committed by GitHub
19 changed files with 852 additions and 49 deletions

View File

@@ -3,7 +3,7 @@ from .hidream import HidreamModel, HidreamE1Model
from .f_light import FLiteModel
from .omnigen2 import OmniGen2Model
from .flux_kontext import FluxKontextModel
from .wan22 import Wan225bModel
from .wan22 import Wan225bModel, Wan2214bModel
from .qwen_image import QwenImageModel
AI_TOOLKIT_MODELS = [
@@ -15,5 +15,6 @@ AI_TOOLKIT_MODELS = [
OmniGen2Model,
FluxKontextModel,
Wan225bModel,
Wan2214bModel,
QwenImageModel,
]

View File

@@ -1 +1,2 @@
from .wan22_5b_model import Wan225bModel
from .wan22_5b_model import Wan225bModel
from .wan22_14b_model import Wan2214bModel

View File

@@ -0,0 +1,540 @@
from functools import partial
import os
from typing import Any, Dict, Optional, Union, List
from typing_extensions import Self
import torch
import yaml
from toolkit.accelerator import unwrap_model
from toolkit.basic import flush
from toolkit.prompt_utils import PromptEmbeds
from PIL import Image
from diffusers import UniPCMultistepScheduler
import torch
from toolkit.config_modules import GenerateImageConfig, ModelConfig
from toolkit.samplers.custom_flowmatch_sampler import (
CustomFlowMatchEulerDiscreteScheduler,
)
from toolkit.util.quantize import quantize_model
from .wan22_pipeline import Wan22Pipeline
from diffusers import WanTransformer3DModel
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from torchvision.transforms import functional as TF
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline
from .wan22_5b_model import (
scheduler_config,
time_text_monkeypatch,
Wan225bModel,
)
from safetensors.torch import load_file, save_file
boundary_ratio_t2v = 0.875
boundary_ratio_i2v = 0.9
scheduler_configUniPC = {
"_class_name": "UniPCMultistepScheduler",
"_diffusers_version": "0.35.0.dev0",
"beta_end": 0.02,
"beta_schedule": "linear",
"beta_start": 0.0001,
"disable_corrector": [],
"dynamic_thresholding_ratio": 0.995,
"final_sigmas_type": "zero",
"flow_shift": 3.0,
"lower_order_final": True,
"num_train_timesteps": 1000,
"predict_x0": True,
"prediction_type": "flow_prediction",
"rescale_betas_zero_snr": False,
"sample_max_value": 1.0,
"solver_order": 2,
"solver_p": None,
"solver_type": "bh2",
"steps_offset": 0,
"thresholding": False,
"time_shift_type": "exponential",
"timestep_spacing": "linspace",
"trained_betas": None,
"use_beta_sigmas": False,
"use_dynamic_shifting": False,
"use_exponential_sigmas": False,
"use_flow_sigmas": True,
"use_karras_sigmas": False,
}
class DualWanTransformer3DModel(torch.nn.Module):
def __init__(
self,
transformer_1: WanTransformer3DModel,
transformer_2: WanTransformer3DModel,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
device: Optional[Union[str, torch.device]] = None,
boundary_ratio: float = boundary_ratio_t2v,
low_vram: bool = False,
) -> None:
super().__init__()
self.transformer_1: WanTransformer3DModel = transformer_1
self.transformer_2: WanTransformer3DModel = transformer_2
self.torch_dtype: torch.dtype = torch_dtype
self.device_torch: torch.device = device
self.boundary_ratio: float = boundary_ratio
self.boundary: float = self.boundary_ratio * 1000
self.low_vram: bool = low_vram
self._active_transformer_name = "transformer_1" # default to transformer_1
@property
def device(self) -> torch.device:
return self.device_torch
@property
def dtype(self) -> torch.dtype:
return self.torch_dtype
@property
def config(self):
return self.transformer_1.config
@property
def transformer(self) -> WanTransformer3DModel:
return getattr(self, self._active_transformer_name)
def enable_gradient_checkpointing(self):
"""
Enable gradient checkpointing for both transformers.
"""
self.transformer_1.enable_gradient_checkpointing()
self.transformer_2.enable_gradient_checkpointing()
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# determine if doing high noise or low noise by meaning the timestep.
# timesteps are in the range of 0 to 1000, so we can use a threshold
with torch.no_grad():
if timestep.float().mean().item() >= self.boundary:
t_name = "transformer_1"
else:
t_name = "transformer_2"
# check if we are changing the active transformer, if so, we need to swap the one in
# vram if low_vram is enabled
# todo swap the loras as well
if t_name != self._active_transformer_name:
if self.low_vram:
getattr(self, self._active_transformer_name).to("cpu")
getattr(self, t_name).to(self.device_torch)
torch.cuda.empty_cache()
self._active_transformer_name = t_name
if self.transformer.device != hidden_states.device:
if self.low_vram:
# move other transformer to cpu
other_tname = (
"transformer_1" if t_name == "transformer_2" else "transformer_2"
)
getattr(self, other_tname).to("cpu")
self.transformer.to(hidden_states.device)
return self.transformer(
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_image=encoder_hidden_states_image,
return_dict=return_dict,
attention_kwargs=attention_kwargs,
)
def to(self, *args, **kwargs) -> Self:
# do not do to, this will be handled separately
return self
class Wan2214bModel(Wan225bModel):
arch = "wan22_14b"
_wan_generation_scheduler_config = scheduler_configUniPC
_wan_expand_timesteps = True
_wan_vae_path = "ai-toolkit/wan2.1-vae"
def __init__(
self,
device,
model_config: ModelConfig,
dtype="bf16",
custom_pipeline=None,
noise_scheduler=None,
**kwargs,
):
super().__init__(
device=device,
model_config=model_config,
dtype=dtype,
custom_pipeline=custom_pipeline,
noise_scheduler=noise_scheduler,
**kwargs,
)
# target it so we can target both transformers
self.target_lora_modules = ["DualWanTransformer3DModel"]
self._wan_cache = None
self.is_multistage = True
# multistage boundaries split the models up when sampling timesteps
# for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2
self.multistage_boundaries: List[float] = [0.875, 0.0]
self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True)
self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True)
self.trainable_multistage_boundaries: List[int] = []
if self.train_high_noise:
self.trainable_multistage_boundaries.append(0)
if self.train_low_noise:
self.trainable_multistage_boundaries.append(1)
if len(self.trainable_multistage_boundaries) == 0:
raise ValueError(
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
)
@property
def max_step_saves_to_keep_multiplier(self):
# the cleanup mechanism checks this to see how many saves to keep
# if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep
if (
self.network is not None
and self.network.network_config.split_multistage_loras
):
return 2
return 1
def load_model(self):
# load model from patent parent. Wan21 not immediate parent
# super().load_model()
super(Wan225bModel, self).load_model()
# we have to split up the model on the pipeline
self.pipeline.transformer = self.model.transformer_1
self.pipeline.transformer_2 = self.model.transformer_2
# patch the condition embedder
self.model.transformer_1.condition_embedder.forward = partial(
time_text_monkeypatch, self.model.transformer_1.condition_embedder
)
self.model.transformer_2.condition_embedder.forward = partial(
time_text_monkeypatch, self.model.transformer_2.condition_embedder
)
def get_bucket_divisibility(self):
# 16x compression and 2x2 patch size
return 32
def load_wan_transformer(self, transformer_path, subfolder=None):
if self.model_config.split_model_over_gpus:
raise ValueError(
"Splitting model over gpus is not supported for Wan2.2 models"
)
if (
self.model_config.assistant_lora_path is not None
or self.model_config.inference_lora_path is not None
):
raise ValueError(
"Assistant LoRA is not supported for Wan2.2 models currently"
)
if self.model_config.lora_path is not None:
raise ValueError(
"Loading LoRA is not supported for Wan2.2 models currently"
)
# transformer path can be a directory that ends with /transformer or a hf path.
transformer_path_1 = transformer_path
subfolder_1 = subfolder
transformer_path_2 = transformer_path
subfolder_2 = subfolder
if subfolder_2 is None:
# we have a local path, replace it with transformer_2 folder
transformer_path_2 = os.path.join(
os.path.dirname(transformer_path_1), "transformer_2"
)
else:
# we have a hf path, replace it with transformer_2 subfolder
subfolder_2 = "transformer_2"
self.print_and_status_update("Loading transformer 1")
dtype = self.torch_dtype
transformer_1 = WanTransformer3DModel.from_pretrained(
transformer_path_1,
subfolder=subfolder_1,
torch_dtype=dtype,
).to(dtype=dtype)
flush()
if not self.model_config.low_vram:
# quantize on the device
transformer_1.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
# todo handle two ARAs
self.print_and_status_update("Quantizing Transformer 1")
quantize_model(self, transformer_1)
flush()
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer 1 to CPU")
transformer_1.to("cpu")
self.print_and_status_update("Loading transformer 2")
dtype = self.torch_dtype
transformer_2 = WanTransformer3DModel.from_pretrained(
transformer_path_2,
subfolder=subfolder_2,
torch_dtype=dtype,
).to(dtype=dtype)
flush()
if not self.model_config.low_vram:
# quantize on the device
transformer_2.to(self.quantize_device, dtype=dtype)
flush()
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
# todo handle two ARAs
self.print_and_status_update("Quantizing Transformer 2")
quantize_model(self, transformer_2)
flush()
if self.model_config.low_vram:
self.print_and_status_update("Moving transformer 2 to CPU")
transformer_2.to("cpu")
# make the combined model
self.print_and_status_update("Creating DualWanTransformer3DModel")
transformer = DualWanTransformer3DModel(
transformer_1=transformer_1,
transformer_2=transformer_2,
torch_dtype=self.torch_dtype,
device=self.device_torch,
boundary_ratio=boundary_ratio_t2v,
low_vram=self.model_config.low_vram,
)
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is not None:
# apply the accuracy recovery adapter to both transformers
self.print_and_status_update("Applying Accuracy Recovery Adapter to Transformers")
quantize_model(self, transformer)
flush()
return transformer
def get_generation_pipeline(self):
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
pipeline = Wan22Pipeline(
vae=self.vae,
transformer=self.model.transformer_1,
transformer_2=self.model.transformer_2,
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
scheduler=scheduler,
expand_timesteps=self._wan_expand_timesteps,
device=self.device_torch,
aggressive_offload=self.model_config.low_vram,
# todo detect if it is i2v or t2v
boundary_ratio=boundary_ratio_t2v,
)
# pipeline = pipeline.to(self.device_torch)
return pipeline
# static method to get the scheduler
@staticmethod
def get_train_scheduler():
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
return scheduler
def get_base_model_version(self):
return "wan_2.2_14b"
def generate_single_image(
self,
pipeline: AggressiveWanUnloadPipeline,
gen_config: GenerateImageConfig,
conditional_embeds: PromptEmbeds,
unconditional_embeds: PromptEmbeds,
generator: torch.Generator,
extra: dict,
):
return super().generate_single_image(
pipeline=pipeline,
gen_config=gen_config,
conditional_embeds=conditional_embeds,
unconditional_embeds=unconditional_embeds,
generator=generator,
extra=extra,
)
def get_noise_prediction(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor, # 0 to 1000 scale
text_embeddings: PromptEmbeds,
batch: DataLoaderBatchDTO,
**kwargs,
):
# todo do we need to override this? Adjust timesteps?
return super().get_noise_prediction(
latent_model_input=latent_model_input,
timestep=timestep,
text_embeddings=text_embeddings,
batch=batch,
**kwargs,
)
def get_model_has_grad(self):
return False
def get_te_has_grad(self):
return False
def save_model(self, output_path, meta, save_dtype):
transformer_combo: DualWanTransformer3DModel = unwrap_model(self.model)
transformer_combo.transformer_1.save_pretrained(
save_directory=os.path.join(output_path, "transformer"),
safe_serialization=True,
)
transformer_combo.transformer_2.save_pretrained(
save_directory=os.path.join(output_path, "transformer_2"),
safe_serialization=True,
)
meta_path = os.path.join(output_path, "aitk_meta.yaml")
with open(meta_path, "w") as f:
yaml.dump(meta, f)
def save_lora(
self,
state_dict: Dict[str, torch.Tensor],
output_path: str,
metadata: Optional[Dict[str, Any]] = None,
):
if not self.network.network_config.split_multistage_loras:
# just save as a combo lora
save_file(state_dict, output_path, metadata=metadata)
return
# we need to build out both dictionaries for high and low noise LoRAs
high_noise_lora = {}
low_noise_lora = {}
only_train_high_noise = self.train_high_noise and not self.train_low_noise
only_train_low_noise = self.train_low_noise and not self.train_high_noise
for key in state_dict:
if ".transformer_1." in key or only_train_high_noise:
# this is a high noise LoRA
new_key = key.replace(".transformer_1.", ".")
high_noise_lora[new_key] = state_dict[key]
elif ".transformer_2." in key or only_train_low_noise:
# this is a low noise LoRA
new_key = key.replace(".transformer_2.", ".")
low_noise_lora[new_key] = state_dict[key]
# loras have either LORA_MODEL_NAME_000005000.safetensors or LORA_MODEL_NAME.safetensors
if len(high_noise_lora.keys()) > 0:
# save the high noise LoRA
high_noise_lora_path = output_path.replace(
".safetensors", "_high_noise.safetensors"
)
save_file(high_noise_lora, high_noise_lora_path, metadata=metadata)
if len(low_noise_lora.keys()) > 0:
# save the low noise LoRA
low_noise_lora_path = output_path.replace(
".safetensors", "_low_noise.safetensors"
)
save_file(low_noise_lora, low_noise_lora_path, metadata=metadata)
def load_lora(self, file: str):
# if it doesnt have high_noise or low_noise, it is a combo LoRA
if (
"_high_noise.safetensors" not in file
and "_low_noise.safetensors" not in file
):
# this is a combined LoRA, we dont need to split it up
sd = load_file(file)
return sd
# we may have been passed the high_noise or the low_noise LoRA path, but we need to load both
high_noise_lora_path = file.replace(
"_low_noise.safetensors", "_high_noise.safetensors"
)
low_noise_lora_path = file.replace(
"_high_noise.safetensors", "_low_noise.safetensors"
)
combined_dict = {}
if os.path.exists(high_noise_lora_path) and self.train_high_noise:
# load the high noise LoRA
high_noise_lora = load_file(high_noise_lora_path)
for key in high_noise_lora:
new_key = key.replace(
"diffusion_model.", "diffusion_model.transformer_1."
)
combined_dict[new_key] = high_noise_lora[key]
if os.path.exists(low_noise_lora_path) and self.train_low_noise:
# load the low noise LoRA
low_noise_lora = load_file(low_noise_lora_path)
for key in low_noise_lora:
new_key = key.replace(
"diffusion_model.", "diffusion_model.transformer_2."
)
combined_dict[new_key] = low_noise_lora[key]
# if we are not training both stages, we wont have transformer designations in the keys
if not self.train_high_noise and not self.train_low_noise:
new_dict = {}
for key in combined_dict:
if ".transformer_1." in key:
new_key = key.replace(".transformer_1.", ".")
elif ".transformer_2." in key:
new_key = key.replace(".transformer_2.", ".")
else:
new_key = key
new_dict[new_key] = combined_dict[key]
combined_dict = new_dict
return combined_dict
def get_model_to_train(self):
# todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key.
# called when setting up the LoRA. We only need to get the model for the stages we want to train.
if self.train_high_noise and self.train_low_noise:
# we are training both stages, return the unified model
return self.model
elif self.train_high_noise:
# we are only training the high noise stage, return transformer_1
return self.model.transformer_1
elif self.train_low_noise:
# we are only training the low noise stage, return transformer_2
return self.model.transformer_2
else:
raise ValueError(
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
)

View File

@@ -12,7 +12,6 @@ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from typing import Any, Callable, Dict, List, Optional, Union
class Wan22Pipeline(WanPipeline):
def __init__(
self,
@@ -52,6 +51,7 @@ class Wan22Pipeline(WanPipeline):
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
@@ -77,11 +77,15 @@ class Wan22Pipeline(WanPipeline):
vae_device = self.vae.device
transformer_device = self.transformer.device
text_encoder_device = self.text_encoder.device
device = self.transformer.device
device = self._exec_device
if self._aggressive_offload:
print("Unloading vae")
self.vae.to("cpu")
print("Unloading transformer")
self.transformer.to("cpu")
if self.transformer_2 is not None:
self.transformer_2.to("cpu")
self.text_encoder.to(device)
flush()
@@ -95,9 +99,14 @@ class Wan22Pipeline(WanPipeline):
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
guidance_scale_2
)
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
guidance_scale_2 = guidance_scale
self._guidance_scale = guidance_scale
self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -160,6 +169,13 @@ class Wan22Pipeline(WanPipeline):
num_warmup_steps = len(timesteps) - \
num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:
boundary_timestep = None
current_model = self.transformer
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -167,6 +183,25 @@ class Wan22Pipeline(WanPipeline):
continue
self._current_timestep = t
if boundary_timestep is None or t >= boundary_timestep:
if self._aggressive_offload and current_model != self.transformer:
if self.transformer_2 is not None:
self.transformer_2.to("cpu")
self.transformer.to(device)
# wan2.1 or high-noise stage in wan2.2
current_model = self.transformer
current_guidance_scale = guidance_scale
else:
if self._aggressive_offload and current_model != self.transformer_2:
if self.transformer is not None:
self.transformer.to("cpu")
if self.transformer_2 is not None:
self.transformer_2.to(device)
# low-noise stage in wan2.2
current_model = self.transformer_2
current_guidance_scale = guidance_scale_2
latent_model_input = latents.to(device, transformer_dtype)
if self.config.expand_timesteps:
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
@@ -176,7 +211,7 @@ class Wan22Pipeline(WanPipeline):
else:
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
@@ -185,14 +220,14 @@ class Wan22Pipeline(WanPipeline):
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * \
noise_pred = noise_uncond + current_guidance_scale * \
(noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
@@ -259,10 +294,8 @@ class Wan22Pipeline(WanPipeline):
# move transformer back to device
if self._aggressive_offload:
print("Moving transformer back to device")
self.transformer.to(self._execution_device)
if self.transformer_2 is not None:
self.transformer_2.to(self._execution_device)
# print("Moving transformer back to device")
# self.transformer.to(self._execution_device)
flush()
if not return_dict:

View File

@@ -1862,7 +1862,20 @@ class SDTrainer(BaseSDTrainProcess):
total_loss = None
self.optimizer.zero_grad()
for batch in batch_list:
if self.sd.is_multistage:
# handle multistage switching
if self.steps_this_boundary >= self.train_config.switch_boundary_every:
# iterate to make sure we only train trainable_multistage_boundaries
while True:
self.steps_this_boundary = 0
self.current_boundary_index += 1
if self.current_boundary_index >= len(self.sd.multistage_boundaries):
self.current_boundary_index = 0
if self.current_boundary_index in self.sd.trainable_multistage_boundaries:
# if this boundary is trainable, we can stop looking
break
loss = self.train_single_accumulation(batch)
self.steps_this_boundary += 1
if total_loss is None:
total_loss = loss
else:
@@ -1907,7 +1920,7 @@ class SDTrainer(BaseSDTrainProcess):
self.adapter.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}
{'loss': (total_loss / len(batch_list)).item()}
)
self.end_of_training_loop()

View File

@@ -260,6 +260,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
torch.profiler.ProfilerActivity.CUDA,
],
)
self.current_boundary_index = 0
self.steps_this_boundary = 0
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -437,19 +440,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
# Combine and sort the lists
combined_items = safetensors_files + directories + pt_files
combined_items.sort(key=os.path.getctime)
num_saves_to_keep = self.save_config.max_step_saves_to_keep
if hasattr(self.sd, 'max_step_saves_to_keep_multiplier'):
num_saves_to_keep *= self.sd.max_step_saves_to_keep_multiplier
# Use slicing with a check to avoid 'NoneType' error
safetensors_to_remove = safetensors_files[
:-self.save_config.max_step_saves_to_keep] if safetensors_files else []
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else []
:-num_saves_to_keep] if safetensors_files else []
pt_files_to_remove = pt_files[:-num_saves_to_keep] if pt_files else []
directories_to_remove = directories[:-num_saves_to_keep] if directories else []
embeddings_to_remove = embed_files[:-num_saves_to_keep] if embed_files else []
critic_to_remove = critic_items[:-num_saves_to_keep] if critic_items else []
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove
# remove all but the latest max_step_saves_to_keep
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
# items_to_remove = combined_items[:-num_saves_to_keep]
# remove duplicates
items_to_remove = list(dict.fromkeys(items_to_remove))
@@ -1166,6 +1174,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch
)
if self.sd.is_multistage:
with self.timer('adjust_multistage_timesteps'):
# get our current sample range
boundaries = [1000] + self.sd.multistage_boundaries
boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1]
lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False)
hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True)
first_idx = lo.item() if hi > lo else 0
last_idx = (hi - 1).item() if hi > lo else 999
min_noise_steps = first_idx
max_noise_steps = last_idx
# clip min max indicies
min_noise_steps = max(min_noise_steps, 0)
max_noise_steps = min(max_noise_steps, num_train_timesteps - 1)
with self.timer('prepare_timesteps_indices'):
content_or_style = self.train_config.content_or_style
@@ -1204,11 +1230,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
0,
self.train_config.num_train_timesteps - 1,
min_noise_steps,
max_noise_steps - 1
max_noise_steps
)
timestep_indices = timestep_indices.long().clamp(
min_noise_steps + 1,
max_noise_steps - 1
min_noise_steps,
max_noise_steps
)
elif content_or_style == 'balanced':
@@ -1221,7 +1247,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.train_config.noise_scheduler == 'flowmatch':
# flowmatch uses indices, so we need to use indices
min_idx = 0
max_idx = max_noise_steps - 1
max_idx = max_noise_steps
timestep_indices = torch.randint(
min_idx,
max_idx,
@@ -1671,7 +1697,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network = NetworkClass(
text_encoder=text_encoder,
unet=unet,
unet=self.sd.get_model_to_train(),
lora_dim=self.network_config.linear,
multiplier=1.0,
alpha=self.network_config.linear_alpha,

View File

@@ -185,6 +185,9 @@ class NetworkConfig:
self.conv_alpha = 9999999999
# -1 automatically finds the largest factor
self.lokr_factor = kwargs.get('lokr_factor', -1)
# for multi stage models
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']
@@ -332,7 +335,7 @@ class TrainConfig:
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 999)
self.batch_size: int = kwargs.get('batch_size', 1)
self.orig_batch_size: int = self.batch_size
self.dtype: str = kwargs.get('dtype', 'fp32')
@@ -512,6 +515,9 @@ class TrainConfig:
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
if isinstance(self.guidance_loss_target, tuple):
self.guidance_loss_target = list(self.guidance_loss_target)
# for multi stage models, how often to switch the boundary
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']

View File

@@ -172,6 +172,11 @@ class BaseModel:
self.sample_prompts_cache = None
self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None
self.is_multistage = False
# a list of multistage boundaries starting with train step 1000 to first idx
self.multistage_boundaries: List[float] = [0.0]
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# properties for old arch for backwards compatibility
@property
@@ -1502,3 +1507,7 @@ class BaseModel:
def get_base_model_version(self) -> str:
# override in child classes to get the base model version
return "unknown"
def get_model_to_train(self):
# called to get model to attach LoRAs to. Can be overridden in child classes
return self.unet

View File

@@ -310,6 +310,7 @@ class Wan21(BaseModel):
arch = 'wan21'
_wan_generation_scheduler_config = scheduler_configUniPC
_wan_expand_timesteps = False
_wan_vae_path = None
_comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors']
def __init__(
@@ -431,8 +432,14 @@ class Wan21(BaseModel):
scheduler = Wan21.get_train_scheduler()
self.print_and_status_update("Loading VAE")
# todo, example does float 32? check if quality suffers
vae = AutoencoderKLWan.from_pretrained(
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
if self._wan_vae_path is not None:
# load the vae from individual repo
vae = AutoencoderKLWan.from_pretrained(
self._wan_vae_path, torch_dtype=dtype).to(dtype=dtype)
else:
vae = AutoencoderKLWan.from_pretrained(
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
flush()
self.print_and_status_update("Making pipe")

View File

@@ -565,6 +565,13 @@ class ToolkitNetworkMixin:
if metadata is None:
metadata = OrderedDict()
metadata = add_model_hash_to_meta(save_dict, metadata)
# let the model handle the saving
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'):
# call the base model save lora method
self.base_model_ref().save_lora(save_dict, file, metadata)
return
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(save_dict, file, metadata)
@@ -577,12 +584,15 @@ class ToolkitNetworkMixin:
keymap = {} if keymap is None else keymap
if isinstance(file, str):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'):
# call the base model load lora method
weights_sd = self.base_model_ref().load_lora(file)
else:
weights_sd = torch.load(file, map_location="cpu")
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
else:
# probably a state dict
weights_sd = file

View File

@@ -211,6 +211,12 @@ class StableDiffusion:
self.sample_prompts_cache = None
self.is_multistage = False
# a list of multistage boundaries starting with train step 1000 to first idx
self.multistage_boundaries: List[float] = [0.0]
# a list of trainable multistage boundaries
self.trainable_multistage_boundaries: List[int] = [0]
# properties for old arch for backwards compatibility
@property
def is_xl(self):
@@ -3123,3 +3129,6 @@ class StableDiffusion:
if self.is_v2:
return 'sd_2.1'
return 'sd_1.5'
def get_model_to_train(self):
return self.unet

View File

@@ -40,10 +40,25 @@ export default function SimpleJob({
const isVideoModel = !!(modelArch?.group === 'video');
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
const numTopCards = useMemo(() => {
let count = 4; // job settings, model config, target config, save config
if (modelArch?.additionalSections?.includes('model.multistage')) {
count += 1; // add multistage card
}
if (!modelArch?.disableSections?.includes('model.quantize')) {
count += 1; // add quantization card
}
return count;
}, [modelArch]);
if (modelArch?.disableSections?.includes('model.quantize')) {
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
if (numTopCards == 5) {
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
}
if (numTopCards == 6) {
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
}
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
@@ -91,7 +106,7 @@ export default function SimpleJob({
<>
<form onSubmit={handleSubmit} className="space-y-8">
<div className={topBarClass}>
<Card title="Job Settings">
<Card title="Job">
<TextInput
label="Training Name"
value={jobConfig.config.name}
@@ -124,7 +139,7 @@ export default function SimpleJob({
</Card>
{/* Model Configuration Section */}
<Card title="Model Configuration">
<Card title="Model">
<SelectInput
label="Model Architecture"
value={jobConfig.config.process[0].model.arch}
@@ -239,7 +254,32 @@ export default function SimpleJob({
/>
</Card>
)}
<Card title="Target Configuration">
{modelArch?.additionalSections?.includes('model.multistage') && (
<Card title="Multistage">
<FormGroup label="Stages to Train" docKey={'model.multistage'}>
<Checkbox
label="High Noise"
checked={jobConfig.config.process[0].model.model_kwargs?.train_high_noise || false}
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')}
/>
<Checkbox
label="Low Noise"
checked={jobConfig.config.process[0].model.model_kwargs?.train_low_noise || false}
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')}
/>
</FormGroup>
<NumberInput
label="Switch Every"
value={jobConfig.config.process[0].train.switch_boundary_every}
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
placeholder="eg. 1"
docKey={'train.switch_boundary_every'}
min={1}
required
/>
</Card>
)}
<Card title="Target">
<SelectInput
label="Target Type"
value={jobConfig.config.process[0].network?.type ?? 'lora'}
@@ -295,7 +335,7 @@ export default function SimpleJob({
</>
)}
</Card>
<Card title="Save Configuration">
<Card title="Save">
<SelectInput
label="Data Type"
value={jobConfig.config.process[0].save.dtype}
@@ -325,7 +365,7 @@ export default function SimpleJob({
</Card>
</div>
<div>
<Card title="Training Configuration">
<Card title="Training">
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
<div>
<NumberInput
@@ -645,7 +685,7 @@ export default function SimpleJob({
</Card>
</div>
<div>
<Card title="Sample Configuration">
<Card title="Sample">
<div
className={
isVideoModel

View File

@@ -78,6 +78,7 @@ export const defaultJobConfig: JobConfig = {
diff_output_preservation: false,
diff_output_preservation_multiplier: 1.0,
diff_output_preservation_class: 'person',
switch_boundary_every: 1,
},
model: {
name_or_path: 'ostris/Flex.1-alpha',

View File

@@ -3,7 +3,13 @@ import { GroupedSelectOption, SelectOption } from '@/types';
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
type AdditionalSections =
| 'datasets.control_path'
| 'datasets.do_i2v'
| 'sample.ctrl_img'
| 'datasets.num_frames'
| 'model.multistage'
| 'model.low_vram';
type ModelGroup = 'image' | 'video';
export interface ModelArch {
@@ -121,7 +127,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
'config.process[0].sample.fps': [16, 1],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.low_vram'],
@@ -139,7 +145,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
'config.process[0].sample.fps': [16, 1],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
},
disableSections: ['network.conv'],
@@ -158,7 +164,7 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
'config.process[0].sample.fps': [16, 1],
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
},
disableSections: ['network.conv'],
@@ -177,11 +183,41 @@ export const modelArchs: ModelArch[] = [
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [15, 1],
'config.process[0].sample.fps': [16, 1],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.low_vram'],
},
{
name: 'wan22_14b:t2v',
label: 'Wan 2.2 (14B)',
group: 'video',
isVideoModel: true,
defaults: {
// default updates when [selected, unselected] in the UI
'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16', defaultNameOrPath],
'config.process[0].model.quantize': [true, false],
'config.process[0].model.quantize_te': [true, false],
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
'config.process[0].sample.num_frames': [41, 1],
'config.process[0].sample.fps': [16, 1],
'config.process[0].model.low_vram': [true, false],
'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
'config.process[0].model.model_kwargs': [
{
train_high_noise: true,
train_low_noise: true,
},
{},
],
},
disableSections: ['network.conv'],
additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'],
// accuracyRecoveryAdapters: {
// '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
// },
},
{
name: 'wan22_5b',
label: 'Wan 2.2 TI2V (5B)',

View File

@@ -27,7 +27,7 @@ export default function SampleImages({ job }: SampleImagesProps) {
// This way Tailwind can properly generate the class
// I hate this, but it's the only way to make it work
const gridColsClass = useMemo(() => {
const cols = Math.min(numSamples, 20);
const cols = Math.min(numSamples, 40);
switch (cols) {
case 1:
@@ -70,6 +70,46 @@ export default function SampleImages({ job }: SampleImagesProps) {
return 'grid-cols-19';
case 20:
return 'grid-cols-20';
case 21:
return 'grid-cols-21';
case 22:
return 'grid-cols-22';
case 23:
return 'grid-cols-23';
case 24:
return 'grid-cols-24';
case 25:
return 'grid-cols-25';
case 26:
return 'grid-cols-26';
case 27:
return 'grid-cols-27';
case 28:
return 'grid-cols-28';
case 29:
return 'grid-cols-29';
case 30:
return 'grid-cols-30';
case 31:
return 'grid-cols-31';
case 32:
return 'grid-cols-32';
case 33:
return 'grid-cols-33';
case 34:
return 'grid-cols-34';
case 35:
return 'grid-cols-35';
case 36:
return 'grid-cols-36';
case 37:
return 'grid-cols-37';
case 38:
return 'grid-cols-38';
case 39:
return 'grid-cols-39';
case 40:
return 'grid-cols-40';
default:
return 'grid-cols-1';
}

View File

@@ -283,7 +283,7 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
return (
<div className={classNames(className)}>
{label && (
<label className={labelClasses}>
<label className={classNames(labelClasses, 'mb-2')}>
{label}{' '}
{doc && (
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
@@ -292,7 +292,7 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
)}
</label>
)}
<div className="px-4 space-y-2">{children}</div>
<div className="space-y-2">{children}</div>
</div>
);
};

View File

@@ -111,6 +111,36 @@ const docs: { [key: string]: ConfigDoc } = {
</>
),
},
'model.multistage': {
title: 'Stages to Train',
description: (
<>
Some models have multi stage networks that are trained and used separately in the denoising process. Most
common, is to have 2 stages. One for high noise and one for low noise. You can choose to train both stages at
once or train them separately. If trained at the same time, The trainer will alternate between training each
model every so many steps and will output 2 different LoRAs. If you choose to train only one stage, the
trainer will only train that stage and output a single LoRA.
</>
),
},
'train.switch_boundary_every': {
title: 'Switch Boundary Every',
description: (
<>
When training a model with multiple stages, this setting controls how often the trainer will switch between
training each stage.
<br />
<br />
For low vram settings, the model not being trained will be unloaded from the gpu to save memory. This takes some
time to do, so it is recommended to alternate less often when using low vram. A setting like 10 or 20 is
recommended for low vram settings.
<br />
<br />
The swap happens at the batch level, meaning it will swap between a gradient accumulation steps. To train both
stages in a single step, set them to switch every 1 step and set gradient accumulation to 2.
</>
),
},
};
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {

View File

@@ -119,6 +119,7 @@ export interface TrainConfig {
diff_output_preservation: boolean;
diff_output_preservation_multiplier: number;
diff_output_preservation_class: string;
switch_boundary_every: number;
}
export interface QuantizeKwargsConfig {

View File

@@ -1 +1 @@
VERSION = "0.4.1"
VERSION = "0.5.0"