mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-09 15:09:57 +00:00
@@ -3,7 +3,7 @@ from .hidream import HidreamModel, HidreamE1Model
|
||||
from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
from .wan22 import Wan225bModel
|
||||
from .wan22 import Wan225bModel, Wan2214bModel
|
||||
from .qwen_image import QwenImageModel
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
@@ -15,5 +15,6 @@ AI_TOOLKIT_MODELS = [
|
||||
OmniGen2Model,
|
||||
FluxKontextModel,
|
||||
Wan225bModel,
|
||||
Wan2214bModel,
|
||||
QwenImageModel,
|
||||
]
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .wan22_5b_model import Wan225bModel
|
||||
from .wan22_5b_model import Wan225bModel
|
||||
from .wan22_14b_model import Wan2214bModel
|
||||
540
extensions_built_in/diffusion_models/wan22/wan22_14b_model.py
Normal file
540
extensions_built_in/diffusion_models/wan22/wan22_14b_model.py
Normal file
@@ -0,0 +1,540 @@
|
||||
from functools import partial
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
from typing_extensions import Self
|
||||
import torch
|
||||
import yaml
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from PIL import Image
|
||||
from diffusers import UniPCMultistepScheduler
|
||||
import torch
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from toolkit.util.quantize import quantize_model
|
||||
from .wan22_pipeline import Wan22Pipeline
|
||||
from diffusers import WanTransformer3DModel
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline
|
||||
from .wan22_5b_model import (
|
||||
scheduler_config,
|
||||
time_text_monkeypatch,
|
||||
Wan225bModel,
|
||||
)
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
boundary_ratio_t2v = 0.875
|
||||
boundary_ratio_i2v = 0.9
|
||||
|
||||
scheduler_configUniPC = {
|
||||
"_class_name": "UniPCMultistepScheduler",
|
||||
"_diffusers_version": "0.35.0.dev0",
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"beta_start": 0.0001,
|
||||
"disable_corrector": [],
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"final_sigmas_type": "zero",
|
||||
"flow_shift": 3.0,
|
||||
"lower_order_final": True,
|
||||
"num_train_timesteps": 1000,
|
||||
"predict_x0": True,
|
||||
"prediction_type": "flow_prediction",
|
||||
"rescale_betas_zero_snr": False,
|
||||
"sample_max_value": 1.0,
|
||||
"solver_order": 2,
|
||||
"solver_p": None,
|
||||
"solver_type": "bh2",
|
||||
"steps_offset": 0,
|
||||
"thresholding": False,
|
||||
"time_shift_type": "exponential",
|
||||
"timestep_spacing": "linspace",
|
||||
"trained_betas": None,
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": False,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_flow_sigmas": True,
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
|
||||
|
||||
class DualWanTransformer3DModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
transformer_1: WanTransformer3DModel,
|
||||
transformer_2: WanTransformer3DModel,
|
||||
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
boundary_ratio: float = boundary_ratio_t2v,
|
||||
low_vram: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.transformer_1: WanTransformer3DModel = transformer_1
|
||||
self.transformer_2: WanTransformer3DModel = transformer_2
|
||||
self.torch_dtype: torch.dtype = torch_dtype
|
||||
self.device_torch: torch.device = device
|
||||
self.boundary_ratio: float = boundary_ratio
|
||||
self.boundary: float = self.boundary_ratio * 1000
|
||||
self.low_vram: bool = low_vram
|
||||
self._active_transformer_name = "transformer_1" # default to transformer_1
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.device_torch
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.torch_dtype
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.transformer_1.config
|
||||
|
||||
@property
|
||||
def transformer(self) -> WanTransformer3DModel:
|
||||
return getattr(self, self._active_transformer_name)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
"""
|
||||
Enable gradient checkpointing for both transformers.
|
||||
"""
|
||||
self.transformer_1.enable_gradient_checkpointing()
|
||||
self.transformer_2.enable_gradient_checkpointing()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
# determine if doing high noise or low noise by meaning the timestep.
|
||||
# timesteps are in the range of 0 to 1000, so we can use a threshold
|
||||
with torch.no_grad():
|
||||
if timestep.float().mean().item() >= self.boundary:
|
||||
t_name = "transformer_1"
|
||||
else:
|
||||
t_name = "transformer_2"
|
||||
|
||||
# check if we are changing the active transformer, if so, we need to swap the one in
|
||||
# vram if low_vram is enabled
|
||||
# todo swap the loras as well
|
||||
if t_name != self._active_transformer_name:
|
||||
if self.low_vram:
|
||||
getattr(self, self._active_transformer_name).to("cpu")
|
||||
getattr(self, t_name).to(self.device_torch)
|
||||
torch.cuda.empty_cache()
|
||||
self._active_transformer_name = t_name
|
||||
|
||||
if self.transformer.device != hidden_states.device:
|
||||
if self.low_vram:
|
||||
# move other transformer to cpu
|
||||
other_tname = (
|
||||
"transformer_1" if t_name == "transformer_2" else "transformer_2"
|
||||
)
|
||||
getattr(self, other_tname).to("cpu")
|
||||
|
||||
self.transformer.to(hidden_states.device)
|
||||
|
||||
return self.transformer(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_image=encoder_hidden_states_image,
|
||||
return_dict=return_dict,
|
||||
attention_kwargs=attention_kwargs,
|
||||
)
|
||||
|
||||
def to(self, *args, **kwargs) -> Self:
|
||||
# do not do to, this will be handled separately
|
||||
return self
|
||||
|
||||
|
||||
class Wan2214bModel(Wan225bModel):
|
||||
arch = "wan22_14b"
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = True
|
||||
_wan_vae_path = "ai-toolkit/wan2.1-vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device=device,
|
||||
model_config=model_config,
|
||||
dtype=dtype,
|
||||
custom_pipeline=custom_pipeline,
|
||||
noise_scheduler=noise_scheduler,
|
||||
**kwargs,
|
||||
)
|
||||
# target it so we can target both transformers
|
||||
self.target_lora_modules = ["DualWanTransformer3DModel"]
|
||||
self._wan_cache = None
|
||||
|
||||
self.is_multistage = True
|
||||
# multistage boundaries split the models up when sampling timesteps
|
||||
# for wan 2.2 14b. the timesteps are 1000-875 for transformer 1 and 875-0 for transformer 2
|
||||
self.multistage_boundaries: List[float] = [0.875, 0.0]
|
||||
|
||||
self.train_high_noise = model_config.model_kwargs.get("train_high_noise", True)
|
||||
self.train_low_noise = model_config.model_kwargs.get("train_low_noise", True)
|
||||
|
||||
self.trainable_multistage_boundaries: List[int] = []
|
||||
if self.train_high_noise:
|
||||
self.trainable_multistage_boundaries.append(0)
|
||||
if self.train_low_noise:
|
||||
self.trainable_multistage_boundaries.append(1)
|
||||
|
||||
if len(self.trainable_multistage_boundaries) == 0:
|
||||
raise ValueError(
|
||||
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
|
||||
)
|
||||
|
||||
@property
|
||||
def max_step_saves_to_keep_multiplier(self):
|
||||
# the cleanup mechanism checks this to see how many saves to keep
|
||||
# if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep
|
||||
if (
|
||||
self.network is not None
|
||||
and self.network.network_config.split_multistage_loras
|
||||
):
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def load_model(self):
|
||||
# load model from patent parent. Wan21 not immediate parent
|
||||
# super().load_model()
|
||||
super(Wan225bModel, self).load_model()
|
||||
|
||||
# we have to split up the model on the pipeline
|
||||
self.pipeline.transformer = self.model.transformer_1
|
||||
self.pipeline.transformer_2 = self.model.transformer_2
|
||||
|
||||
# patch the condition embedder
|
||||
self.model.transformer_1.condition_embedder.forward = partial(
|
||||
time_text_monkeypatch, self.model.transformer_1.condition_embedder
|
||||
)
|
||||
self.model.transformer_2.condition_embedder.forward = partial(
|
||||
time_text_monkeypatch, self.model.transformer_2.condition_embedder
|
||||
)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
# 16x compression and 2x2 patch size
|
||||
return 32
|
||||
|
||||
def load_wan_transformer(self, transformer_path, subfolder=None):
|
||||
if self.model_config.split_model_over_gpus:
|
||||
raise ValueError(
|
||||
"Splitting model over gpus is not supported for Wan2.2 models"
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_config.assistant_lora_path is not None
|
||||
or self.model_config.inference_lora_path is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"Assistant LoRA is not supported for Wan2.2 models currently"
|
||||
)
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError(
|
||||
"Loading LoRA is not supported for Wan2.2 models currently"
|
||||
)
|
||||
|
||||
# transformer path can be a directory that ends with /transformer or a hf path.
|
||||
|
||||
transformer_path_1 = transformer_path
|
||||
subfolder_1 = subfolder
|
||||
|
||||
transformer_path_2 = transformer_path
|
||||
subfolder_2 = subfolder
|
||||
|
||||
if subfolder_2 is None:
|
||||
# we have a local path, replace it with transformer_2 folder
|
||||
transformer_path_2 = os.path.join(
|
||||
os.path.dirname(transformer_path_1), "transformer_2"
|
||||
)
|
||||
else:
|
||||
# we have a hf path, replace it with transformer_2 subfolder
|
||||
subfolder_2 = "transformer_2"
|
||||
|
||||
self.print_and_status_update("Loading transformer 1")
|
||||
dtype = self.torch_dtype
|
||||
transformer_1 = WanTransformer3DModel.from_pretrained(
|
||||
transformer_path_1,
|
||||
subfolder=subfolder_1,
|
||||
torch_dtype=dtype,
|
||||
).to(dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer_1.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
|
||||
# todo handle two ARAs
|
||||
self.print_and_status_update("Quantizing Transformer 1")
|
||||
quantize_model(self, transformer_1)
|
||||
flush()
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer 1 to CPU")
|
||||
transformer_1.to("cpu")
|
||||
|
||||
self.print_and_status_update("Loading transformer 2")
|
||||
dtype = self.torch_dtype
|
||||
transformer_2 = WanTransformer3DModel.from_pretrained(
|
||||
transformer_path_2,
|
||||
subfolder=subfolder_2,
|
||||
torch_dtype=dtype,
|
||||
).to(dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer_2.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is None:
|
||||
# todo handle two ARAs
|
||||
self.print_and_status_update("Quantizing Transformer 2")
|
||||
quantize_model(self, transformer_2)
|
||||
flush()
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer 2 to CPU")
|
||||
transformer_2.to("cpu")
|
||||
|
||||
# make the combined model
|
||||
self.print_and_status_update("Creating DualWanTransformer3DModel")
|
||||
transformer = DualWanTransformer3DModel(
|
||||
transformer_1=transformer_1,
|
||||
transformer_2=transformer_2,
|
||||
torch_dtype=self.torch_dtype,
|
||||
device=self.device_torch,
|
||||
boundary_ratio=boundary_ratio_t2v,
|
||||
low_vram=self.model_config.low_vram,
|
||||
)
|
||||
|
||||
if self.model_config.quantize and self.model_config.accuracy_recovery_adapter is not None:
|
||||
# apply the accuracy recovery adapter to both transformers
|
||||
self.print_and_status_update("Applying Accuracy Recovery Adapter to Transformers")
|
||||
quantize_model(self, transformer)
|
||||
flush()
|
||||
|
||||
return transformer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
pipeline = Wan22Pipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.model.transformer_1,
|
||||
transformer_2=self.model.transformer_2,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
expand_timesteps=self._wan_expand_timesteps,
|
||||
device=self.device_torch,
|
||||
aggressive_offload=self.model_config.low_vram,
|
||||
# todo detect if it is i2v or t2v
|
||||
boundary_ratio=boundary_ratio_t2v,
|
||||
)
|
||||
|
||||
# pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
# static method to get the scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
return scheduler
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "wan_2.2_14b"
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: AggressiveWanUnloadPipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
return super().generate_single_image(
|
||||
pipeline=pipeline,
|
||||
gen_config=gen_config,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
generator=generator,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
batch: DataLoaderBatchDTO,
|
||||
**kwargs,
|
||||
):
|
||||
# todo do we need to override this? Adjust timesteps?
|
||||
return super().get_noise_prediction(
|
||||
latent_model_input=latent_model_input,
|
||||
timestep=timestep,
|
||||
text_embeddings=text_embeddings,
|
||||
batch=batch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_model_has_grad(self):
|
||||
return False
|
||||
|
||||
def get_te_has_grad(self):
|
||||
return False
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
transformer_combo: DualWanTransformer3DModel = unwrap_model(self.model)
|
||||
transformer_combo.transformer_1.save_pretrained(
|
||||
save_directory=os.path.join(output_path, "transformer"),
|
||||
safe_serialization=True,
|
||||
)
|
||||
transformer_combo.transformer_2.save_pretrained(
|
||||
save_directory=os.path.join(output_path, "transformer_2"),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, "aitk_meta.yaml")
|
||||
with open(meta_path, "w") as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def save_lora(
|
||||
self,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
output_path: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if not self.network.network_config.split_multistage_loras:
|
||||
# just save as a combo lora
|
||||
save_file(state_dict, output_path, metadata=metadata)
|
||||
return
|
||||
|
||||
# we need to build out both dictionaries for high and low noise LoRAs
|
||||
high_noise_lora = {}
|
||||
low_noise_lora = {}
|
||||
|
||||
only_train_high_noise = self.train_high_noise and not self.train_low_noise
|
||||
only_train_low_noise = self.train_low_noise and not self.train_high_noise
|
||||
|
||||
for key in state_dict:
|
||||
if ".transformer_1." in key or only_train_high_noise:
|
||||
# this is a high noise LoRA
|
||||
new_key = key.replace(".transformer_1.", ".")
|
||||
high_noise_lora[new_key] = state_dict[key]
|
||||
elif ".transformer_2." in key or only_train_low_noise:
|
||||
# this is a low noise LoRA
|
||||
new_key = key.replace(".transformer_2.", ".")
|
||||
low_noise_lora[new_key] = state_dict[key]
|
||||
|
||||
# loras have either LORA_MODEL_NAME_000005000.safetensors or LORA_MODEL_NAME.safetensors
|
||||
if len(high_noise_lora.keys()) > 0:
|
||||
# save the high noise LoRA
|
||||
high_noise_lora_path = output_path.replace(
|
||||
".safetensors", "_high_noise.safetensors"
|
||||
)
|
||||
save_file(high_noise_lora, high_noise_lora_path, metadata=metadata)
|
||||
|
||||
if len(low_noise_lora.keys()) > 0:
|
||||
# save the low noise LoRA
|
||||
low_noise_lora_path = output_path.replace(
|
||||
".safetensors", "_low_noise.safetensors"
|
||||
)
|
||||
save_file(low_noise_lora, low_noise_lora_path, metadata=metadata)
|
||||
|
||||
def load_lora(self, file: str):
|
||||
# if it doesnt have high_noise or low_noise, it is a combo LoRA
|
||||
if (
|
||||
"_high_noise.safetensors" not in file
|
||||
and "_low_noise.safetensors" not in file
|
||||
):
|
||||
# this is a combined LoRA, we dont need to split it up
|
||||
sd = load_file(file)
|
||||
return sd
|
||||
|
||||
# we may have been passed the high_noise or the low_noise LoRA path, but we need to load both
|
||||
high_noise_lora_path = file.replace(
|
||||
"_low_noise.safetensors", "_high_noise.safetensors"
|
||||
)
|
||||
low_noise_lora_path = file.replace(
|
||||
"_high_noise.safetensors", "_low_noise.safetensors"
|
||||
)
|
||||
|
||||
combined_dict = {}
|
||||
|
||||
if os.path.exists(high_noise_lora_path) and self.train_high_noise:
|
||||
# load the high noise LoRA
|
||||
high_noise_lora = load_file(high_noise_lora_path)
|
||||
for key in high_noise_lora:
|
||||
new_key = key.replace(
|
||||
"diffusion_model.", "diffusion_model.transformer_1."
|
||||
)
|
||||
combined_dict[new_key] = high_noise_lora[key]
|
||||
if os.path.exists(low_noise_lora_path) and self.train_low_noise:
|
||||
# load the low noise LoRA
|
||||
low_noise_lora = load_file(low_noise_lora_path)
|
||||
for key in low_noise_lora:
|
||||
new_key = key.replace(
|
||||
"diffusion_model.", "diffusion_model.transformer_2."
|
||||
)
|
||||
combined_dict[new_key] = low_noise_lora[key]
|
||||
|
||||
# if we are not training both stages, we wont have transformer designations in the keys
|
||||
if not self.train_high_noise and not self.train_low_noise:
|
||||
new_dict = {}
|
||||
for key in combined_dict:
|
||||
if ".transformer_1." in key:
|
||||
new_key = key.replace(".transformer_1.", ".")
|
||||
elif ".transformer_2." in key:
|
||||
new_key = key.replace(".transformer_2.", ".")
|
||||
else:
|
||||
new_key = key
|
||||
new_dict[new_key] = combined_dict[key]
|
||||
combined_dict = new_dict
|
||||
|
||||
return combined_dict
|
||||
|
||||
def get_model_to_train(self):
|
||||
# todo, loras wont load right unless they have the transformer_1 or transformer_2 in the key.
|
||||
# called when setting up the LoRA. We only need to get the model for the stages we want to train.
|
||||
if self.train_high_noise and self.train_low_noise:
|
||||
# we are training both stages, return the unified model
|
||||
return self.model
|
||||
elif self.train_high_noise:
|
||||
# we are only training the high noise stage, return transformer_1
|
||||
return self.model.transformer_1
|
||||
elif self.train_low_noise:
|
||||
# we are only training the low noise stage, return transformer_2
|
||||
return self.model.transformer_2
|
||||
else:
|
||||
raise ValueError(
|
||||
"At least one of train_high_noise or train_low_noise must be True in model.model_kwargs"
|
||||
)
|
||||
@@ -12,7 +12,6 @@ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
|
||||
|
||||
class Wan22Pipeline(WanPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -52,6 +51,7 @@ class Wan22Pipeline(WanPipeline):
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator,
|
||||
List[torch.Generator]]] = None,
|
||||
@@ -77,11 +77,15 @@ class Wan22Pipeline(WanPipeline):
|
||||
vae_device = self.vae.device
|
||||
transformer_device = self.transformer.device
|
||||
text_encoder_device = self.text_encoder.device
|
||||
device = self.transformer.device
|
||||
device = self._exec_device
|
||||
|
||||
if self._aggressive_offload:
|
||||
print("Unloading vae")
|
||||
self.vae.to("cpu")
|
||||
print("Unloading transformer")
|
||||
self.transformer.to("cpu")
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.to("cpu")
|
||||
self.text_encoder.to(device)
|
||||
flush()
|
||||
|
||||
@@ -95,9 +99,14 @@ class Wan22Pipeline(WanPipeline):
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
guidance_scale_2
|
||||
)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
@@ -160,6 +169,13 @@ class Wan22Pipeline(WanPipeline):
|
||||
num_warmup_steps = len(timesteps) - \
|
||||
num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
current_model = self.transformer
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -167,6 +183,25 @@ class Wan22Pipeline(WanPipeline):
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
if self._aggressive_offload and current_model != self.transformer:
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.to("cpu")
|
||||
self.transformer.to(device)
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
if self._aggressive_offload and current_model != self.transformer_2:
|
||||
if self.transformer is not None:
|
||||
self.transformer.to("cpu")
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.to(device)
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
latent_model_input = latents.to(device, transformer_dtype)
|
||||
if self.config.expand_timesteps:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
@@ -176,7 +211,7 @@ class Wan22Pipeline(WanPipeline):
|
||||
else:
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
@@ -185,14 +220,14 @@ class Wan22Pipeline(WanPipeline):
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * \
|
||||
noise_pred = noise_uncond + current_guidance_scale * \
|
||||
(noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
@@ -259,10 +294,8 @@ class Wan22Pipeline(WanPipeline):
|
||||
|
||||
# move transformer back to device
|
||||
if self._aggressive_offload:
|
||||
print("Moving transformer back to device")
|
||||
self.transformer.to(self._execution_device)
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.to(self._execution_device)
|
||||
# print("Moving transformer back to device")
|
||||
# self.transformer.to(self._execution_device)
|
||||
flush()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -1862,7 +1862,20 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
total_loss = None
|
||||
self.optimizer.zero_grad()
|
||||
for batch in batch_list:
|
||||
if self.sd.is_multistage:
|
||||
# handle multistage switching
|
||||
if self.steps_this_boundary >= self.train_config.switch_boundary_every:
|
||||
# iterate to make sure we only train trainable_multistage_boundaries
|
||||
while True:
|
||||
self.steps_this_boundary = 0
|
||||
self.current_boundary_index += 1
|
||||
if self.current_boundary_index >= len(self.sd.multistage_boundaries):
|
||||
self.current_boundary_index = 0
|
||||
if self.current_boundary_index in self.sd.trainable_multistage_boundaries:
|
||||
# if this boundary is trainable, we can stop looking
|
||||
break
|
||||
loss = self.train_single_accumulation(batch)
|
||||
self.steps_this_boundary += 1
|
||||
if total_loss is None:
|
||||
total_loss = loss
|
||||
else:
|
||||
@@ -1907,7 +1920,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.adapter.restore_embeddings()
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
{'loss': (total_loss / len(batch_list)).item()}
|
||||
)
|
||||
|
||||
self.end_of_training_loop()
|
||||
|
||||
@@ -260,6 +260,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
)
|
||||
|
||||
self.current_boundary_index = 0
|
||||
self.steps_this_boundary = 0
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -437,19 +440,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# Combine and sort the lists
|
||||
combined_items = safetensors_files + directories + pt_files
|
||||
combined_items.sort(key=os.path.getctime)
|
||||
|
||||
num_saves_to_keep = self.save_config.max_step_saves_to_keep
|
||||
|
||||
if hasattr(self.sd, 'max_step_saves_to_keep_multiplier'):
|
||||
num_saves_to_keep *= self.sd.max_step_saves_to_keep_multiplier
|
||||
|
||||
# Use slicing with a check to avoid 'NoneType' error
|
||||
safetensors_to_remove = safetensors_files[
|
||||
:-self.save_config.max_step_saves_to_keep] if safetensors_files else []
|
||||
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
|
||||
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
|
||||
embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
|
||||
critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else []
|
||||
:-num_saves_to_keep] if safetensors_files else []
|
||||
pt_files_to_remove = pt_files[:-num_saves_to_keep] if pt_files else []
|
||||
directories_to_remove = directories[:-num_saves_to_keep] if directories else []
|
||||
embeddings_to_remove = embed_files[:-num_saves_to_keep] if embed_files else []
|
||||
critic_to_remove = critic_items[:-num_saves_to_keep] if critic_items else []
|
||||
|
||||
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove
|
||||
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
# items_to_remove = combined_items[:-num_saves_to_keep]
|
||||
|
||||
# remove duplicates
|
||||
items_to_remove = list(dict.fromkeys(items_to_remove))
|
||||
@@ -1166,6 +1174,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
num_train_timesteps, device=self.device_torch
|
||||
)
|
||||
if self.sd.is_multistage:
|
||||
with self.timer('adjust_multistage_timesteps'):
|
||||
# get our current sample range
|
||||
boundaries = [1000] + self.sd.multistage_boundaries
|
||||
boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1]
|
||||
lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False)
|
||||
hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True)
|
||||
first_idx = lo.item() if hi > lo else 0
|
||||
last_idx = (hi - 1).item() if hi > lo else 999
|
||||
|
||||
min_noise_steps = first_idx
|
||||
max_noise_steps = last_idx
|
||||
|
||||
# clip min max indicies
|
||||
min_noise_steps = max(min_noise_steps, 0)
|
||||
max_noise_steps = min(max_noise_steps, num_train_timesteps - 1)
|
||||
|
||||
|
||||
with self.timer('prepare_timesteps_indices'):
|
||||
|
||||
content_or_style = self.train_config.content_or_style
|
||||
@@ -1204,11 +1230,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
0,
|
||||
self.train_config.num_train_timesteps - 1,
|
||||
min_noise_steps,
|
||||
max_noise_steps - 1
|
||||
max_noise_steps
|
||||
)
|
||||
timestep_indices = timestep_indices.long().clamp(
|
||||
min_noise_steps + 1,
|
||||
max_noise_steps - 1
|
||||
min_noise_steps,
|
||||
max_noise_steps
|
||||
)
|
||||
|
||||
elif content_or_style == 'balanced':
|
||||
@@ -1221,7 +1247,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.train_config.noise_scheduler == 'flowmatch':
|
||||
# flowmatch uses indices, so we need to use indices
|
||||
min_idx = 0
|
||||
max_idx = max_noise_steps - 1
|
||||
max_idx = max_noise_steps
|
||||
timestep_indices = torch.randint(
|
||||
min_idx,
|
||||
max_idx,
|
||||
@@ -1671,7 +1697,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.network = NetworkClass(
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
unet=self.sd.get_model_to_train(),
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.linear_alpha,
|
||||
|
||||
@@ -185,6 +185,9 @@ class NetworkConfig:
|
||||
self.conv_alpha = 9999999999
|
||||
# -1 automatically finds the largest factor
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
# for multi stage models
|
||||
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']
|
||||
@@ -332,7 +335,7 @@ class TrainConfig:
|
||||
self.lr_scheduler = kwargs.get('lr_scheduler', 'constant')
|
||||
self.lr_scheduler_params = kwargs.get('lr_scheduler_params', {})
|
||||
self.min_denoising_steps: int = kwargs.get('min_denoising_steps', 0)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 1000)
|
||||
self.max_denoising_steps: int = kwargs.get('max_denoising_steps', 999)
|
||||
self.batch_size: int = kwargs.get('batch_size', 1)
|
||||
self.orig_batch_size: int = self.batch_size
|
||||
self.dtype: str = kwargs.get('dtype', 'fp32')
|
||||
@@ -512,6 +515,9 @@ class TrainConfig:
|
||||
self.unconditional_prompt: str = kwargs.get('unconditional_prompt', '')
|
||||
if isinstance(self.guidance_loss_target, tuple):
|
||||
self.guidance_loss_target = list(self.guidance_loss_target)
|
||||
|
||||
# for multi stage models, how often to switch the boundary
|
||||
self.switch_boundary_every: int = kwargs.get('switch_boundary_every', 1)
|
||||
|
||||
|
||||
ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex1', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21']
|
||||
|
||||
@@ -172,6 +172,11 @@ class BaseModel:
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
self.accuracy_recovery_adapter: Union[None, 'LoRASpecialNetwork'] = None
|
||||
self.is_multistage = False
|
||||
# a list of multistage boundaries starting with train step 1000 to first idx
|
||||
self.multistage_boundaries: List[float] = [0.0]
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
@@ -1502,3 +1507,7 @@ class BaseModel:
|
||||
def get_base_model_version(self) -> str:
|
||||
# override in child classes to get the base model version
|
||||
return "unknown"
|
||||
|
||||
def get_model_to_train(self):
|
||||
# called to get model to attach LoRAs to. Can be overridden in child classes
|
||||
return self.unet
|
||||
|
||||
@@ -310,6 +310,7 @@ class Wan21(BaseModel):
|
||||
arch = 'wan21'
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = False
|
||||
_wan_vae_path = None
|
||||
|
||||
_comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors']
|
||||
def __init__(
|
||||
@@ -431,8 +432,14 @@ class Wan21(BaseModel):
|
||||
scheduler = Wan21.get_train_scheduler()
|
||||
self.print_and_status_update("Loading VAE")
|
||||
# todo, example does float 32? check if quality suffers
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
|
||||
|
||||
if self._wan_vae_path is not None:
|
||||
# load the vae from individual repo
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
self._wan_vae_path, torch_dtype=dtype).to(dtype=dtype)
|
||||
else:
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
@@ -565,6 +565,13 @@ class ToolkitNetworkMixin:
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
metadata = add_model_hash_to_meta(save_dict, metadata)
|
||||
# let the model handle the saving
|
||||
|
||||
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'):
|
||||
# call the base model save lora method
|
||||
self.base_model_ref().save_lora(save_dict, file, metadata)
|
||||
return
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
save_file(save_dict, file, metadata)
|
||||
@@ -577,12 +584,15 @@ class ToolkitNetworkMixin:
|
||||
keymap = {} if keymap is None else keymap
|
||||
|
||||
if isinstance(file, str):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'):
|
||||
# call the base model load lora method
|
||||
weights_sd = self.base_model_ref().load_lora(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
else:
|
||||
# probably a state dict
|
||||
weights_sd = file
|
||||
|
||||
@@ -211,6 +211,12 @@ class StableDiffusion:
|
||||
|
||||
self.sample_prompts_cache = None
|
||||
|
||||
self.is_multistage = False
|
||||
# a list of multistage boundaries starting with train step 1000 to first idx
|
||||
self.multistage_boundaries: List[float] = [0.0]
|
||||
# a list of trainable multistage boundaries
|
||||
self.trainable_multistage_boundaries: List[int] = [0]
|
||||
|
||||
# properties for old arch for backwards compatibility
|
||||
@property
|
||||
def is_xl(self):
|
||||
@@ -3123,3 +3129,6 @@ class StableDiffusion:
|
||||
if self.is_v2:
|
||||
return 'sd_2.1'
|
||||
return 'sd_1.5'
|
||||
|
||||
def get_model_to_train(self):
|
||||
return self.unet
|
||||
|
||||
@@ -40,10 +40,25 @@ export default function SimpleJob({
|
||||
|
||||
const isVideoModel = !!(modelArch?.group === 'video');
|
||||
|
||||
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
|
||||
const numTopCards = useMemo(() => {
|
||||
let count = 4; // job settings, model config, target config, save config
|
||||
if (modelArch?.additionalSections?.includes('model.multistage')) {
|
||||
count += 1; // add multistage card
|
||||
}
|
||||
if (!modelArch?.disableSections?.includes('model.quantize')) {
|
||||
count += 1; // add quantization card
|
||||
}
|
||||
return count;
|
||||
|
||||
}, [modelArch]);
|
||||
|
||||
if (modelArch?.disableSections?.includes('model.quantize')) {
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
||||
let topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 xl:grid-cols-4 gap-6';
|
||||
|
||||
if (numTopCards == 5) {
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-5 gap-6';
|
||||
}
|
||||
if (numTopCards == 6) {
|
||||
topBarClass = 'grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-6 gap-6';
|
||||
}
|
||||
|
||||
const transformerQuantizationOptions: GroupedSelectOption[] | SelectOption[] = useMemo(() => {
|
||||
@@ -91,7 +106,7 @@ export default function SimpleJob({
|
||||
<>
|
||||
<form onSubmit={handleSubmit} className="space-y-8">
|
||||
<div className={topBarClass}>
|
||||
<Card title="Job Settings">
|
||||
<Card title="Job">
|
||||
<TextInput
|
||||
label="Training Name"
|
||||
value={jobConfig.config.name}
|
||||
@@ -124,7 +139,7 @@ export default function SimpleJob({
|
||||
</Card>
|
||||
|
||||
{/* Model Configuration Section */}
|
||||
<Card title="Model Configuration">
|
||||
<Card title="Model">
|
||||
<SelectInput
|
||||
label="Model Architecture"
|
||||
value={jobConfig.config.process[0].model.arch}
|
||||
@@ -239,7 +254,32 @@ export default function SimpleJob({
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Target Configuration">
|
||||
{modelArch?.additionalSections?.includes('model.multistage') && (
|
||||
<Card title="Multistage">
|
||||
<FormGroup label="Stages to Train" docKey={'model.multistage'}>
|
||||
<Checkbox
|
||||
label="High Noise"
|
||||
checked={jobConfig.config.process[0].model.model_kwargs?.train_high_noise || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_high_noise')}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Low Noise"
|
||||
checked={jobConfig.config.process[0].model.model_kwargs?.train_low_noise || false}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].model.model_kwargs.train_low_noise')}
|
||||
/>
|
||||
</FormGroup>
|
||||
<NumberInput
|
||||
label="Switch Every"
|
||||
value={jobConfig.config.process[0].train.switch_boundary_every}
|
||||
onChange={value => setJobConfig(value, 'config.process[0].train.switch_boundary_every')}
|
||||
placeholder="eg. 1"
|
||||
docKey={'train.switch_boundary_every'}
|
||||
min={1}
|
||||
required
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
<Card title="Target">
|
||||
<SelectInput
|
||||
label="Target Type"
|
||||
value={jobConfig.config.process[0].network?.type ?? 'lora'}
|
||||
@@ -295,7 +335,7 @@ export default function SimpleJob({
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
<Card title="Save Configuration">
|
||||
<Card title="Save">
|
||||
<SelectInput
|
||||
label="Data Type"
|
||||
value={jobConfig.config.process[0].save.dtype}
|
||||
@@ -325,7 +365,7 @@ export default function SimpleJob({
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Training Configuration">
|
||||
<Card title="Training">
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 lg:grid-cols-5 gap-6">
|
||||
<div>
|
||||
<NumberInput
|
||||
@@ -645,7 +685,7 @@ export default function SimpleJob({
|
||||
</Card>
|
||||
</div>
|
||||
<div>
|
||||
<Card title="Sample Configuration">
|
||||
<Card title="Sample">
|
||||
<div
|
||||
className={
|
||||
isVideoModel
|
||||
|
||||
@@ -78,6 +78,7 @@ export const defaultJobConfig: JobConfig = {
|
||||
diff_output_preservation: false,
|
||||
diff_output_preservation_multiplier: 1.0,
|
||||
diff_output_preservation_class: 'person',
|
||||
switch_boundary_every: 1,
|
||||
},
|
||||
model: {
|
||||
name_or_path: 'ostris/Flex.1-alpha',
|
||||
|
||||
@@ -3,7 +3,13 @@ import { GroupedSelectOption, SelectOption } from '@/types';
|
||||
type Control = 'depth' | 'line' | 'pose' | 'inpaint';
|
||||
|
||||
type DisableableSections = 'model.quantize' | 'train.timestep_type' | 'network.conv';
|
||||
type AdditionalSections = 'datasets.control_path' | 'datasets.do_i2v' | 'sample.ctrl_img' | 'datasets.num_frames' | 'model.low_vram';
|
||||
type AdditionalSections =
|
||||
| 'datasets.control_path'
|
||||
| 'datasets.do_i2v'
|
||||
| 'sample.ctrl_img'
|
||||
| 'datasets.num_frames'
|
||||
| 'model.multistage'
|
||||
| 'model.low_vram';
|
||||
type ModelGroup = 'image' | 'video';
|
||||
|
||||
export interface ModelArch {
|
||||
@@ -121,7 +127,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
||||
@@ -139,7 +145,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
@@ -158,7 +164,7 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
'config.process[0].train.timestep_type': ['weighted', 'sigmoid'],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
@@ -177,11 +183,41 @@ export const modelArchs: ModelArch[] = [
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [15, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.low_vram'],
|
||||
},
|
||||
{
|
||||
name: 'wan22_14b:t2v',
|
||||
label: 'Wan 2.2 (14B)',
|
||||
group: 'video',
|
||||
isVideoModel: true,
|
||||
defaults: {
|
||||
// default updates when [selected, unselected] in the UI
|
||||
'config.process[0].model.name_or_path': ['ai-toolkit/Wan2.2-T2V-A14B-Diffusers-bf16', defaultNameOrPath],
|
||||
'config.process[0].model.quantize': [true, false],
|
||||
'config.process[0].model.quantize_te': [true, false],
|
||||
'config.process[0].sample.sampler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].train.noise_scheduler': ['flowmatch', 'flowmatch'],
|
||||
'config.process[0].sample.num_frames': [41, 1],
|
||||
'config.process[0].sample.fps': [16, 1],
|
||||
'config.process[0].model.low_vram': [true, false],
|
||||
'config.process[0].train.timestep_type': ['linear', 'sigmoid'],
|
||||
'config.process[0].model.model_kwargs': [
|
||||
{
|
||||
train_high_noise: true,
|
||||
train_low_noise: true,
|
||||
},
|
||||
{},
|
||||
],
|
||||
},
|
||||
disableSections: ['network.conv'],
|
||||
additionalSections: ['datasets.num_frames', 'model.low_vram', 'model.multistage'],
|
||||
// accuracyRecoveryAdapters: {
|
||||
// '3 bit with ARA': 'uint3|ostris/accuracy_recovery_adapters/wan22_14b_t2i_torchao_uint3.safetensors',
|
||||
// },
|
||||
},
|
||||
{
|
||||
name: 'wan22_5b',
|
||||
label: 'Wan 2.2 TI2V (5B)',
|
||||
|
||||
@@ -27,7 +27,7 @@ export default function SampleImages({ job }: SampleImagesProps) {
|
||||
// This way Tailwind can properly generate the class
|
||||
// I hate this, but it's the only way to make it work
|
||||
const gridColsClass = useMemo(() => {
|
||||
const cols = Math.min(numSamples, 20);
|
||||
const cols = Math.min(numSamples, 40);
|
||||
|
||||
switch (cols) {
|
||||
case 1:
|
||||
@@ -70,6 +70,46 @@ export default function SampleImages({ job }: SampleImagesProps) {
|
||||
return 'grid-cols-19';
|
||||
case 20:
|
||||
return 'grid-cols-20';
|
||||
case 21:
|
||||
return 'grid-cols-21';
|
||||
case 22:
|
||||
return 'grid-cols-22';
|
||||
case 23:
|
||||
return 'grid-cols-23';
|
||||
case 24:
|
||||
return 'grid-cols-24';
|
||||
case 25:
|
||||
return 'grid-cols-25';
|
||||
case 26:
|
||||
return 'grid-cols-26';
|
||||
case 27:
|
||||
return 'grid-cols-27';
|
||||
case 28:
|
||||
return 'grid-cols-28';
|
||||
case 29:
|
||||
return 'grid-cols-29';
|
||||
case 30:
|
||||
return 'grid-cols-30';
|
||||
case 31:
|
||||
return 'grid-cols-31';
|
||||
case 32:
|
||||
return 'grid-cols-32';
|
||||
case 33:
|
||||
return 'grid-cols-33';
|
||||
case 34:
|
||||
return 'grid-cols-34';
|
||||
case 35:
|
||||
return 'grid-cols-35';
|
||||
case 36:
|
||||
return 'grid-cols-36';
|
||||
case 37:
|
||||
return 'grid-cols-37';
|
||||
case 38:
|
||||
return 'grid-cols-38';
|
||||
case 39:
|
||||
return 'grid-cols-39';
|
||||
case 40:
|
||||
return 'grid-cols-40';
|
||||
default:
|
||||
return 'grid-cols-1';
|
||||
}
|
||||
|
||||
@@ -283,7 +283,7 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
|
||||
return (
|
||||
<div className={classNames(className)}>
|
||||
{label && (
|
||||
<label className={labelClasses}>
|
||||
<label className={classNames(labelClasses, 'mb-2')}>
|
||||
{label}{' '}
|
||||
{doc && (
|
||||
<div className="inline-block ml-1 text-xs text-gray-500 cursor-pointer" onClick={() => openDoc(doc)}>
|
||||
@@ -292,7 +292,7 @@ export const FormGroup: React.FC<FormGroupProps> = props => {
|
||||
)}
|
||||
</label>
|
||||
)}
|
||||
<div className="px-4 space-y-2">{children}</div>
|
||||
<div className="space-y-2">{children}</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -111,6 +111,36 @@ const docs: { [key: string]: ConfigDoc } = {
|
||||
</>
|
||||
),
|
||||
},
|
||||
'model.multistage': {
|
||||
title: 'Stages to Train',
|
||||
description: (
|
||||
<>
|
||||
Some models have multi stage networks that are trained and used separately in the denoising process. Most
|
||||
common, is to have 2 stages. One for high noise and one for low noise. You can choose to train both stages at
|
||||
once or train them separately. If trained at the same time, The trainer will alternate between training each
|
||||
model every so many steps and will output 2 different LoRAs. If you choose to train only one stage, the
|
||||
trainer will only train that stage and output a single LoRA.
|
||||
</>
|
||||
),
|
||||
},
|
||||
'train.switch_boundary_every': {
|
||||
title: 'Switch Boundary Every',
|
||||
description: (
|
||||
<>
|
||||
When training a model with multiple stages, this setting controls how often the trainer will switch between
|
||||
training each stage.
|
||||
<br />
|
||||
<br />
|
||||
For low vram settings, the model not being trained will be unloaded from the gpu to save memory. This takes some
|
||||
time to do, so it is recommended to alternate less often when using low vram. A setting like 10 or 20 is
|
||||
recommended for low vram settings.
|
||||
<br />
|
||||
<br />
|
||||
The swap happens at the batch level, meaning it will swap between a gradient accumulation steps. To train both
|
||||
stages in a single step, set them to switch every 1 step and set gradient accumulation to 2.
|
||||
</>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const getDoc = (key: string | null | undefined): ConfigDoc | null => {
|
||||
|
||||
@@ -119,6 +119,7 @@ export interface TrainConfig {
|
||||
diff_output_preservation: boolean;
|
||||
diff_output_preservation_multiplier: number;
|
||||
diff_output_preservation_class: string;
|
||||
switch_boundary_every: number;
|
||||
}
|
||||
|
||||
export interface QuantizeKwargsConfig {
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = "0.4.1"
|
||||
VERSION = "0.5.0"
|
||||
Reference in New Issue
Block a user