mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-23 22:03:57 +00:00
Wan22 14b training is working, still need tons of testing and some bug fixes
This commit is contained in:
@@ -3,7 +3,7 @@ from .hidream import HidreamModel, HidreamE1Model
|
||||
from .f_light import FLiteModel
|
||||
from .omnigen2 import OmniGen2Model
|
||||
from .flux_kontext import FluxKontextModel
|
||||
from .wan22 import Wan225bModel
|
||||
from .wan22 import Wan225bModel, Wan2214bModel
|
||||
from .qwen_image import QwenImageModel
|
||||
|
||||
AI_TOOLKIT_MODELS = [
|
||||
@@ -15,5 +15,6 @@ AI_TOOLKIT_MODELS = [
|
||||
OmniGen2Model,
|
||||
FluxKontextModel,
|
||||
Wan225bModel,
|
||||
Wan2214bModel,
|
||||
QwenImageModel,
|
||||
]
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .wan22_5b_model import Wan225bModel
|
||||
from .wan22_5b_model import Wan225bModel
|
||||
from .wan22_14b_model import Wan2214bModel
|
||||
470
extensions_built_in/diffusion_models/wan22/wan22_14b_model.py
Normal file
470
extensions_built_in/diffusion_models/wan22/wan22_14b_model.py
Normal file
@@ -0,0 +1,470 @@
|
||||
from functools import partial
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing_extensions import Self
|
||||
import torch
|
||||
import yaml
|
||||
from toolkit.accelerator import unwrap_model
|
||||
from toolkit.basic import flush
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from PIL import Image
|
||||
from diffusers import UniPCMultistepScheduler
|
||||
import torch
|
||||
from toolkit.config_modules import GenerateImageConfig, ModelConfig
|
||||
from toolkit.samplers.custom_flowmatch_sampler import (
|
||||
CustomFlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from toolkit.util.quantize import quantize_model
|
||||
from .wan22_pipeline import Wan22Pipeline
|
||||
from diffusers import WanTransformer3DModel
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from torchvision.transforms import functional as TF
|
||||
|
||||
from toolkit.models.wan21.wan21 import AggressiveWanUnloadPipeline
|
||||
from .wan22_5b_model import (
|
||||
scheduler_config,
|
||||
time_text_monkeypatch,
|
||||
Wan225bModel,
|
||||
)
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
boundary_ratio_t2v = 0.875
|
||||
boundary_ratio_i2v = 0.9
|
||||
|
||||
scheduler_configUniPC = {
|
||||
"_class_name": "UniPCMultistepScheduler",
|
||||
"_diffusers_version": "0.35.0.dev0",
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"beta_start": 0.0001,
|
||||
"disable_corrector": [],
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"final_sigmas_type": "zero",
|
||||
"flow_shift": 3.0,
|
||||
"lower_order_final": True,
|
||||
"num_train_timesteps": 1000,
|
||||
"predict_x0": True,
|
||||
"prediction_type": "flow_prediction",
|
||||
"rescale_betas_zero_snr": False,
|
||||
"sample_max_value": 1.0,
|
||||
"solver_order": 2,
|
||||
"solver_p": None,
|
||||
"solver_type": "bh2",
|
||||
"steps_offset": 0,
|
||||
"thresholding": False,
|
||||
"time_shift_type": "exponential",
|
||||
"timestep_spacing": "linspace",
|
||||
"trained_betas": None,
|
||||
"use_beta_sigmas": False,
|
||||
"use_dynamic_shifting": False,
|
||||
"use_exponential_sigmas": False,
|
||||
"use_flow_sigmas": True,
|
||||
"use_karras_sigmas": False,
|
||||
}
|
||||
|
||||
|
||||
class DualWanTransformer3DModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
transformer_1: WanTransformer3DModel,
|
||||
transformer_2: WanTransformer3DModel,
|
||||
torch_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
boundary_ratio: float = boundary_ratio_t2v,
|
||||
low_vram: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.transformer_1: WanTransformer3DModel = transformer_1
|
||||
self.transformer_2: WanTransformer3DModel = transformer_2
|
||||
self.torch_dtype: torch.dtype = torch_dtype
|
||||
self.device_torch: torch.device = device
|
||||
self.boundary_ratio: float = boundary_ratio
|
||||
self.boundary: float = self.boundary_ratio * 1000
|
||||
self.low_vram: bool = low_vram
|
||||
self._active_transformer_name = "transformer_1" # default to transformer_1
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.device_torch
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.torch_dtype
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.transformer_1.config
|
||||
|
||||
@property
|
||||
def transformer(self) -> WanTransformer3DModel:
|
||||
return getattr(self, self._active_transformer_name)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
"""
|
||||
Enable gradient checkpointing for both transformers.
|
||||
"""
|
||||
self.transformer_1.enable_gradient_checkpointing()
|
||||
self.transformer_2.enable_gradient_checkpointing()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
# determine if doing high noise or low noise by meaning the timestep.
|
||||
# timesteps are in the range of 0 to 1000, so we can use a threshold
|
||||
with torch.no_grad():
|
||||
if timestep.float().mean().item() >= self.boundary:
|
||||
t_name = "transformer_1"
|
||||
else:
|
||||
t_name = "transformer_2"
|
||||
|
||||
# check if we are changing the active transformer, if so, we need to swap the one in
|
||||
# vram if low_vram is enabled
|
||||
# todo swap the loras as well
|
||||
if t_name != self._active_transformer_name:
|
||||
if self.low_vram:
|
||||
getattr(self, self._active_transformer_name).to("cpu")
|
||||
getattr(self, t_name).to(self.device_torch)
|
||||
torch.cuda.empty_cache()
|
||||
self._active_transformer_name = t_name
|
||||
if self.transformer.device != hidden_states.device:
|
||||
raise ValueError(
|
||||
f"Transformer device {self.transformer.device} does not match hidden states device {hidden_states.device}"
|
||||
)
|
||||
|
||||
return self.transformer(
|
||||
hidden_states=hidden_states,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_image=encoder_hidden_states_image,
|
||||
return_dict=return_dict,
|
||||
attention_kwargs=attention_kwargs,
|
||||
)
|
||||
|
||||
def to(self, *args, **kwargs) -> Self:
|
||||
# do not do to, this will be handled separately
|
||||
return self
|
||||
|
||||
|
||||
class Wan2214bModel(Wan225bModel):
|
||||
arch = "wan22_14b"
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = True
|
||||
_wan_vae_path = "ai-toolkit/wan2.1-vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device,
|
||||
model_config: ModelConfig,
|
||||
dtype="bf16",
|
||||
custom_pipeline=None,
|
||||
noise_scheduler=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
device=device,
|
||||
model_config=model_config,
|
||||
dtype=dtype,
|
||||
custom_pipeline=custom_pipeline,
|
||||
noise_scheduler=noise_scheduler,
|
||||
**kwargs,
|
||||
)
|
||||
# target it so we can target both transformers
|
||||
self.target_lora_modules = ["DualWanTransformer3DModel"]
|
||||
self._wan_cache = None
|
||||
|
||||
@property
|
||||
def max_step_saves_to_keep_multiplier(self):
|
||||
# the cleanup mechanism checks this to see how many saves to keep
|
||||
# if we are training a LoRA, we need to set this to 2 so we keep both the high noise and low noise LoRAs at saves to keep
|
||||
if self.network is not None:
|
||||
return 2
|
||||
return 1
|
||||
|
||||
def load_model(self):
|
||||
# load model from patent parent. Wan21 not immediate parent
|
||||
# super().load_model()
|
||||
super(Wan225bModel, self).load_model()
|
||||
|
||||
# we have to split up the model on the pipeline
|
||||
self.pipeline.transformer = self.model.transformer_1
|
||||
self.pipeline.transformer_2 = self.model.transformer_2
|
||||
|
||||
# patch the condition embedder
|
||||
self.model.transformer_1.condition_embedder.forward = partial(
|
||||
time_text_monkeypatch, self.model.transformer_1.condition_embedder
|
||||
)
|
||||
self.model.transformer_2.condition_embedder.forward = partial(
|
||||
time_text_monkeypatch, self.model.transformer_2.condition_embedder
|
||||
)
|
||||
|
||||
def get_bucket_divisibility(self):
|
||||
# 16x compression and 2x2 patch size
|
||||
return 32
|
||||
|
||||
def load_wan_transformer(self, transformer_path, subfolder=None):
|
||||
if self.model_config.split_model_over_gpus:
|
||||
raise ValueError(
|
||||
"Splitting model over gpus is not supported for Wan2.2 models"
|
||||
)
|
||||
|
||||
if (
|
||||
self.model_config.assistant_lora_path is not None
|
||||
or self.model_config.inference_lora_path is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"Assistant LoRA is not supported for Wan2.2 models currently"
|
||||
)
|
||||
|
||||
if self.model_config.lora_path is not None:
|
||||
raise ValueError(
|
||||
"Loading LoRA is not supported for Wan2.2 models currently"
|
||||
)
|
||||
|
||||
# transformer path can be a directory that ends with /transformer or a hf path.
|
||||
|
||||
transformer_path_1 = transformer_path
|
||||
subfolder_1 = subfolder
|
||||
|
||||
transformer_path_2 = transformer_path
|
||||
subfolder_2 = subfolder
|
||||
|
||||
if subfolder_2 is None:
|
||||
# we have a local path, replace it with transformer_2 folder
|
||||
transformer_path_2 = os.path.join(
|
||||
os.path.dirname(transformer_path_1), "transformer_2"
|
||||
)
|
||||
else:
|
||||
# we have a hf path, replace it with transformer_2 subfolder
|
||||
subfolder_2 = "transformer_2"
|
||||
|
||||
self.print_and_status_update("Loading transformer 1")
|
||||
dtype = self.torch_dtype
|
||||
transformer_1 = WanTransformer3DModel.from_pretrained(
|
||||
transformer_path_1,
|
||||
subfolder=subfolder_1,
|
||||
torch_dtype=dtype,
|
||||
).to(dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer_1.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
# todo handle two ARAs
|
||||
self.print_and_status_update("Quantizing Transformer 1")
|
||||
quantize_model(self, transformer_1)
|
||||
flush()
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer 1 to CPU")
|
||||
transformer_1.to("cpu")
|
||||
|
||||
self.print_and_status_update("Loading transformer 2")
|
||||
dtype = self.torch_dtype
|
||||
transformer_2 = WanTransformer3DModel.from_pretrained(
|
||||
transformer_path_2,
|
||||
subfolder=subfolder_2,
|
||||
torch_dtype=dtype,
|
||||
).to(dtype=dtype)
|
||||
|
||||
flush()
|
||||
|
||||
if not self.model_config.low_vram:
|
||||
# quantize on the device
|
||||
transformer_2.to(self.quantize_device, dtype=dtype)
|
||||
flush()
|
||||
|
||||
if self.model_config.quantize:
|
||||
# todo handle two ARAs
|
||||
self.print_and_status_update("Quantizing Transformer 2")
|
||||
quantize_model(self, transformer_2)
|
||||
flush()
|
||||
|
||||
if self.model_config.low_vram:
|
||||
self.print_and_status_update("Moving transformer 2 to CPU")
|
||||
transformer_2.to("cpu")
|
||||
|
||||
# make the combined model
|
||||
self.print_and_status_update("Creating DualWanTransformer3DModel")
|
||||
transformer = DualWanTransformer3DModel(
|
||||
transformer_1=transformer_1,
|
||||
transformer_2=transformer_2,
|
||||
torch_dtype=self.torch_dtype,
|
||||
device=self.device_torch,
|
||||
boundary_ratio=boundary_ratio_t2v,
|
||||
low_vram=self.model_config.low_vram,
|
||||
)
|
||||
|
||||
return transformer
|
||||
|
||||
def get_generation_pipeline(self):
|
||||
scheduler = UniPCMultistepScheduler(**self._wan_generation_scheduler_config)
|
||||
pipeline = Wan22Pipeline(
|
||||
vae=self.vae,
|
||||
transformer=self.model.transformer_1,
|
||||
transformer_2=self.model.transformer_2,
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
scheduler=scheduler,
|
||||
expand_timesteps=self._wan_expand_timesteps,
|
||||
device=self.device_torch,
|
||||
aggressive_offload=self.model_config.low_vram,
|
||||
# todo detect if it is i2v or t2v
|
||||
boundary_ratio=boundary_ratio_t2v,
|
||||
)
|
||||
|
||||
# pipeline = pipeline.to(self.device_torch)
|
||||
|
||||
return pipeline
|
||||
|
||||
# static method to get the scheduler
|
||||
@staticmethod
|
||||
def get_train_scheduler():
|
||||
scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config)
|
||||
return scheduler
|
||||
|
||||
def get_base_model_version(self):
|
||||
return "wan_2.2_14b"
|
||||
|
||||
def generate_single_image(
|
||||
self,
|
||||
pipeline: AggressiveWanUnloadPipeline,
|
||||
gen_config: GenerateImageConfig,
|
||||
conditional_embeds: PromptEmbeds,
|
||||
unconditional_embeds: PromptEmbeds,
|
||||
generator: torch.Generator,
|
||||
extra: dict,
|
||||
):
|
||||
return super().generate_single_image(
|
||||
pipeline=pipeline,
|
||||
gen_config=gen_config,
|
||||
conditional_embeds=conditional_embeds,
|
||||
unconditional_embeds=unconditional_embeds,
|
||||
generator=generator,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
def get_noise_prediction(
|
||||
self,
|
||||
latent_model_input: torch.Tensor,
|
||||
timestep: torch.Tensor, # 0 to 1000 scale
|
||||
text_embeddings: PromptEmbeds,
|
||||
batch: DataLoaderBatchDTO,
|
||||
**kwargs,
|
||||
):
|
||||
# todo do we need to override this? Adjust timesteps?
|
||||
return super().get_noise_prediction(
|
||||
latent_model_input=latent_model_input,
|
||||
timestep=timestep,
|
||||
text_embeddings=text_embeddings,
|
||||
batch=batch,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_model_has_grad(self):
|
||||
return False
|
||||
|
||||
def get_te_has_grad(self):
|
||||
return False
|
||||
|
||||
def save_model(self, output_path, meta, save_dtype):
|
||||
transformer_combo: DualWanTransformer3DModel = unwrap_model(self.model)
|
||||
transformer_combo.transformer_1.save_pretrained(
|
||||
save_directory=os.path.join(output_path, "transformer"),
|
||||
safe_serialization=True,
|
||||
)
|
||||
transformer_combo.transformer_2.save_pretrained(
|
||||
save_directory=os.path.join(output_path, "transformer_2"),
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
meta_path = os.path.join(output_path, "aitk_meta.yaml")
|
||||
with open(meta_path, "w") as f:
|
||||
yaml.dump(meta, f)
|
||||
|
||||
def save_lora(
|
||||
self,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
output_path: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if not self.network.network_config.split_multistage_loras:
|
||||
# just save as a combo lora
|
||||
save_file(state_dict, output_path, metadata=metadata)
|
||||
return
|
||||
|
||||
# we need to build out both dictionaries for high and low noise LoRAs
|
||||
high_noise_lora = {}
|
||||
low_noise_lora = {}
|
||||
|
||||
for key in state_dict:
|
||||
if ".transformer_1." in key:
|
||||
# this is a high noise LoRA
|
||||
new_key = key.replace(".transformer_1.", ".")
|
||||
high_noise_lora[new_key] = state_dict[key]
|
||||
elif ".transformer_2." in key:
|
||||
# this is a low noise LoRA
|
||||
new_key = key.replace(".transformer_2.", ".")
|
||||
low_noise_lora[new_key] = state_dict[key]
|
||||
|
||||
# loras have either LORA_MODEL_NAME_000005000.safetensors or LORA_MODEL_NAME.safetensors
|
||||
if len(high_noise_lora.keys()) > 0:
|
||||
# save the high noise LoRA
|
||||
high_noise_lora_path = output_path.replace(
|
||||
".safetensors", "_high_noise.safetensors"
|
||||
)
|
||||
save_file(high_noise_lora, high_noise_lora_path, metadata=metadata)
|
||||
|
||||
if len(low_noise_lora.keys()) > 0:
|
||||
# save the low noise LoRA
|
||||
low_noise_lora_path = output_path.replace(
|
||||
".safetensors", "_low_noise.safetensors"
|
||||
)
|
||||
save_file(low_noise_lora, low_noise_lora_path, metadata=metadata)
|
||||
|
||||
def load_lora(self, file: str):
|
||||
# if it doesnt have high_noise or low_noise, it is a combo LoRA
|
||||
if "_high_noise.safetensors" not in file and "_low_noise.safetensors" not in file:
|
||||
# this is a combined LoRA, we need to split it up
|
||||
sd = load_file(file)
|
||||
return sd
|
||||
|
||||
# we may have been passed the high_noise or the low_noise LoRA path, but we need to load both
|
||||
high_noise_lora_path = file.replace(
|
||||
"_low_noise.safetensors", "_high_noise.safetensors"
|
||||
)
|
||||
low_noise_lora_path = file.replace(
|
||||
"_high_noise.safetensors", "_low_noise.safetensors"
|
||||
)
|
||||
|
||||
combined_dict = {}
|
||||
|
||||
if os.path.exists(high_noise_lora_path):
|
||||
# load the high noise LoRA
|
||||
high_noise_lora = load_file(high_noise_lora_path)
|
||||
for key in high_noise_lora:
|
||||
new_key = key.replace(
|
||||
"diffusion_model.", "diffusion_model.transformer_1."
|
||||
)
|
||||
combined_dict[new_key] = high_noise_lora[key]
|
||||
if os.path.exists(low_noise_lora_path):
|
||||
# load the low noise LoRA
|
||||
low_noise_lora = load_file(low_noise_lora_path)
|
||||
for key in low_noise_lora:
|
||||
new_key = key.replace(
|
||||
"diffusion_model.", "diffusion_model.transformer_2."
|
||||
)
|
||||
combined_dict[new_key] = low_noise_lora[key]
|
||||
|
||||
return combined_dict
|
||||
@@ -12,7 +12,6 @@ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
|
||||
|
||||
class Wan22Pipeline(WanPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -52,6 +51,7 @@ class Wan22Pipeline(WanPipeline):
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
guidance_scale_2: Optional[float] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator,
|
||||
List[torch.Generator]]] = None,
|
||||
@@ -77,11 +77,15 @@ class Wan22Pipeline(WanPipeline):
|
||||
vae_device = self.vae.device
|
||||
transformer_device = self.transformer.device
|
||||
text_encoder_device = self.text_encoder.device
|
||||
device = self.transformer.device
|
||||
device = self._exec_device
|
||||
|
||||
if self._aggressive_offload:
|
||||
print("Unloading vae")
|
||||
self.vae.to("cpu")
|
||||
print("Unloading transformer")
|
||||
self.transformer.to("cpu")
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.to("cpu")
|
||||
self.text_encoder.to(device)
|
||||
flush()
|
||||
|
||||
@@ -95,9 +99,14 @@ class Wan22Pipeline(WanPipeline):
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
guidance_scale_2
|
||||
)
|
||||
|
||||
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
||||
guidance_scale_2 = guidance_scale
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_scale_2 = guidance_scale_2
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
@@ -160,6 +169,13 @@ class Wan22Pipeline(WanPipeline):
|
||||
num_warmup_steps = len(timesteps) - \
|
||||
num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
boundary_timestep = None
|
||||
|
||||
current_model = self.transformer
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -167,6 +183,25 @@ class Wan22Pipeline(WanPipeline):
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if boundary_timestep is None or t >= boundary_timestep:
|
||||
if self._aggressive_offload and current_model != self.transformer:
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.to("cpu")
|
||||
self.transformer.to(device)
|
||||
# wan2.1 or high-noise stage in wan2.2
|
||||
current_model = self.transformer
|
||||
current_guidance_scale = guidance_scale
|
||||
else:
|
||||
if self._aggressive_offload and current_model != self.transformer_2:
|
||||
if self.transformer is not None:
|
||||
self.transformer.to("cpu")
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.to(device)
|
||||
# low-noise stage in wan2.2
|
||||
current_model = self.transformer_2
|
||||
current_guidance_scale = guidance_scale_2
|
||||
|
||||
latent_model_input = latents.to(device, transformer_dtype)
|
||||
if self.config.expand_timesteps:
|
||||
# seq_len: num_latent_frames * latent_height//2 * latent_width//2
|
||||
@@ -176,7 +211,7 @@ class Wan22Pipeline(WanPipeline):
|
||||
else:
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
noise_pred = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
@@ -185,14 +220,14 @@ class Wan22Pipeline(WanPipeline):
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
noise_uncond = current_model(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * \
|
||||
noise_pred = noise_uncond + current_guidance_scale * \
|
||||
(noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
@@ -259,10 +294,8 @@ class Wan22Pipeline(WanPipeline):
|
||||
|
||||
# move transformer back to device
|
||||
if self._aggressive_offload:
|
||||
print("Moving transformer back to device")
|
||||
self.transformer.to(self._execution_device)
|
||||
if self.transformer_2 is not None:
|
||||
self.transformer_2.to(self._execution_device)
|
||||
# print("Moving transformer back to device")
|
||||
# self.transformer.to(self._execution_device)
|
||||
flush()
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -437,19 +437,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# Combine and sort the lists
|
||||
combined_items = safetensors_files + directories + pt_files
|
||||
combined_items.sort(key=os.path.getctime)
|
||||
|
||||
num_saves_to_keep = self.save_config.max_step_saves_to_keep
|
||||
|
||||
if hasattr(self.sd, 'max_step_saves_to_keep_multiplier'):
|
||||
num_saves_to_keep *= self.sd.max_step_saves_to_keep_multiplier
|
||||
|
||||
# Use slicing with a check to avoid 'NoneType' error
|
||||
safetensors_to_remove = safetensors_files[
|
||||
:-self.save_config.max_step_saves_to_keep] if safetensors_files else []
|
||||
pt_files_to_remove = pt_files[:-self.save_config.max_step_saves_to_keep] if pt_files else []
|
||||
directories_to_remove = directories[:-self.save_config.max_step_saves_to_keep] if directories else []
|
||||
embeddings_to_remove = embed_files[:-self.save_config.max_step_saves_to_keep] if embed_files else []
|
||||
critic_to_remove = critic_items[:-self.save_config.max_step_saves_to_keep] if critic_items else []
|
||||
:-num_saves_to_keep] if safetensors_files else []
|
||||
pt_files_to_remove = pt_files[:-num_saves_to_keep] if pt_files else []
|
||||
directories_to_remove = directories[:-num_saves_to_keep] if directories else []
|
||||
embeddings_to_remove = embed_files[:-num_saves_to_keep] if embed_files else []
|
||||
critic_to_remove = critic_items[:-num_saves_to_keep] if critic_items else []
|
||||
|
||||
items_to_remove = safetensors_to_remove + pt_files_to_remove + directories_to_remove + embeddings_to_remove + critic_to_remove
|
||||
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
# items_to_remove = combined_items[:-self.save_config.max_step_saves_to_keep]
|
||||
# items_to_remove = combined_items[:-num_saves_to_keep]
|
||||
|
||||
# remove duplicates
|
||||
items_to_remove = list(dict.fromkeys(items_to_remove))
|
||||
|
||||
@@ -185,6 +185,9 @@ class NetworkConfig:
|
||||
self.conv_alpha = 9999999999
|
||||
# -1 automatically finds the largest factor
|
||||
self.lokr_factor = kwargs.get('lokr_factor', -1)
|
||||
|
||||
# for multi stage models
|
||||
self.split_multistage_loras = kwargs.get('split_multistage_loras', True)
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'control_net', 'control_lora', 'i2v']
|
||||
|
||||
@@ -310,6 +310,7 @@ class Wan21(BaseModel):
|
||||
arch = 'wan21'
|
||||
_wan_generation_scheduler_config = scheduler_configUniPC
|
||||
_wan_expand_timesteps = False
|
||||
_wan_vae_path = None
|
||||
|
||||
_comfy_te_file = ['text_encoders/umt5_xxl_fp16.safetensors', 'text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors']
|
||||
def __init__(
|
||||
@@ -431,8 +432,14 @@ class Wan21(BaseModel):
|
||||
scheduler = Wan21.get_train_scheduler()
|
||||
self.print_and_status_update("Loading VAE")
|
||||
# todo, example does float 32? check if quality suffers
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
|
||||
|
||||
if self._wan_vae_path is not None:
|
||||
# load the vae from individual repo
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
self._wan_vae_path, torch_dtype=dtype).to(dtype=dtype)
|
||||
else:
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
vae_path, subfolder="vae", torch_dtype=dtype).to(dtype=dtype)
|
||||
flush()
|
||||
|
||||
self.print_and_status_update("Making pipe")
|
||||
|
||||
@@ -565,6 +565,13 @@ class ToolkitNetworkMixin:
|
||||
if metadata is None:
|
||||
metadata = OrderedDict()
|
||||
metadata = add_model_hash_to_meta(save_dict, metadata)
|
||||
# let the model handle the saving
|
||||
|
||||
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'save_lora'):
|
||||
# call the base model save lora method
|
||||
self.base_model_ref().save_lora(save_dict, file, metadata)
|
||||
return
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
save_file(save_dict, file, metadata)
|
||||
@@ -577,12 +584,15 @@ class ToolkitNetworkMixin:
|
||||
keymap = {} if keymap is None else keymap
|
||||
|
||||
if isinstance(file, str):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
if self.base_model_ref is not None and hasattr(self.base_model_ref(), 'load_lora'):
|
||||
# call the base model load lora method
|
||||
weights_sd = self.base_model_ref().load_lora(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
else:
|
||||
# probably a state dict
|
||||
weights_sd = file
|
||||
|
||||
Reference in New Issue
Block a user